[Mlir-commits] [mlir] [mlir][SCF] convert-scf-to-cf: Lower scf.forall to scf.parallel (PR #65449)

Matthias Springer llvmlistbot at llvm.org
Wed Sep 6 00:44:38 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/65449:

scf.forall ops without shared outputs (i.e., fully bufferized ops) are lowered to scf.parallel. scf.forall ops are typically lowered by an earlier pass depending on the execution target. E.g., there are optimized lowerings for GPU execution. This new lowering is for completeness (convert-scf-to-cf can now lower all SCF loop constructs) and provides a simple CPU lowering strategy for testing purposes.

scf.parallel is currently lowered to scf.for, which executes sequentially. The scf.parallel lowering could be improved in the future to run on multiple threads.

>From 6904abb7b0fab155616af938b63ff66178a3811f Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 6 Sep 2023 09:43:33 +0200
Subject: [PATCH] [mlir][SCF] convert-scf-to-cf: Lower scf.forall to
 scf.parallel

scf.forall ops without shared outputs (i.e., fully bufferized ops) are lowered to scf.parallel. scf.forall ops are typically lowered by an earlier pass depending on the execution target. E.g., there are optimized lowerings for GPU execution. This new lowering is for completeness (convert-scf-to-cf can now lower all SCF loop constructs) and provides a simple CPU lowering strategy for testing purposes.

scf.parallel is currently lowered to scf.for, which executes sequentially. The scf.parallel lowering could be improved in the future to run on multiple threads.
---
 .../SCFToControlFlow/SCFToControlFlow.cpp     | 49 +++++++++++++++++--
 .../SCFToControlFlow/convert-to-cfg.mlir      | 28 +++++++++++
 2 files changed, 74 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 91dbdb429f948e..f5face5929916a 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -298,6 +298,18 @@ struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
   LogicalResult matchAndRewrite(IndexSwitchOp op,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
+/// has no shared outputs. Ops with shared outputs should be bufferized first.
+/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
+/// dialects/passes.
+struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
+  using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -677,10 +689,41 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
   return success();
 }
 
+LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
+                                              PatternRewriter &rewriter) const {
+  Location loc = forallOp.getLoc();
+  if (!forallOp.getOutputs().empty())
+    return rewriter.notifyMatchFailure(
+        forallOp,
+        "only fully bufferized scf.forall ops can be lowered to scf.parallel");
+
+  // Convert mixed bounds and steps to SSA values.
+  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedLowerBound());
+  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedUpperBound());
+  SmallVector<Value> steps =
+      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+
+  // Create empty scf.parallel op.
+  auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
+  rewriter.eraseBlock(&parallelOp.getRegion().front());
+  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
+                              parallelOp.getRegion().begin());
+  // Replace the terminator.
+  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
+  rewriter.replaceOpWithNewOp<scf::YieldOp>(
+      parallelOp.getRegion().front().getTerminator());
+
+  // Erase the scf.forall op.
+  rewriter.replaceOp(forallOp, parallelOp);
+  return success();
+}
+
 void mlir::populateSCFToControlFlowConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
-               ExecuteRegionLowering, IndexSwitchLowering>(
+  patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
+               WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
       patterns.getContext());
   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
 }
@@ -691,7 +734,7 @@ void SCFToControlFlowPass::runOnOperation() {
 
   // Configure conversion to lower out SCF operations.
   ConversionTarget target(getContext());
-  target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
+  target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
                       scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   if (failed(
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 36307a910a6cad..99b47ea94cc0b1 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -648,3 +648,31 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
   // CHECK-NEXT: return %[[V]]
   return %0 : i32
 }
+
+// Note: scf.forall is lowered to scf.parallel, which is currently lowered to
+// scf.for and then to unstructured control flow. scf.parallel could lower more
+// efficiently to multi-threaded IR, at which point scf.forall would
+// automatically lower to multi-threaded IR.
+
+// CHECK-LABEL: func @forall(
+//  CHECK-SAME:     %[[num_threads:.*]]: index)
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[c1:.*]] = arith.constant 1 : index
+//       CHECK:   cf.br ^[[bb1:.*]](%[[c0]] : index)
+//       CHECK: ^[[bb1]](%[[arg0:.*]]: index):
+//       CHECK:   %[[cmpi:.*]] = arith.cmpi slt, %[[arg0]], %[[num_threads]]
+//       CHECK:   cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
+//       CHECK: ^[[bb2]]:
+//       CHECK:   "test.foo"(%[[arg0]])
+//       CHECK:   %[[addi:.*]] = arith.addi %[[arg0]], %[[c1]]
+//       CHECK:   cf.br ^[[bb1]](%[[addi]] : index)
+//       CHECK: ^[[bb3]]:
+//       CHECK:   return
+func.func @forall(%num_threads: index) {
+  scf.forall (%thread_idx) in (%num_threads) {
+    "test.foo"(%thread_idx) : (index) -> ()
+    scf.forall.in_parallel {
+    }
+  }
+  return
+}



More information about the Mlir-commits mailing list