[Mlir-commits] [mlir] [mlir] Extract forall_to_for logic into reusable function and add pass (PR #89636)

Jorn Tuyls llvmlistbot at llvm.org
Tue Apr 23 01:08:01 PDT 2024


https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/89636

>From 45f506578b434386283a233b003218e8a3eda76f Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Sat, 20 Apr 2024 05:55:35 -0700
Subject: [PATCH] [mlir] Extract forall_to_for logic into reusable function and
 add pass

---
 .../mlir/Dialect/SCF/Transforms/Passes.h      |  3 +
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  5 ++
 .../mlir/Dialect/SCF/Transforms/Transforms.h  |  5 ++
 .../SCF/TransformOps/SCFTransformOps.cpp      | 38 ++------
 .../lib/Dialect/SCF/Transforms/CMakeLists.txt |  1 +
 .../Dialect/SCF/Transforms/ForallToFor.cpp    | 90 +++++++++++++++++++
 mlir/test/Dialect/SCF/forall-to-for.mlir      | 57 ++++++++++++
 7 files changed, 169 insertions(+), 30 deletions(-)
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
 create mode 100644 mlir/test/Dialect/SCF/forall-to-for.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfd..31c3d0eb629d28 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -59,6 +59,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
 /// loop range.
 std::unique_ptr<Pass> createForLoopRangeFoldingPass();
 
+/// Creates a pass that converts SCF forall loops to SCF for loops.
+std::unique_ptr<Pass> createForallToForLoopPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..a7aeb42d60c0e9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -120,6 +120,11 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> {
   let constructor = "mlir::createForLoopRangeFoldingPass()";
 }
 
+def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
+  let summary = "Convert SCF forall loops to SCF for loops";
+  let constructor = "mlir::createForallToForLoopPass()";
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 220dcb35571d27..9a9731161ddf7d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -28,10 +28,15 @@ class Value;
 namespace scf {
 
 class IfOp;
+class ForallOp;
 class ForOp;
 class ParallelOp;
 class WhileOp;
 
+/// Try converting scf.forall into a set of nested scf.for loops.
+LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
+                              SmallVector<Operation *> *results);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 7e4faf8b73afbb..0e3bc8ad4cacee 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -69,16 +69,7 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
     return diag;
   }
 
-  rewriter.setInsertionPoint(target);
-
-  if (!target.getOutputs().empty()) {
-    return emitSilenceableError()
-           << "unsupported shared outputs (didn't bufferize?)";
-  }
-
   SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
-  SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
-  SmallVector<OpFoldResult> steps = target.getMixedStep();
 
   if (getNumResults() != lbs.size()) {
     DiagnosedSilenceableFailure diag =
@@ -89,28 +80,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
     return diag;
   }
 
-  auto loc = target.getLoc();
-  SmallVector<Value> ivs;
-  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
-    Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
-    Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
-    Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
-    auto loop = rewriter.create<scf::ForOp>(
-        loc, lbValue, ubValue, stepValue, ValueRange(),
-        [](OpBuilder &, Location, Value, ValueRange) {});
-    ivs.push_back(loop.getInductionVar());
-    rewriter.setInsertionPointToStart(loop.getBody());
-    rewriter.create<scf::YieldOp>(loc);
-    rewriter.setInsertionPointToStart(loop.getBody());
+  SmallVector<Operation *> opResults;
+  if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to convert forall into for";
+    return diag;
   }
-  rewriter.eraseOp(target.getBody()->getTerminator());
-  rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
-                             ivs);
-  rewriter.eraseOp(target);
-
-  for (auto &&[i, iv] : llvm::enumerate(ivs)) {
-    results.set(cast<OpResult>(getTransformed()[i]),
-                {iv.getParentBlock()->getParentOp()});
+
+  for (auto [i, res] : llvm::enumerate(opResults)) {
+    results.set(cast<OpResult>(getTransformed()[i]), {res});
   }
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index a2925aef17ca78..e7671c9cc28f8b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  ForallToFor.cpp
   ForToWhile.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
new file mode 100644
index 00000000000000..205ac6cb87d7a9
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -0,0 +1,90 @@
+//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForallOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace llvm;
+using namespace mlir;
+using scf::ForallOp;
+using scf::ForOp;
+using scf::LoopNest;
+
+LogicalResult
+mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
+                           SmallVector<Operation *> *results = nullptr) {
+  rewriter.setInsertionPoint(forallOp);
+
+  if (!forallOp.getOutputs().empty()) {
+    return forallOp.emitOpError()
+           << "unsupported shared outputs (didn't bufferize?)";
+  }
+
+  auto loc = forallOp.getLoc();
+  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedLowerBound());
+  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedUpperBound());
+  SmallVector<Value> steps =
+      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+  LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
+
+  if (results) {
+    llvm::copy(loopNest.loops, std::back_inserter(*results));
+  }
+
+  SmallVector<Value> ivs;
+  for (scf::ForOp loop : loopNest.loops) {
+    ivs.push_back(loop.getInductionVar());
+  }
+
+  Block *innermostBlock = loopNest.loops.back().getBody();
+  rewriter.eraseOp(forallOp.getBody()->getTerminator());
+  rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
+                             innermostBlock->getTerminator()->getIterator(),
+                             ivs);
+  rewriter.eraseOp(forallOp);
+
+  return success();
+}
+
+namespace {
+struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
+  void runOnOperation() override {
+    auto *parentOp = getOperation();
+    IRRewriter rewriter(parentOp->getContext());
+
+    SmallVector<scf::ForallOp> forallOps;
+    parentOp->walk(
+        [&](scf::ForallOp forallOp) { forallOps.push_back(forallOp); });
+
+    for (auto forallOp : forallOps) {
+      if (failed(scf::forallToForLoop(rewriter, forallOp))) {
+        return signalPassFailure();
+      }
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
+  return std::make_unique<ForallToForLoop>();
+}
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
new file mode 100644
index 00000000000000..e7d183fb9d2b54
--- /dev/null
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @repeated
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+// CHECK-LABEL: @nested
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
+func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     scf.for %[[IV3:.+]] = %{{.*}} to %[[UB3]]
+  // CHECK:       scf.for %[[IV4:.+]] = %{{.*}} to %[[UB4]]
+  // CHECK:         func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    scf.forall (%k, %l) in (%ub3, %ub4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  return
+}



More information about the Mlir-commits mailing list