[Mlir-commits] [mlir] [MLIR][SCF] Define `-scf-rotate-while` pass (PR #99850)
Victor Perez
llvmlistbot at llvm.org
Wed Jul 24 05:05:37 PDT 2024
https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/99850
>From f746d30b74fb8b59778b9efab4174aa1248030e3 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 22 Jul 2024 09:41:10 +0100
Subject: [PATCH 1/5] [MLIR][SCF] Define `scf.while` rotation pass
Define pass rotating `scf.while` loops leveraging existing
`mlir::scf::wrapWhileLoopInZeroTripCheck` and exposing it as a pass.
This pass rotates SCF.WhileOp, mutating them from "while" loops to
"do-while" loops. A guard checking the condition for the first
iteration is inserted. Note this guard can be optimized away if the
compiler can prove the loop will be executed at least once.
Using this pass, the following while loop:
```mlir
scf.while (%arg0 = %init) : (i32) -> i64 {
%val = .., %arg0 : i64
%cond = arith.cmpi .., %arg0 : i32
scf.condition(%cond) %val : i64
} do {
^bb0(%arg1: i64):
%next = .., %arg1 : i32
scf.yield %next : i32
}
```
Can be transformed into:
``` mlir
%pre_val = .., %init : i64
%pre_cond = arith.cmpi .., %init : i32
scf.if %pre_cond -> i64 {
%res = scf.while (%arg1 = %va0) : (i64) -> i64 {
// Original after block
%next = .., %arg1 : i32
// Original before block
%val = .., %next : i64
%cond = arith.cmpi .., %next : i32
scf.condition(%cond) %val : i64
} do {
^bb0(%arg2: i64):
%scf.yield %arg2 : i32
}
scf.yield %res : i64
} else {
scf.yield %pre_val : i64
}
```
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
.../mlir/Dialect/SCF/Transforms/Passes.td | 51 ++++++++++++
.../mlir/Dialect/SCF/Transforms/Patterns.h | 5 ++
.../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 +
.../SCF/Transforms/RotateWhileLoop.cpp | 82 +++++++++++++++++++
...zero-trip-check.mlir => rotate-while.mlir} | 4 +-
mlir/test/lib/Dialect/SCF/CMakeLists.txt | 1 -
.../SCF/TestSCFWrapInZeroTripCheck.cpp | 72 ----------------
mlir/tools/mlir-opt/mlir-opt.cpp | 2 -
8 files changed, 141 insertions(+), 77 deletions(-)
create mode 100644 mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
rename mlir/test/Dialect/SCF/{wrap-while-loop-in-zero-trip-check.mlir => rotate-while.mlir} (95%)
delete mode 100644 mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 9b29affb97c43..beb52130e0137 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -164,4 +164,55 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
}];
}
+def SCFRotateWhileLoopPass : Pass<"scf-rotate-while"> {
+ let summary = "Rotate while loops, turning them into do-while loops.";
+ let description = [{
+ This pass rotates SCF.WhileOp, mutating them from "while" loops to
+ "do-while" loops. A guard checking the condition for the first iteration is
+ inserted. Note this guard can be optimized away if the compiler can prove
+ the loop will be executed at least once.
+
+ Using this pass, the following while loop:
+
+ ```mlir
+# Before:
+scf.while (%arg0 = %init) : (i32) -> i64 {
+ %val = .., %arg0 : i64
+ %cond = arith.cmpi .., %arg0 : i32
+ scf.condition(%cond) %val : i64
+} do {
+^bb0(%arg1: i64):
+ %next = .., %arg1 : i32
+ scf.yield %next : i32
+}
+ ```
+
+ Can be transformed into:
+ ``` mlir
+%pre_val = .., %init : i64
+%pre_cond = arith.cmpi .., %init : i32
+scf.if %pre_cond -> i64 {
+ %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+ // Original after block
+ %next = .., %arg1 : i32
+ // Original before block
+ %val = .., %next : i64
+ %cond = arith.cmpi .., %next : i32
+ scf.condition(%cond) %val : i64
+ } do {
+ ^bb0(%arg2: i64):
+ %scf.yield %arg2 : i32
+ }
+ scf.yield %res : i64
+} else {
+ scf.yield %pre_val : i64
+}
+ ```
+ }];
+ let options = [
+ Option<"forceCreateCheck", "force-create-check", "bool", /*default=*/"false",
+ "Create loop guard even if the loop is already in a do-while form.">
+ ];
+}
+
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index fdf2570626980..3a6ed32f37c67 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -16,6 +16,7 @@
namespace mlir {
class ConversionTarget;
+class SCFRotateWhileLoopPassOptions;
class TypeConverter;
namespace scf {
@@ -85,6 +86,10 @@ void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
/// * `after` block containing arith.addi
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
+/// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
+/// from `while` loops.
+void populateSCFRotateWhileLoopPatterns(
+ RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options);
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index d363ffe941fce..8c73515c608f5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ParallelLoopCollapsing.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
+ RotateWhileLoop.cpp
StructuralTypeConversions.cpp
TileUsingInterface.cpp
WrapInZeroTripCheck.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
new file mode 100644
index 0000000000000..f493adaad6be9
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -0,0 +1,82 @@
+//===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Rotates `scf.while` loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "scf-rotate-while"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
+ RotateWhileLoopPattern(bool rotateLoop, MLIRContext *context,
+ PatternBenefit benefit = 1,
+ ArrayRef<StringRef> generatedNames = {})
+ : OpRewritePattern<scf::WhileOp>(context, benefit, generatedNames),
+ forceCreateCheck(rotateLoop) {}
+
+ LogicalResult matchAndRewrite(scf::WhileOp whileOp,
+ PatternRewriter &rewriter) const final {
+ FailureOr<scf::WhileOp> result =
+ scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
+ if (failed(result) || *result == whileOp) {
+ LLVM_DEBUG(whileOp->emitRemark("Failed to rotate loop"));
+ return failure();
+ };
+ return success();
+ }
+
+ bool forceCreateCheck;
+};
+
+struct SCFRotateWhileLoopPass
+ : impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
+ using Base::Base;
+
+ void runOnOperation() final {
+ Operation *parentOp = getOperation();
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ SCFRotateWhileLoopPassOptions options{forceCreateCheck};
+ scf::populateSCFRotateWhileLoopPatterns(patterns, options);
+ // Avoid applying the pattern to a loop more than once.
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ [[maybe_unused]] LogicalResult success =
+ applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace scf {
+void populateSCFRotateWhileLoopPatterns(
+ RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
+ patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
+ patterns.getContext());
+}
+} // namespace scf
+} // namespace mlir
diff --git a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir b/mlir/test/Dialect/SCF/rotate-while.mlir
similarity index 95%
rename from mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
rename to mlir/test/Dialect/SCF/rotate-while.mlir
index 8954839c3c93e..d6b205ee330d0 100644
--- a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
+++ b/mlir/test/Dialect/SCF/rotate-while.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check='force-create-check=true' -split-input-file | FileCheck %s --check-prefix FORCE-CREATE-CHECK
+// RUN: mlir-opt %s -scf-rotate-while -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -scf-rotate-while='force-create-check=true' -split-input-file | FileCheck %s --check-prefix FORCE-CREATE-CHECK
func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
%cst0 = arith.constant 0 : i32
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 792430cc84b65..2fa550e2d9f87 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_library(MLIRSCFTestPasses
TestLoopParametricTiling.cpp
TestLoopUnrolling.cpp
TestSCFUtils.cpp
- TestSCFWrapInZeroTripCheck.cpp
TestUpliftWhileToFor.cpp
TestWhileOpBuilder.cpp
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
deleted file mode 100644
index 10206dd7cedf6..0000000000000
--- a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
+++ /dev/null
@@ -1,72 +0,0 @@
-//===- TestWrapInZeroTripCheck.cpp -- Passes to test SCF zero-trip-check --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the passes to test wrap-in-zero-trip-check transforms on
-// SCF loop ops.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-
-namespace {
-
-struct TestWrapWhileLoopInZeroTripCheckPass
- : public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestWrapWhileLoopInZeroTripCheckPass)
-
- StringRef getArgument() const final {
- return "test-wrap-scf-while-loop-in-zero-trip-check";
- }
-
- StringRef getDescription() const final {
- return "test scf::wrapWhileLoopInZeroTripCheck";
- }
-
- TestWrapWhileLoopInZeroTripCheckPass() = default;
- TestWrapWhileLoopInZeroTripCheckPass(
- const TestWrapWhileLoopInZeroTripCheckPass &) {}
- explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) {
- forceCreateCheck = forceCreateCheckParam;
- }
-
- void runOnOperation() override {
- func::FuncOp func = getOperation();
- MLIRContext *context = &getContext();
- IRRewriter rewriter(context);
- func.walk([&](scf::WhileOp op) {
- FailureOr<scf::WhileOp> result =
- scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
- // Ignore not implemented failure in tests. The expected output should
- // catch problems (e.g. transformation doesn't happen).
- (void)result;
- });
- }
-
- Option<bool> forceCreateCheck{
- *this, "force-create-check",
- llvm::cl::desc("Force to create zero-trip-check."),
- llvm::cl::init(false)};
-};
-
-} // namespace
-
-namespace mlir {
-namespace test {
-void registerTestSCFWrapInZeroTripCheckPasses() {
- PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8cafb0afac9ae..af99bf55a5e7c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -138,7 +138,6 @@ void registerTestRecursiveTypesPass();
void registerTestSCFUpliftWhileToFor();
void registerTestSCFUtilsPass();
void registerTestSCFWhileOpBuilderPass();
-void registerTestSCFWrapInZeroTripCheckPasses();
void registerTestShapeMappingPass();
void registerTestSliceAnalysisPass();
void registerTestTensorCopyInsertionPass();
@@ -270,7 +269,6 @@ void registerTestPasses() {
mlir::test::registerTestSCFUpliftWhileToFor();
mlir::test::registerTestSCFUtilsPass();
mlir::test::registerTestSCFWhileOpBuilderPass();
- mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
mlir::test::registerTestShapeMappingPass();
mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestTensorCopyInsertionPass();
>From d4869d59828dd67b3e64d016e89b8071f6c8bf1d Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 22 Jul 2024 10:00:10 +0100
Subject: [PATCH 2/5] Fix build failure
---
mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 3a6ed32f37c67..9a8d840ef1ee3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -16,7 +16,7 @@
namespace mlir {
class ConversionTarget;
-class SCFRotateWhileLoopPassOptions;
+struct SCFRotateWhileLoopPassOptions;
class TypeConverter;
namespace scf {
@@ -88,6 +88,10 @@ void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
/// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
/// from `while` loops.
+///
+/// Note applying these patterns recursively to newly created operations will
+/// lead to infinite recursion, so `mlir::GreedyRewriteStrictness::ExistingOps`
+/// must be used in passes using these patterns.
void populateSCFRotateWhileLoopPatterns(
RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options);
} // namespace scf
>From a3dd6a3a45015c58bc2e288954a84fad5d1f8253 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 24 Jul 2024 10:15:41 +0100
Subject: [PATCH 3/5] Second round of review comments
---
.../mlir/Dialect/SCF/Transforms/Patterns.h | 7 +----
.../SCF/Transforms/RotateWhileLoop.cpp | 27 ++++++++-----------
2 files changed, 12 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 9a8d840ef1ee3..729cbc3d837c3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -88,12 +88,7 @@ void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
/// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
/// from `while` loops.
-///
-/// Note applying these patterns recursively to newly created operations will
-/// lead to infinite recursion, so `mlir::GreedyRewriteStrictness::ExistingOps`
-/// must be used in passes using these patterns.
-void populateSCFRotateWhileLoopPatterns(
- RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options);
+void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns);
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
index f493adaad6be9..5d72051060903 100644
--- a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -19,10 +19,6 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "scf-rotate-while"
-
namespace mlir {
#define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -42,16 +38,18 @@ struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
PatternRewriter &rewriter) const final {
FailureOr<scf::WhileOp> result =
scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
- if (failed(result) || *result == whileOp) {
- LLVM_DEBUG(whileOp->emitRemark("Failed to rotate loop"));
- return failure();
- };
- return success();
+ return success(succeeded(result) && *result != whileOp);
}
bool forceCreateCheck;
};
+static void populateSCFRotateWhileLoopPatterns(
+ RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
+ patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
+ patterns.getContext());
+}
+
struct SCFRotateWhileLoopPass
: impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
using Base::Base;
@@ -61,22 +59,19 @@ struct SCFRotateWhileLoopPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
SCFRotateWhileLoopPassOptions options{forceCreateCheck};
- scf::populateSCFRotateWhileLoopPatterns(patterns, options);
+ populateSCFRotateWhileLoopPatterns(patterns, options);
// Avoid applying the pattern to a loop more than once.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
- [[maybe_unused]] LogicalResult success =
- applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
+ (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
}
};
} // namespace
namespace mlir {
namespace scf {
-void populateSCFRotateWhileLoopPatterns(
- RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
- patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
- patterns.getContext());
+void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
+ ::populateSCFRotateWhileLoopPatterns(patterns, {});
}
} // namespace scf
} // namespace mlir
>From 172d25e86b79f6c94b020bfa67e5de15dece10f7 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 24 Jul 2024 10:16:48 +0100
Subject: [PATCH 4/5] Drop forward decl
---
mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 729cbc3d837c3..5e66774d2f143 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -16,7 +16,6 @@
namespace mlir {
class ConversionTarget;
-struct SCFRotateWhileLoopPassOptions;
class TypeConverter;
namespace scf {
>From c85b9d2a8ee1f14763889d1b26883d73e54dd8a6 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 24 Jul 2024 13:05:18 +0100
Subject: [PATCH 5/5] Another round
---
.../SCF/Transforms/RotateWhileLoop.cpp | 43 ++++++++-----------
1 file changed, 19 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
index 5d72051060903..d18c9acb40d23 100644
--- a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -13,11 +13,8 @@
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/Diagnostics.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
@@ -28,42 +25,40 @@ using namespace mlir;
namespace {
struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
- RotateWhileLoopPattern(bool rotateLoop, MLIRContext *context,
- PatternBenefit benefit = 1,
- ArrayRef<StringRef> generatedNames = {})
- : OpRewritePattern<scf::WhileOp>(context, benefit, generatedNames),
- forceCreateCheck(rotateLoop) {}
+ using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::WhileOp whileOp,
PatternRewriter &rewriter) const final {
+ // Setting this option would lead to infinite recursion on a greedy driver
+ // as 'do-while' loops wouldn't be skipped.
+ constexpr bool forceCreateCheck = false;
FailureOr<scf::WhileOp> result =
scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
+ // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
+ // mechanism. 'do-while' loops are simply returned unmodified. In order to
+ // stop recursion, we check input and output operations differ.
return success(succeeded(result) && *result != whileOp);
}
-
- bool forceCreateCheck;
};
-static void populateSCFRotateWhileLoopPatterns(
- RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
- patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
- patterns.getContext());
-}
-
+/// We do not use the above pattern in this pass as we can simply walk over the
+/// `scf.while` operations and run the function.
struct SCFRotateWhileLoopPass
: impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
using Base::Base;
void runOnOperation() final {
Operation *parentOp = getOperation();
+
+ SmallVector<scf::WhileOp> workList;
+ parentOp->walk(
+ [&workList](scf::WhileOp whileOp) { workList.push_back(whileOp); });
+
MLIRContext *context = &getContext();
- RewritePatternSet patterns(context);
- SCFRotateWhileLoopPassOptions options{forceCreateCheck};
- populateSCFRotateWhileLoopPatterns(patterns, options);
- // Avoid applying the pattern to a loop more than once.
- GreedyRewriteConfig config;
- config.strictMode = GreedyRewriteStrictness::ExistingOps;
- (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
+ PatternRewriter rewriter(context);
+ for (scf::WhileOp whileOp : workList)
+ (void)scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter,
+ forceCreateCheck);
}
};
} // namespace
@@ -71,7 +66,7 @@ struct SCFRotateWhileLoopPass
namespace mlir {
namespace scf {
void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
- ::populateSCFRotateWhileLoopPatterns(patterns, {});
+ patterns.add<RotateWhileLoopPattern>(patterns.getContext());
}
} // namespace scf
} // namespace mlir
More information about the Mlir-commits
mailing list