[Mlir-commits] [mlir] 87568ff - [mlir][SCF] convert-scf-to-cf: Lower scf.forall to scf.parallel (#65449)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 6 04:28:03 PDT 2023


Author: Matthias Springer
Date: 2023-09-06T13:27:59+02:00
New Revision: 87568ff3efb216e35161bf201897b2c587ffc1c4

URL: https://github.com/llvm/llvm-project/commit/87568ff3efb216e35161bf201897b2c587ffc1c4
DIFF: https://github.com/llvm/llvm-project/commit/87568ff3efb216e35161bf201897b2c587ffc1c4.diff

LOG: [mlir][SCF] convert-scf-to-cf: Lower scf.forall to scf.parallel (#65449)

scf.forall ops without shared outputs (i.e., fully bufferized ops) are
lowered to scf.parallel. scf.forall ops are typically lowered by an
earlier pass depending on the execution target. E.g., there are
optimized lowerings for GPU execution. This new lowering is for
completeness (convert-scf-to-cf can now lower all SCF loop constructs)
and provides a simple CPU lowering strategy for testing purposes.

scf.parallel is currently lowered to scf.for, which executes
sequentially. The scf.parallel lowering could be improved in the future
to run on multiple threads.

Added: 
    

Modified: 
    mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
    mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 91dbdb429f948e..f5face5929916a 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -298,6 +298,18 @@ struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
   LogicalResult matchAndRewrite(IndexSwitchOp op,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
+/// has no shared outputs. Ops with shared outputs should be bufferized first.
+/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
+/// dialects/passes.
+struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
+  using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -677,10 +689,41 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
   return success();
 }
 
+LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
+                                              PatternRewriter &rewriter) const {
+  Location loc = forallOp.getLoc();
+  if (!forallOp.getOutputs().empty())
+    return rewriter.notifyMatchFailure(
+        forallOp,
+        "only fully bufferized scf.forall ops can be lowered to scf.parallel");
+
+  // Convert mixed bounds and steps to SSA values.
+  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedLowerBound());
+  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedUpperBound());
+  SmallVector<Value> steps =
+      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+
+  // Create empty scf.parallel op.
+  auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
+  rewriter.eraseBlock(&parallelOp.getRegion().front());
+  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
+                              parallelOp.getRegion().begin());
+  // Replace the terminator.
+  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
+  rewriter.replaceOpWithNewOp<scf::YieldOp>(
+      parallelOp.getRegion().front().getTerminator());
+
+  // Erase the scf.forall op.
+  rewriter.replaceOp(forallOp, parallelOp);
+  return success();
+}
+
 void mlir::populateSCFToControlFlowConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
-               ExecuteRegionLowering, IndexSwitchLowering>(
+  patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
+               WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
       patterns.getContext());
   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
 }
@@ -691,7 +734,7 @@ void SCFToControlFlowPass::runOnOperation() {
 
   // Configure conversion to lower out SCF operations.
   ConversionTarget target(getContext());
-  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
+  target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
                       scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   if (failed(

diff  --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 36307a910a6cad..99b47ea94cc0b1 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -648,3 +648,31 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
   // CHECK-NEXT: return %[[V]]
   return %0 : i32
 }
+
+// Note: scf.forall is lowered to scf.parallel, which is currently lowered to
+// scf.for and then to unstructured control flow. scf.parallel could lower more
+// efficiently to multi-threaded IR, at which point scf.forall would
+// automatically lower to multi-threaded IR.
+
+// CHECK-LABEL: func @forall(
+//  CHECK-SAME:     %[[num_threads:.*]]: index)
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[c1:.*]] = arith.constant 1 : index
+//       CHECK:   cf.br ^[[bb1:.*]](%[[c0]] : index)
+//       CHECK: ^[[bb1]](%[[arg0:.*]]: index):
+//       CHECK:   %[[cmpi:.*]] = arith.cmpi slt, %[[arg0]], %[[num_threads]]
+//       CHECK:   cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
+//       CHECK: ^[[bb2]]:
+//       CHECK:   "test.foo"(%[[arg0]])
+//       CHECK:   %[[addi:.*]] = arith.addi %[[arg0]], %[[c1]]
+//       CHECK:   cf.br ^[[bb1]](%[[addi]] : index)
+//       CHECK: ^[[bb3]]:
+//       CHECK:   return
+func.func @forall(%num_threads: index) {
+  scf.forall (%thread_idx) in (%num_threads) {
+    "test.foo"(%thread_idx) : (index) -> ()
+    scf.forall.in_parallel {
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list