[Mlir-commits] [mlir] [mlir] Extract forall_to_for logic into reusable function and add pass (PR #89636)

Jorn Tuyls llvmlistbot at llvm.org
Tue Apr 23 04:31:28 PDT 2024


================
@@ -89,28 +80,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
     return diag;
   }
 
-  auto loc = target.getLoc();
-  SmallVector<Value> ivs;
-  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
-    Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
-    Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
-    Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
-    auto loop = rewriter.create<scf::ForOp>(
-        loc, lbValue, ubValue, stepValue, ValueRange(),
-        [](OpBuilder &, Location, Value, ValueRange) {});
-    ivs.push_back(loop.getInductionVar());
-    rewriter.setInsertionPointToStart(loop.getBody());
-    rewriter.create<scf::YieldOp>(loc);
-    rewriter.setInsertionPointToStart(loop.getBody());
+  SmallVector<Operation *> opResults;
+  if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to convert forall into for";
+    return diag;
   }
-  rewriter.eraseOp(target.getBody()->getTerminator());
-  rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
-                             ivs);
-  rewriter.eraseOp(target);
-
-  for (auto &&[i, iv] : llvm::enumerate(ivs)) {
-    results.set(cast<OpResult>(getTransformed()[i]),
-                {iv.getParentBlock()->getParentOp()});
+
+  for (auto [i, res] : llvm::enumerate(opResults)) {
----------------
jtuyls wrote:

Done

https://github.com/llvm/llvm-project/pull/89636


More information about the Mlir-commits mailing list