[Mlir-commits] [mlir] [mlir][scf] Implement conversion from scf.forall to scf.parallel (PR #94109)

Spenser Bauman llvmlistbot at llvm.org
Mon Jun 3 05:38:28 PDT 2024


================
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
 
 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
                                               PatternRewriter &rewriter) const {
-  Location loc = forallOp.getLoc();
-  if (!forallOp.getOutputs().empty())
-    return rewriter.notifyMatchFailure(
-        forallOp,
-        "only fully bufferized scf.forall ops can be lowered to scf.parallel");
-
-  // Convert mixed bounds and steps to SSA values.
-  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedLowerBound());
-  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedUpperBound());
-  SmallVector<Value> steps =
-      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
-
-  // Create empty scf.parallel op.
-  auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
-  rewriter.eraseBlock(&parallelOp.getRegion().front());
-  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
-                              parallelOp.getRegion().begin());
-  // Replace the terminator.
-  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
-  rewriter.replaceOpWithNewOp<scf::ReduceOp>(
-      parallelOp.getRegion().front().getTerminator());
-
-  // Erase the scf.forall op.
-  rewriter.replaceOp(forallOp, parallelOp);
-  return success();
+  return scf::forallToParallelLoop(rewriter, forallOp);
----------------
sabauma wrote:

Fixed. Thanks for catching this. It would not be the first time we broke the build due to a missing dependency.

https://github.com/llvm/llvm-project/pull/94109


More information about the Mlir-commits mailing list