[Mlir-commits] [mlir] [mlir] introduce transform.loop.forall_to_for (PR #65474)
Matthias Springer
llvmlistbot at llvm.org
Wed Sep 6 07:55:40 PDT 2023
================
@@ -64,6 +69,79 @@ transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payload = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(payload)) {
+ return emitSilenceableError() << "expected a single payload op";
+ }
+ auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+ if (!target) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "expected the payload to be scf.forall";
+ diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+ return diag;
+ }
+
+ rewriter.setInsertionPoint(target);
+
+ if (!target.getOutputs().empty()) {
+ return emitSilenceableError()
+ << "unsupported shared outputs (didn't bufferize?)";
+ }
+
+ auto materialize = [](OpBuilder &b, Location loc, OpFoldResult r) -> Value {
+ if (Value v = r.dyn_cast<Value>())
+ return v;
+ return b.create<arith::ConstantIndexOp>(
+ loc, r.get<Attribute>().cast<IntegerAttr>().getValue().getSExtValue());
+ };
+
+ SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
+ SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
+ SmallVector<OpFoldResult> steps = target.getMixedStep();
+
+ if (getNumResults() != lbs.size()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "op expects as many results (" << getNumResults()
+ << ") as payload has induction variables (" << lbs.size() << ")";
+ diag.attachNote(target.getLoc()) << "payload op";
+ return diag;
+ }
+
+ auto loc = target.getLoc();
+ SmallVector<Value> ivs;
+ for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
----------------
matthias-springer wrote:
I'm not sure if this is useful here, but there is also `mlir::scf::buildLoopNest`.
https://github.com/llvm/llvm-project/pull/65474
More information about the Mlir-commits
mailing list