[Mlir-commits] [mlir] 286bd42 - [mlir] Extract forall_to_for logic into reusable function and add pass (#89636)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 24 09:57:52 PDT 2024
Author: Jorn Tuyls
Date: 2024-04-24T09:57:48-07:00
New Revision: 286bd42a7a799e3d9035c09bf0d64cb1a1eef682
URL: https://github.com/llvm/llvm-project/commit/286bd42a7a799e3d9035c09bf0d64cb1a1eef682
DIFF: https://github.com/llvm/llvm-project/commit/286bd42a7a799e3d9035c09bf0d64cb1a1eef682.diff
LOG: [mlir] Extract forall_to_for logic into reusable function and add pass (#89636)
This PR extracts the existing `scf.forall` to `scf.for` conversion logic
inside a transform op (https://github.com/llvm/llvm-project/pull/65474)
into a standalone function which can be used in other transformations
and adds a `scf-forall-to-for` pass.
Added:
mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
mlir/test/Dialect/SCF/forall-to-for.mlir
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8cfd..31c3d0eb629d28 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -59,6 +59,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
/// loop range.
std::unique_ptr<Pass> createForLoopRangeFoldingPass();
+/// Creates a pass that converts SCF forall loops to SCF for loops.
+std::unique_ptr<Pass> createForallToForLoopPass();
+
// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..a7aeb42d60c0e9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -120,6 +120,11 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> {
let constructor = "mlir::createForLoopRangeFoldingPass()";
}
+def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
+ let summary = "Convert SCF forall loops to SCF for loops";
+ let constructor = "mlir::createForallToForLoopPass()";
+}
+
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 220dcb35571d27..b063e6e775e634 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -28,10 +28,17 @@ class Value;
namespace scf {
class IfOp;
+class ForallOp;
class ForOp;
class ParallelOp;
class WhileOp;
+/// Try converting scf.forall into a set of nested scf.for loops.
+/// The newly created scf.for ops will be returned through the `results`
+/// vector if provided.
+LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
+ SmallVectorImpl<Operation *> *results = nullptr);
+
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 7e4faf8b73afbb..69f83d8bd70da1 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -69,16 +69,12 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return diag;
}
- rewriter.setInsertionPoint(target);
-
if (!target.getOutputs().empty()) {
return emitSilenceableError()
<< "unsupported shared outputs (didn't bufferize?)";
}
SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
- SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
- SmallVector<OpFoldResult> steps = target.getMixedStep();
if (getNumResults() != lbs.size()) {
DiagnosedSilenceableFailure diag =
@@ -89,28 +85,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return diag;
}
- auto loc = target.getLoc();
- SmallVector<Value> ivs;
- for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
- Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
- Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
- Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
- auto loop = rewriter.create<scf::ForOp>(
- loc, lbValue, ubValue, stepValue, ValueRange(),
- [](OpBuilder &, Location, Value, ValueRange) {});
- ivs.push_back(loop.getInductionVar());
- rewriter.setInsertionPointToStart(loop.getBody());
- rewriter.create<scf::YieldOp>(loc);
- rewriter.setInsertionPointToStart(loop.getBody());
+ SmallVector<Operation *> opResults;
+ if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "failed to convert forall into for";
+ return diag;
}
- rewriter.eraseOp(target.getBody()->getTerminator());
- rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
- ivs);
- rewriter.eraseOp(target);
-
- for (auto &&[i, iv] : llvm::enumerate(ivs)) {
- results.set(cast<OpResult>(getTransformed()[i]),
- {iv.getParentBlock()->getParentOp()});
+
+ for (auto &&[i, res] : llvm::enumerate(opResults)) {
+ results.set(cast<OpResult>(getTransformed()[i]), {res});
}
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index a2925aef17ca78..e7671c9cc28f8b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ ForallToFor.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
new file mode 100644
index 00000000000000..198cb2e6cc69ef
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -0,0 +1,79 @@
+//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForallOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace llvm;
+using namespace mlir;
+using scf::ForallOp;
+using scf::ForOp;
+using scf::LoopNest;
+
+LogicalResult
+mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
+ SmallVectorImpl<Operation *> *results) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(forallOp);
+
+ Location loc = forallOp.getLoc();
+ SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+ rewriter, loc, forallOp.getMixedLowerBound());
+ SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+ rewriter, loc, forallOp.getMixedUpperBound());
+ SmallVector<Value> steps =
+ getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+ LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
+
+ SmallVector<Value> ivs = llvm::map_to_vector(
+ loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
+
+ Block *innermostBlock = loopNest.loops.back().getBody();
+ rewriter.eraseOp(forallOp.getBody()->getTerminator());
+ rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
+ innermostBlock->getTerminator()->getIterator(),
+ ivs);
+ rewriter.eraseOp(forallOp);
+
+ if (results) {
+ llvm::move(loopNest.loops, std::back_inserter(*results));
+ }
+
+ return success();
+}
+
+namespace {
+struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
+ void runOnOperation() override {
+ Operation *parentOp = getOperation();
+ IRRewriter rewriter(parentOp->getContext());
+
+ parentOp->walk([&](scf::ForallOp forallOp) {
+ if (failed(scf::forallToForLoop(rewriter, forallOp))) {
+ return signalPassFailure();
+ }
+ });
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
+ return std::make_unique<ForallToForLoop>();
+}
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
new file mode 100644
index 00000000000000..e7d183fb9d2b54
--- /dev/null
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+ // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]])
+ return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @repeated
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @repeated(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+ // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]])
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+ // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]])
+ return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+// CHECK-LABEL: @nested
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
+func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
+ // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+ // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+ // CHECK: scf.for %[[IV3:.+]] = %{{.*}} to %[[UB3]]
+ // CHECK: scf.for %[[IV4:.+]] = %{{.*}} to %[[UB4]]
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ scf.forall (%k, %l) in (%ub3, %ub4) {
+ func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+ }
+ }
+ return
+}
More information about the Mlir-commits
mailing list