[Mlir-commits] [mlir] [mlir] introduce transform.loop.forall_to_for (PR #65474)

Matthias Springer llvmlistbot at llvm.org
Wed Sep 6 07:55:40 PDT 2023


================
@@ -64,6 +69,79 @@ transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
+                                transform::TransformResults &results,
+                                transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload)) {
+    return emitSilenceableError() << "expected a single payload op";
+  }
+  auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.forall";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  rewriter.setInsertionPoint(target);
+
+  if (!target.getOutputs().empty()) {
+    return emitSilenceableError()
+           << "unsupported shared outputs (didn't bufferize?)";
+  }
+
+  auto materialize = [](OpBuilder &b, Location loc, OpFoldResult r) -> Value {
+    if (Value v = r.dyn_cast<Value>())
+      return v;
+    return b.create<arith::ConstantIndexOp>(
+        loc, r.get<Attribute>().cast<IntegerAttr>().getValue().getSExtValue());
+  };
+
+  SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
+  SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
+  SmallVector<OpFoldResult> steps = target.getMixedStep();
+
+  if (getNumResults() != lbs.size()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "op expects as many results (" << getNumResults()
+        << ") as payload has induction variables (" << lbs.size() << ")";
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  auto loc = target.getLoc();
+  SmallVector<Value> ivs;
+  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
----------------
matthias-springer wrote:

I'm not sure if this is useful here, but there is also `mlir::scf::buildLoopNest`.

https://github.com/llvm/llvm-project/pull/65474


More information about the Mlir-commits mailing list