[Mlir-commits] [mlir] 286bd42 - [mlir] Extract forall_to_for logic into reusable function and add pass (#89636)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 24 09:57:52 PDT 2024


Author: Jorn Tuyls
Date: 2024-04-24T09:57:48-07:00
New Revision: 286bd42a7a799e3d9035c09bf0d64cb1a1eef682

URL: https://github.com/llvm/llvm-project/commit/286bd42a7a799e3d9035c09bf0d64cb1a1eef682
DIFF: https://github.com/llvm/llvm-project/commit/286bd42a7a799e3d9035c09bf0d64cb1a1eef682.diff

LOG: [mlir] Extract forall_to_for logic into reusable function and add pass (#89636)

This PR extracts the existing `scf.forall` to `scf.for` conversion logic
inside a transform op (https://github.com/llvm/llvm-project/pull/65474)
into a standalone function which can be used in other transformations
and adds a `scf-forall-to-for` pass.

Added: 
    mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
    mlir/test/Dialect/SCF/forall-to-for.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
    mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
    mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfd..31c3d0eb629d28 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -59,6 +59,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
 /// loop range.
 std::unique_ptr<Pass> createForLoopRangeFoldingPass();
 
+/// Creates a pass that converts SCF forall loops to SCF for loops.
+std::unique_ptr<Pass> createForallToForLoopPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..a7aeb42d60c0e9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -120,6 +120,11 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> {
   let constructor = "mlir::createForLoopRangeFoldingPass()";
 }
 
+def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
+  let summary = "Convert SCF forall loops to SCF for loops";
+  let constructor = "mlir::createForallToForLoopPass()";
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 220dcb35571d27..b063e6e775e634 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -28,10 +28,17 @@ class Value;
 namespace scf {
 
 class IfOp;
+class ForallOp;
 class ForOp;
 class ParallelOp;
 class WhileOp;
 
+/// Try converting scf.forall into a set of nested scf.for loops.
+/// The newly created scf.for ops will be returned through the `results`
+/// vector if provided.
+LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
+                              SmallVectorImpl<Operation *> *results = nullptr);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 7e4faf8b73afbb..69f83d8bd70da1 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -69,16 +69,12 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
     return diag;
   }
 
-  rewriter.setInsertionPoint(target);
-
   if (!target.getOutputs().empty()) {
     return emitSilenceableError()
            << "unsupported shared outputs (didn't bufferize?)";
   }
 
   SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
-  SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
-  SmallVector<OpFoldResult> steps = target.getMixedStep();
 
   if (getNumResults() != lbs.size()) {
     DiagnosedSilenceableFailure diag =
@@ -89,28 +85,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
     return diag;
   }
 
-  auto loc = target.getLoc();
-  SmallVector<Value> ivs;
-  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
-    Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
-    Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
-    Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
-    auto loop = rewriter.create<scf::ForOp>(
-        loc, lbValue, ubValue, stepValue, ValueRange(),
-        [](OpBuilder &, Location, Value, ValueRange) {});
-    ivs.push_back(loop.getInductionVar());
-    rewriter.setInsertionPointToStart(loop.getBody());
-    rewriter.create<scf::YieldOp>(loc);
-    rewriter.setInsertionPointToStart(loop.getBody());
+  SmallVector<Operation *> opResults;
+  if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to convert forall into for";
+    return diag;
   }
-  rewriter.eraseOp(target.getBody()->getTerminator());
-  rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
-                             ivs);
-  rewriter.eraseOp(target);
-
-  for (auto &&[i, iv] : llvm::enumerate(ivs)) {
-    results.set(cast<OpResult>(getTransformed()[i]),
-                {iv.getParentBlock()->getParentOp()});
+
+  for (auto &&[i, res] : llvm::enumerate(opResults)) {
+    results.set(cast<OpResult>(getTransformed()[i]), {res});
   }
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index a2925aef17ca78..e7671c9cc28f8b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  ForallToFor.cpp
   ForToWhile.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
new file mode 100644
index 00000000000000..198cb2e6cc69ef
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -0,0 +1,79 @@
+//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForallOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace llvm;
+using namespace mlir;
+using scf::ForallOp;
+using scf::ForOp;
+using scf::LoopNest;
+
+LogicalResult
+mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
+                           SmallVectorImpl<Operation *> *results) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+
+  Location loc = forallOp.getLoc();
+  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedLowerBound());
+  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedUpperBound());
+  SmallVector<Value> steps =
+      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+  LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
+
+  SmallVector<Value> ivs = llvm::map_to_vector(
+      loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
+
+  Block *innermostBlock = loopNest.loops.back().getBody();
+  rewriter.eraseOp(forallOp.getBody()->getTerminator());
+  rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
+                             innermostBlock->getTerminator()->getIterator(),
+                             ivs);
+  rewriter.eraseOp(forallOp);
+
+  if (results) {
+    llvm::move(loopNest.loops, std::back_inserter(*results));
+  }
+
+  return success();
+}
+
+namespace {
+struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+    IRRewriter rewriter(parentOp->getContext());
+
+    parentOp->walk([&](scf::ForallOp forallOp) {
+      if (failed(scf::forallToForLoop(rewriter, forallOp))) {
+        return signalPassFailure();
+      }
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
+  return std::make_unique<ForallToForLoop>();
+}

diff  --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
new file mode 100644
index 00000000000000..e7d183fb9d2b54
--- /dev/null
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @repeated
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+// CHECK-LABEL: @nested
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
+func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     scf.for %[[IV3:.+]] = %{{.*}} to %[[UB3]]
+  // CHECK:       scf.for %[[IV4:.+]] = %{{.*}} to %[[UB4]]
+  // CHECK:         func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    scf.forall (%k, %l) in (%ub3, %ub4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list