[Mlir-commits] [mlir] [mlir] Extract forall_to_for logic into reusable function and add pass (PR #89636)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 22 10:39:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Jorn Tuyls (jtuyls)

<details>
<summary>Changes</summary>

This PR extracts the existing `scf.forall` to `scf.for` conversion logic inside a transform op (https://github.com/llvm/llvm-project/pull/65474) into a standalone function which can be used in other transformations and adds a `scf-forall-to-for` pass. 

cc @<!-- -->ftynse @<!-- -->matthias-springer @<!-- -->MaheshRavishankar Could you help with reviewing this PR?

---
Full diff: https://github.com/llvm/llvm-project/pull/89636.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.h (+3) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+5) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+5) 
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+8-30) 
- (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp (+92) 
- (added) mlir/test/Dialect/SCF/forall-to-for.mlir (+57) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfd..31c3d0eb629d28 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -59,6 +59,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
 /// loop range.
 std::unique_ptr<Pass> createForLoopRangeFoldingPass();
 
+/// Creates a pass that converts SCF forall loops to SCF for loops.
+std::unique_ptr<Pass> createForallToForLoopPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..a7aeb42d60c0e9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -120,6 +120,11 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> {
   let constructor = "mlir::createForLoopRangeFoldingPass()";
 }
 
+def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
+  let summary = "Convert SCF forall loops to SCF for loops";
+  let constructor = "mlir::createForallToForLoopPass()";
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 220dcb35571d27..9a9731161ddf7d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -28,10 +28,15 @@ class Value;
 namespace scf {
 
 class IfOp;
+class ForallOp;
 class ForOp;
 class ParallelOp;
 class WhileOp;
 
+/// Try converting scf.forall into a set of nested scf.for loops.
+LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
+                              SmallVector<Operation *> *results);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 7e4faf8b73afbb..0e3bc8ad4cacee 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -69,16 +69,7 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
     return diag;
   }
 
-  rewriter.setInsertionPoint(target);
-
-  if (!target.getOutputs().empty()) {
-    return emitSilenceableError()
-           << "unsupported shared outputs (didn't bufferize?)";
-  }
-
   SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
-  SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
-  SmallVector<OpFoldResult> steps = target.getMixedStep();
 
   if (getNumResults() != lbs.size()) {
     DiagnosedSilenceableFailure diag =
@@ -89,28 +80,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
     return diag;
   }
 
-  auto loc = target.getLoc();
-  SmallVector<Value> ivs;
-  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
-    Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
-    Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
-    Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
-    auto loop = rewriter.create<scf::ForOp>(
-        loc, lbValue, ubValue, stepValue, ValueRange(),
-        [](OpBuilder &, Location, Value, ValueRange) {});
-    ivs.push_back(loop.getInductionVar());
-    rewriter.setInsertionPointToStart(loop.getBody());
-    rewriter.create<scf::YieldOp>(loc);
-    rewriter.setInsertionPointToStart(loop.getBody());
+  SmallVector<Operation *> opResults;
+  if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to convert forall into for";
+    return diag;
   }
-  rewriter.eraseOp(target.getBody()->getTerminator());
-  rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
-                             ivs);
-  rewriter.eraseOp(target);
-
-  for (auto &&[i, iv] : llvm::enumerate(ivs)) {
-    results.set(cast<OpResult>(getTransformed()[i]),
-                {iv.getParentBlock()->getParentOp()});
+
+  for (auto [i, res] : llvm::enumerate(opResults)) {
+    results.set(cast<OpResult>(getTransformed()[i]), {res});
   }
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index a2925aef17ca78..e7671c9cc28f8b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  ForallToFor.cpp
   ForToWhile.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
new file mode 100644
index 00000000000000..2a4e97564fbe2c
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -0,0 +1,92 @@
+//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForallOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace llvm;
+using namespace mlir;
+using scf::ForallOp;
+using scf::ForOp;
+
+LogicalResult
+mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
+                           SmallVector<Operation *> *results = nullptr) {
+  rewriter.setInsertionPoint(forallOp);
+
+  if (!forallOp.getOutputs().empty()) {
+    return forallOp.emitOpError()
+           << "unsupported shared outputs (didn't bufferize?)";
+  }
+
+  SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
+  SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
+  SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
+
+  auto loc = forallOp.getLoc();
+  SmallVector<Value> ivs;
+  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
+    Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
+    Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
+    Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
+    auto loop =
+        rewriter.create<ForOp>(loc, lbValue, ubValue, stepValue, ValueRange(),
+                               [](OpBuilder &, Location, Value, ValueRange) {});
+    if (results)
+      results->push_back(loop);
+    ivs.push_back(loop.getInductionVar());
+    rewriter.setInsertionPointToStart(loop.getBody());
+    rewriter.create<scf::YieldOp>(loc);
+    rewriter.setInsertionPointToStart(loop.getBody());
+  }
+  rewriter.eraseOp(forallOp.getBody()->getTerminator());
+  rewriter.inlineBlockBefore(forallOp.getBody(), &*rewriter.getInsertionPoint(),
+                             ivs);
+  rewriter.eraseOp(forallOp);
+  return success();
+}
+
+namespace {
+struct ForallToForLoopLoweringPattern : public OpRewritePattern<ForallOp> {
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForallOp forallOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(scf::forallToForLoop(rewriter, forallOp)))
+      return failure();
+    return success();
+  }
+};
+
+struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
+  void runOnOperation() override {
+    auto *parentOp = getOperation();
+    MLIRContext *ctx = parentOp->getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<ForallToForLoopLoweringPattern>(ctx);
+    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
+  return std::make_unique<ForallToForLoop>();
+}
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
new file mode 100644
index 00000000000000..e7d183fb9d2b54
--- /dev/null
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @repeated
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+// CHECK-LABEL: @nested
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
+func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     scf.for %[[IV3:.+]] = %{{.*}} to %[[UB3]]
+  // CHECK:       scf.for %[[IV4:.+]] = %{{.*}} to %[[UB4]]
+  // CHECK:         func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    scf.forall (%k, %l) in (%ub3, %ub4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/89636


More information about the Mlir-commits mailing list