[Mlir-commits] [mlir] [mlir] Extract forall_to_for logic into reusable function and add pass (PR #89636)
Jorn Tuyls
llvmlistbot at llvm.org
Tue Apr 23 04:31:28 PDT 2024
================
@@ -89,28 +80,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return diag;
}
- auto loc = target.getLoc();
- SmallVector<Value> ivs;
- for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
- Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
- Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
- Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
- auto loop = rewriter.create<scf::ForOp>(
- loc, lbValue, ubValue, stepValue, ValueRange(),
- [](OpBuilder &, Location, Value, ValueRange) {});
- ivs.push_back(loop.getInductionVar());
- rewriter.setInsertionPointToStart(loop.getBody());
- rewriter.create<scf::YieldOp>(loc);
- rewriter.setInsertionPointToStart(loop.getBody());
+ SmallVector<Operation *> opResults;
+ if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "failed to convert forall into for";
+ return diag;
}
- rewriter.eraseOp(target.getBody()->getTerminator());
- rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
- ivs);
- rewriter.eraseOp(target);
-
- for (auto &&[i, iv] : llvm::enumerate(ivs)) {
- results.set(cast<OpResult>(getTransformed()[i]),
- {iv.getParentBlock()->getParentOp()});
+
+ for (auto [i, res] : llvm::enumerate(opResults)) {
----------------
jtuyls wrote:
Done
https://github.com/llvm/llvm-project/pull/89636
More information about the Mlir-commits
mailing list