[Mlir-commits] [mlir] [MLIR][SCF] Define `-scf-rotate-while` pass (PR #99850)

Victor Perez llvmlistbot at llvm.org
Mon Jul 29 05:56:54 PDT 2024


https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/99850

>From f746d30b74fb8b59778b9efab4174aa1248030e3 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 22 Jul 2024 09:41:10 +0100
Subject: [PATCH 1/8] [MLIR][SCF] Define `scf.while` rotation pass

Define pass rotating `scf.while` loops leveraging existing
`mlir::scf::wrapWhileLoopInZeroTripCheck` and exposing it as a pass.

This pass rotates SCF.WhileOp, mutating them from "while" loops to
"do-while" loops. A guard checking the condition for the first
iteration is inserted. Note this guard can be optimized away if the
compiler can prove the loop will be executed at least once.

Using this pass, the following while loop:

```mlir
scf.while (%arg0 = %init) : (i32) -> i64 {
  %val = .., %arg0 : i64
  %cond = arith.cmpi .., %arg0 : i32
  scf.condition(%cond) %val : i64
} do {
^bb0(%arg1: i64):
  %next = .., %arg1 : i32
  scf.yield %next : i32
}
```

Can be transformed into:

``` mlir
%pre_val = .., %init : i64
%pre_cond = arith.cmpi .., %init : i32
scf.if %pre_cond -> i64 {
  %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
    // Original after block
    %next = .., %arg1 : i32
    // Original before block
    %val = .., %next : i64
    %cond = arith.cmpi .., %next : i32
    scf.condition(%cond) %val : i64
  } do {
  ^bb0(%arg2: i64):
    %scf.yield %arg2 : i32
  }
  scf.yield %res : i64
} else {
  scf.yield %pre_val : i64
}
```

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 .../mlir/Dialect/SCF/Transforms/Passes.td     | 51 ++++++++++++
 .../mlir/Dialect/SCF/Transforms/Patterns.h    |  5 ++
 .../lib/Dialect/SCF/Transforms/CMakeLists.txt |  1 +
 .../SCF/Transforms/RotateWhileLoop.cpp        | 82 +++++++++++++++++++
 ...zero-trip-check.mlir => rotate-while.mlir} |  4 +-
 mlir/test/lib/Dialect/SCF/CMakeLists.txt      |  1 -
 .../SCF/TestSCFWrapInZeroTripCheck.cpp        | 72 ----------------
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 -
 8 files changed, 141 insertions(+), 77 deletions(-)
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
 rename mlir/test/Dialect/SCF/{wrap-while-loop-in-zero-trip-check.mlir => rotate-while.mlir} (95%)
 delete mode 100644 mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 9b29affb97c43..beb52130e0137 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -164,4 +164,55 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   }];
 }
 
+def SCFRotateWhileLoopPass : Pass<"scf-rotate-while"> {
+  let summary = "Rotate while loops, turning them into do-while loops.";
+  let description = [{
+    This pass rotates SCF.WhileOp, mutating them from "while" loops to
+    "do-while" loops. A guard checking the condition for the first iteration is
+    inserted. Note this guard can be optimized away if the compiler can prove
+    the loop will be executed at least once.
+
+    Using this pass, the following while loop:
+
+    ```mlir
+# Before:
+scf.while (%arg0 = %init) : (i32) -> i64 {
+  %val = .., %arg0 : i64
+  %cond = arith.cmpi .., %arg0 : i32
+  scf.condition(%cond) %val : i64
+} do {
+^bb0(%arg1: i64):
+  %next = .., %arg1 : i32
+  scf.yield %next : i32
+}
+    ```
+
+    Can be transformed into:
+    ``` mlir
+%pre_val = .., %init : i64
+%pre_cond = arith.cmpi .., %init : i32
+scf.if %pre_cond -> i64 {
+  %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+    // Original after block
+    %next = .., %arg1 : i32
+    // Original before block
+    %val = .., %next : i64
+    %cond = arith.cmpi .., %next : i32
+    scf.condition(%cond) %val : i64
+  } do {
+  ^bb0(%arg2: i64):
+    %scf.yield %arg2 : i32
+  }
+  scf.yield %res : i64
+} else {
+  scf.yield %pre_val : i64
+}
+    ```
+  }];
+  let options = [
+    Option<"forceCreateCheck", "force-create-check", "bool", /*default=*/"false",
+           "Create loop guard even if the loop is already in a do-while form.">
+  ];
+}
+
 #endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index fdf2570626980..3a6ed32f37c67 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -16,6 +16,7 @@
 namespace mlir {
 
 class ConversionTarget;
+class SCFRotateWhileLoopPassOptions;
 class TypeConverter;
 
 namespace scf {
@@ -85,6 +86,10 @@ void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 ///  * `after` block containing arith.addi
 void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
+/// from `while` loops.
+void populateSCFRotateWhileLoopPatterns(
+    RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options);
 } // namespace scf
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index d363ffe941fce..8c73515c608f5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
+  RotateWhileLoop.cpp
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
   WrapInZeroTripCheck.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
new file mode 100644
index 0000000000000..f493adaad6be9
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -0,0 +1,82 @@
+//===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Rotates `scf.while` loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "scf-rotate-while"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
+  RotateWhileLoopPattern(bool rotateLoop, MLIRContext *context,
+                         PatternBenefit benefit = 1,
+                         ArrayRef<StringRef> generatedNames = {})
+      : OpRewritePattern<scf::WhileOp>(context, benefit, generatedNames),
+        forceCreateCheck(rotateLoop) {}
+
+  LogicalResult matchAndRewrite(scf::WhileOp whileOp,
+                                PatternRewriter &rewriter) const final {
+    FailureOr<scf::WhileOp> result =
+        scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
+    if (failed(result) || *result == whileOp) {
+      LLVM_DEBUG(whileOp->emitRemark("Failed to rotate loop"));
+      return failure();
+    };
+    return success();
+  }
+
+  bool forceCreateCheck;
+};
+
+struct SCFRotateWhileLoopPass
+    : impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
+  using Base::Base;
+
+  void runOnOperation() final {
+    Operation *parentOp = getOperation();
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    SCFRotateWhileLoopPassOptions options{forceCreateCheck};
+    scf::populateSCFRotateWhileLoopPatterns(patterns, options);
+    // Avoid applying the pattern to a loop more than once.
+    GreedyRewriteConfig config;
+    config.strictMode = GreedyRewriteStrictness::ExistingOps;
+    [[maybe_unused]] LogicalResult success =
+        applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace scf {
+void populateSCFRotateWhileLoopPatterns(
+    RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
+  patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
+                                       patterns.getContext());
+}
+} // namespace scf
+} // namespace mlir
diff --git a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir b/mlir/test/Dialect/SCF/rotate-while.mlir
similarity index 95%
rename from mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
rename to mlir/test/Dialect/SCF/rotate-while.mlir
index 8954839c3c93e..d6b205ee330d0 100644
--- a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
+++ b/mlir/test/Dialect/SCF/rotate-while.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check -split-input-file  | FileCheck %s
-// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check='force-create-check=true' -split-input-file  | FileCheck %s --check-prefix FORCE-CREATE-CHECK
+// RUN: mlir-opt %s -scf-rotate-while -split-input-file  | FileCheck %s
+// RUN: mlir-opt %s -scf-rotate-while='force-create-check=true' -split-input-file  | FileCheck %s --check-prefix FORCE-CREATE-CHECK
 
 func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
   %cst0 = arith.constant 0 : i32
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 792430cc84b65..2fa550e2d9f87 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
   TestSCFUtils.cpp
-  TestSCFWrapInZeroTripCheck.cpp
   TestUpliftWhileToFor.cpp
   TestWhileOpBuilder.cpp
 
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
deleted file mode 100644
index 10206dd7cedf6..0000000000000
--- a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
+++ /dev/null
@@ -1,72 +0,0 @@
-//===- TestWrapInZeroTripCheck.cpp -- Passes to test SCF zero-trip-check --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the passes to test wrap-in-zero-trip-check transforms on
-// SCF loop ops.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-
-namespace {
-
-struct TestWrapWhileLoopInZeroTripCheckPass
-    : public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      TestWrapWhileLoopInZeroTripCheckPass)
-
-  StringRef getArgument() const final {
-    return "test-wrap-scf-while-loop-in-zero-trip-check";
-  }
-
-  StringRef getDescription() const final {
-    return "test scf::wrapWhileLoopInZeroTripCheck";
-  }
-
-  TestWrapWhileLoopInZeroTripCheckPass() = default;
-  TestWrapWhileLoopInZeroTripCheckPass(
-      const TestWrapWhileLoopInZeroTripCheckPass &) {}
-  explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) {
-    forceCreateCheck = forceCreateCheckParam;
-  }
-
-  void runOnOperation() override {
-    func::FuncOp func = getOperation();
-    MLIRContext *context = &getContext();
-    IRRewriter rewriter(context);
-    func.walk([&](scf::WhileOp op) {
-      FailureOr<scf::WhileOp> result =
-          scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
-      // Ignore not implemented failure in tests. The expected output should
-      // catch problems (e.g. transformation doesn't happen).
-      (void)result;
-    });
-  }
-
-  Option<bool> forceCreateCheck{
-      *this, "force-create-check",
-      llvm::cl::desc("Force to create zero-trip-check."),
-      llvm::cl::init(false)};
-};
-
-} // namespace
-
-namespace mlir {
-namespace test {
-void registerTestSCFWrapInZeroTripCheckPasses() {
-  PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8cafb0afac9ae..af99bf55a5e7c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -138,7 +138,6 @@ void registerTestRecursiveTypesPass();
 void registerTestSCFUpliftWhileToFor();
 void registerTestSCFUtilsPass();
 void registerTestSCFWhileOpBuilderPass();
-void registerTestSCFWrapInZeroTripCheckPasses();
 void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorCopyInsertionPass();
@@ -270,7 +269,6 @@ void registerTestPasses() {
   mlir::test::registerTestSCFUpliftWhileToFor();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSCFWhileOpBuilderPass();
-  mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
   mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorCopyInsertionPass();

>From d4869d59828dd67b3e64d016e89b8071f6c8bf1d Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 22 Jul 2024 10:00:10 +0100
Subject: [PATCH 2/8] Fix build failure

---
 mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 3a6ed32f37c67..9a8d840ef1ee3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -16,7 +16,7 @@
 namespace mlir {
 
 class ConversionTarget;
-class SCFRotateWhileLoopPassOptions;
+struct SCFRotateWhileLoopPassOptions;
 class TypeConverter;
 
 namespace scf {
@@ -88,6 +88,10 @@ void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
 
 /// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
 /// from `while` loops.
+///
+/// Note applying these patterns recursively to newly created operations will
+/// lead to infinite recursion, so `mlir::GreedyRewriteStrictness::ExistingOps`
+/// must be used in passes using these patterns.
 void populateSCFRotateWhileLoopPatterns(
     RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options);
 } // namespace scf

>From a3dd6a3a45015c58bc2e288954a84fad5d1f8253 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 24 Jul 2024 10:15:41 +0100
Subject: [PATCH 3/8] Second round of review comments

---
 .../mlir/Dialect/SCF/Transforms/Patterns.h    |  7 +----
 .../SCF/Transforms/RotateWhileLoop.cpp        | 27 ++++++++-----------
 2 files changed, 12 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 9a8d840ef1ee3..729cbc3d837c3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -88,12 +88,7 @@ void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
 
 /// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
 /// from `while` loops.
-///
-/// Note applying these patterns recursively to newly created operations will
-/// lead to infinite recursion, so `mlir::GreedyRewriteStrictness::ExistingOps`
-/// must be used in passes using these patterns.
-void populateSCFRotateWhileLoopPatterns(
-    RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options);
+void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns);
 } // namespace scf
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
index f493adaad6be9..5d72051060903 100644
--- a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -19,10 +19,6 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "scf-rotate-while"
-
 namespace mlir {
 #define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -42,16 +38,18 @@ struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
                                 PatternRewriter &rewriter) const final {
     FailureOr<scf::WhileOp> result =
         scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
-    if (failed(result) || *result == whileOp) {
-      LLVM_DEBUG(whileOp->emitRemark("Failed to rotate loop"));
-      return failure();
-    };
-    return success();
+    return success(succeeded(result) && *result != whileOp);
   }
 
   bool forceCreateCheck;
 };
 
+static void populateSCFRotateWhileLoopPatterns(
+    RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
+  patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
+                                       patterns.getContext());
+}
+
 struct SCFRotateWhileLoopPass
     : impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
   using Base::Base;
@@ -61,22 +59,19 @@ struct SCFRotateWhileLoopPass
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
     SCFRotateWhileLoopPassOptions options{forceCreateCheck};
-    scf::populateSCFRotateWhileLoopPatterns(patterns, options);
+    populateSCFRotateWhileLoopPatterns(patterns, options);
     // Avoid applying the pattern to a loop more than once.
     GreedyRewriteConfig config;
     config.strictMode = GreedyRewriteStrictness::ExistingOps;
-    [[maybe_unused]] LogicalResult success =
-        applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
+    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
   }
 };
 } // namespace
 
 namespace mlir {
 namespace scf {
-void populateSCFRotateWhileLoopPatterns(
-    RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
-  patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
-                                       patterns.getContext());
+void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
+  ::populateSCFRotateWhileLoopPatterns(patterns, {});
 }
 } // namespace scf
 } // namespace mlir

>From 172d25e86b79f6c94b020bfa67e5de15dece10f7 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 24 Jul 2024 10:16:48 +0100
Subject: [PATCH 4/8] Drop forward decl

---
 mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 729cbc3d837c3..5e66774d2f143 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -16,7 +16,6 @@
 namespace mlir {
 
 class ConversionTarget;
-struct SCFRotateWhileLoopPassOptions;
 class TypeConverter;
 
 namespace scf {

>From c85b9d2a8ee1f14763889d1b26883d73e54dd8a6 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 24 Jul 2024 13:05:18 +0100
Subject: [PATCH 5/8] Another round

---
 .../SCF/Transforms/RotateWhileLoop.cpp        | 43 ++++++++-----------
 1 file changed, 19 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
index 5d72051060903..d18c9acb40d23 100644
--- a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -13,11 +13,8 @@
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/IR/Diagnostics.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
@@ -28,42 +25,40 @@ using namespace mlir;
 
 namespace {
 struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
-  RotateWhileLoopPattern(bool rotateLoop, MLIRContext *context,
-                         PatternBenefit benefit = 1,
-                         ArrayRef<StringRef> generatedNames = {})
-      : OpRewritePattern<scf::WhileOp>(context, benefit, generatedNames),
-        forceCreateCheck(rotateLoop) {}
+  using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(scf::WhileOp whileOp,
                                 PatternRewriter &rewriter) const final {
+    // Setting this option would lead to infinite recursion on a greedy driver
+    // as 'do-while' loops wouldn't be skipped.
+    constexpr bool forceCreateCheck = false;
     FailureOr<scf::WhileOp> result =
         scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
+    // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
+    // mechanism. 'do-while' loops are simply returned unmodified. In order to
+    // stop recursion, we check input and output operations differ.
     return success(succeeded(result) && *result != whileOp);
   }
-
-  bool forceCreateCheck;
 };
 
-static void populateSCFRotateWhileLoopPatterns(
-    RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
-  patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
-                                       patterns.getContext());
-}
-
+/// We do not use the above pattern in this pass as we can simply walk over the
+/// `scf.while` operations and run the function.
 struct SCFRotateWhileLoopPass
     : impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
   using Base::Base;
 
   void runOnOperation() final {
     Operation *parentOp = getOperation();
+
+    SmallVector<scf::WhileOp> workList;
+    parentOp->walk(
+        [&workList](scf::WhileOp whileOp) { workList.push_back(whileOp); });
+
     MLIRContext *context = &getContext();
-    RewritePatternSet patterns(context);
-    SCFRotateWhileLoopPassOptions options{forceCreateCheck};
-    populateSCFRotateWhileLoopPatterns(patterns, options);
-    // Avoid applying the pattern to a loop more than once.
-    GreedyRewriteConfig config;
-    config.strictMode = GreedyRewriteStrictness::ExistingOps;
-    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
+    PatternRewriter rewriter(context);
+    for (scf::WhileOp whileOp : workList)
+      (void)scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter,
+                                              forceCreateCheck);
   }
 };
 } // namespace
@@ -71,7 +66,7 @@ struct SCFRotateWhileLoopPass
 namespace mlir {
 namespace scf {
 void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
-  ::populateSCFRotateWhileLoopPatterns(patterns, {});
+  patterns.add<RotateWhileLoopPattern>(patterns.getContext());
 }
 } // namespace scf
 } // namespace mlir

>From d3b2c9f964f2d9f4c820ac57cbb10fcafab7417a Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 25 Jul 2024 10:09:13 +0100
Subject: [PATCH 6/8] Split in two passes for coverage

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  4 --
 .../SCF/Transforms/RotateWhileLoop.cpp        | 15 ++--
 mlir/test/Dialect/SCF/rotate-while.mlir       |  4 +-
 mlir/test/lib/Dialect/SCF/CMakeLists.txt      |  1 +
 .../SCF/TestSCFForceWrapInZeroTripCheck.cpp   | 68 +++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 6 files changed, 77 insertions(+), 17 deletions(-)
 create mode 100644 mlir/test/lib/Dialect/SCF/TestSCFForceWrapInZeroTripCheck.cpp

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index beb52130e0137..f76fedccc6c32 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -209,10 +209,6 @@ scf.if %pre_cond -> i64 {
 }
     ```
   }];
-  let options = [
-    Option<"forceCreateCheck", "force-create-check", "bool", /*default=*/"false",
-           "Create loop guard even if the loop is already in a do-while form.">
-  ];
 }
 
 #endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
index d18c9acb40d23..1894ab5dffac2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
@@ -41,24 +42,16 @@ struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
   }
 };
 
-/// We do not use the above pattern in this pass as we can simply walk over the
-/// `scf.while` operations and run the function.
 struct SCFRotateWhileLoopPass
     : impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
   using Base::Base;
 
   void runOnOperation() final {
     Operation *parentOp = getOperation();
-
-    SmallVector<scf::WhileOp> workList;
-    parentOp->walk(
-        [&workList](scf::WhileOp whileOp) { workList.push_back(whileOp); });
-
     MLIRContext *context = &getContext();
-    PatternRewriter rewriter(context);
-    for (scf::WhileOp whileOp : workList)
-      (void)scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter,
-                                              forceCreateCheck);
+    RewritePatternSet patterns(context);
+    scf::populateSCFRotateWhileLoopPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/SCF/rotate-while.mlir b/mlir/test/Dialect/SCF/rotate-while.mlir
index d6b205ee330d0..be27a6f60aafa 100644
--- a/mlir/test/Dialect/SCF/rotate-while.mlir
+++ b/mlir/test/Dialect/SCF/rotate-while.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -scf-rotate-while -split-input-file  | FileCheck %s
-// RUN: mlir-opt %s -scf-rotate-while='force-create-check=true' -split-input-file  | FileCheck %s --check-prefix FORCE-CREATE-CHECK
+// RUN: mlir-opt %s -test-force-wrap-scf-while-loop-in-zero-trip-check -split-input-file  | FileCheck %s --check-prefix FORCE-CREATE-CHECK
 
 func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
   %cst0 = arith.constant 0 : i32
@@ -20,7 +20,7 @@ func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
 // CHECK-SAME:      %[[BOUND:.*]]: i32) -> i32 {
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : i32
 // CHECK-DAG:     %[[C5:.*]] = arith.constant 5 : i32
-// CHECK-DAG:     %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[BOUND]] : i32
+// CHECK-DAG:     %[[PRE_COND:.*]] = arith.cmpi sgt, %[[BOUND]], %[[C0]] : i32
 // CHECK-DAG:     %[[PRE_INV:.*]] = arith.addi %[[BOUND]], %[[C5]] : i32
 // CHECK:         %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
 // CHECK:           %[[WHILE:.*]]:2 = scf.while (
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 2fa550e2d9f87..c92f809b4d718 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
   TestSCFUtils.cpp
+  TestSCFForceWrapInZeroTripCheck.cpp
   TestUpliftWhileToFor.cpp
   TestWhileOpBuilder.cpp
 
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFForceWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFForceWrapInZeroTripCheck.cpp
new file mode 100644
index 0000000000000..9e3af278cee9f
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestSCFForceWrapInZeroTripCheck.cpp
@@ -0,0 +1,68 @@
+//===- TestSCFForceWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the pass to test wrap-in-zero-trip-check transforms on
+// SCF loop ops. This pass only tests the case in which transformation is
+// forced, i.e., when `forceCreateCheck = true`, as the other case is covered by
+// the `-scf-rotate-while` pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestForceWrapWhileLoopInZeroTripCheckPass
+    : PassWrapper<TestForceWrapWhileLoopInZeroTripCheckPass,
+                  OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestForceWrapWhileLoopInZeroTripCheckPass)
+
+  using PassWrapper<TestForceWrapWhileLoopInZeroTripCheckPass,
+                    OperationPass<func::FuncOp>>::PassWrapper;
+
+  StringRef getArgument() const final {
+    return "test-force-wrap-scf-while-loop-in-zero-trip-check";
+  }
+
+  StringRef getDescription() const final {
+    return "test scf::wrapWhileLoopInZeroTripCheck whith forceCreateCheck=true";
+  }
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    MLIRContext *context = &getContext();
+    IRRewriter rewriter(context);
+    func.walk([&](scf::WhileOp op) {
+      // `forceCreateCheck=false` case is already tested by the
+      // `-scf-rotate-while` pass using this function in its pattern.
+      constexpr bool forceCreateCheck = true;
+      FailureOr<scf::WhileOp> result =
+          scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
+      // Ignore not implemented failure in tests. The expected output should
+      // catch problems (e.g. transformation doesn't happen).
+      (void)result;
+    });
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestSCFForceWrapInZeroTripCheckPasses() {
+  PassRegistration<TestForceWrapWhileLoopInZeroTripCheckPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 68bc3f710e6cf..fc204e25e6b58 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -138,6 +138,7 @@ void registerTestRecursiveTypesPass();
 void registerTestSCFUpliftWhileToFor();
 void registerTestSCFUtilsPass();
 void registerTestSCFWhileOpBuilderPass();
+void registerTestSCFForceWrapInZeroTripCheckPasses();
 void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
 void registerTestSPIRVFuncSignatureConversion();
@@ -270,6 +271,7 @@ void registerTestPasses() {
   mlir::test::registerTestSCFUpliftWhileToFor();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSCFWhileOpBuilderPass();
+  mlir::test::registerTestSCFForceWrapInZeroTripCheckPasses();
   mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestSPIRVFuncSignatureConversion();

>From 8b145e0e12f66aa434b495302d8c969eab9b0b53 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 25 Jul 2024 15:46:13 +0100
Subject: [PATCH 7/8] Add failure documentation

---
 mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 71835cd178930..ea2f457c4e889 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -228,6 +228,11 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
 ///   } else {
 ///     scf.yield %pre_val : i64
 ///   }
+///
+/// Failure mechanism is not implemented for this function, so it currently
+/// always returns a `WhileOp` operation: a new one if the transformation took
+/// place or the input `whileOp` if the loop was already in a `do-while` form
+/// and `forceCreateCheck` is `false`.
 FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
                                                 RewriterBase &rewriter,
                                                 bool forceCreateCheck = false);

>From 6d40ce75f1629e5b6c495113fe2d03b9308cadc4 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 29 Jul 2024 13:56:33 +0100
Subject: [PATCH 8/8] Drop pass and only use test one

---
 .../mlir/Dialect/SCF/Transforms/Passes.td     | 47 -----------
 .../SCF/Transforms/RotateWhileLoop.cpp        | 23 +-----
 ...> wrap-while-loop-in-zero-trip-check.mlir} |  4 +-
 mlir/test/lib/Dialect/SCF/CMakeLists.txt      |  2 +-
 .../SCF/TestSCFForceWrapInZeroTripCheck.cpp   | 68 ----------------
 .../SCF/TestSCFWrapInZeroTripCheck.cpp        | 80 +++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  4 +-
 7 files changed, 86 insertions(+), 142 deletions(-)
 rename mlir/test/Dialect/SCF/{rotate-while.mlir => wrap-while-loop-in-zero-trip-check.mlir} (95%)
 delete mode 100644 mlir/test/lib/Dialect/SCF/TestSCFForceWrapInZeroTripCheck.cpp
 create mode 100644 mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index f76fedccc6c32..9b29affb97c43 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -164,51 +164,4 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   }];
 }
 
-def SCFRotateWhileLoopPass : Pass<"scf-rotate-while"> {
-  let summary = "Rotate while loops, turning them into do-while loops.";
-  let description = [{
-    This pass rotates SCF.WhileOp, mutating them from "while" loops to
-    "do-while" loops. A guard checking the condition for the first iteration is
-    inserted. Note this guard can be optimized away if the compiler can prove
-    the loop will be executed at least once.
-
-    Using this pass, the following while loop:
-
-    ```mlir
-# Before:
-scf.while (%arg0 = %init) : (i32) -> i64 {
-  %val = .., %arg0 : i64
-  %cond = arith.cmpi .., %arg0 : i32
-  scf.condition(%cond) %val : i64
-} do {
-^bb0(%arg1: i64):
-  %next = .., %arg1 : i32
-  scf.yield %next : i32
-}
-    ```
-
-    Can be transformed into:
-    ``` mlir
-%pre_val = .., %init : i64
-%pre_cond = arith.cmpi .., %init : i32
-scf.if %pre_cond -> i64 {
-  %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
-    // Original after block
-    %next = .., %arg1 : i32
-    // Original before block
-    %val = .., %next : i64
-    %cond = arith.cmpi .., %next : i32
-    scf.condition(%cond) %val : i64
-  } do {
-  ^bb0(%arg2: i64):
-    %scf.yield %arg2 : i32
-  }
-  scf.yield %res : i64
-} else {
-  scf.yield %pre_val : i64
-}
-    ```
-  }];
-}
-
 #endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
index 1894ab5dffac2..8707ec91328dc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -10,17 +10,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/SCF/Transforms/Passes.h"
-
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-namespace mlir {
-#define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
-#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
-} // namespace mlir
+#include "mlir/Dialect/SCF/IR/SCF.h"
 
 using namespace mlir;
 
@@ -41,19 +33,6 @@ struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
     return success(succeeded(result) && *result != whileOp);
   }
 };
-
-struct SCFRotateWhileLoopPass
-    : impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
-  using Base::Base;
-
-  void runOnOperation() final {
-    Operation *parentOp = getOperation();
-    MLIRContext *context = &getContext();
-    RewritePatternSet patterns(context);
-    scf::populateSCFRotateWhileLoopPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
-  }
-};
 } // namespace
 
 namespace mlir {
diff --git a/mlir/test/Dialect/SCF/rotate-while.mlir b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
similarity index 95%
rename from mlir/test/Dialect/SCF/rotate-while.mlir
rename to mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
index be27a6f60aafa..43a4693ca3d2b 100644
--- a/mlir/test/Dialect/SCF/rotate-while.mlir
+++ b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -scf-rotate-while -split-input-file  | FileCheck %s
-// RUN: mlir-opt %s -test-force-wrap-scf-while-loop-in-zero-trip-check -split-input-file  | FileCheck %s --check-prefix FORCE-CREATE-CHECK
+// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check -split-input-file  | FileCheck %s
+// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check='force-create-check=true' -split-input-file  | FileCheck %s --check-prefix FORCE-CREATE-CHECK
 
 func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
   %cst0 = arith.constant 0 : i32
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index c92f809b4d718..792430cc84b65 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,7 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
   TestSCFUtils.cpp
-  TestSCFForceWrapInZeroTripCheck.cpp
+  TestSCFWrapInZeroTripCheck.cpp
   TestUpliftWhileToFor.cpp
   TestWhileOpBuilder.cpp
 
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFForceWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFForceWrapInZeroTripCheck.cpp
deleted file mode 100644
index 9e3af278cee9f..0000000000000
--- a/mlir/test/lib/Dialect/SCF/TestSCFForceWrapInZeroTripCheck.cpp
+++ /dev/null
@@ -1,68 +0,0 @@
-//===- TestSCFForceWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the pass to test wrap-in-zero-trip-check transforms on
-// SCF loop ops. This pass only tests the case in which transformation is
-// forced, i.e., when `forceCreateCheck = true`, as the other case is covered by
-// the `-scf-rotate-while` pass.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-
-namespace {
-
-struct TestForceWrapWhileLoopInZeroTripCheckPass
-    : PassWrapper<TestForceWrapWhileLoopInZeroTripCheckPass,
-                  OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      TestForceWrapWhileLoopInZeroTripCheckPass)
-
-  using PassWrapper<TestForceWrapWhileLoopInZeroTripCheckPass,
-                    OperationPass<func::FuncOp>>::PassWrapper;
-
-  StringRef getArgument() const final {
-    return "test-force-wrap-scf-while-loop-in-zero-trip-check";
-  }
-
-  StringRef getDescription() const final {
-    return "test scf::wrapWhileLoopInZeroTripCheck whith forceCreateCheck=true";
-  }
-
-  void runOnOperation() override {
-    func::FuncOp func = getOperation();
-    MLIRContext *context = &getContext();
-    IRRewriter rewriter(context);
-    func.walk([&](scf::WhileOp op) {
-      // `forceCreateCheck=false` case is already tested by the
-      // `-scf-rotate-while` pass using this function in its pattern.
-      constexpr bool forceCreateCheck = true;
-      FailureOr<scf::WhileOp> result =
-          scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
-      // Ignore not implemented failure in tests. The expected output should
-      // catch problems (e.g. transformation doesn't happen).
-      (void)result;
-    });
-  }
-};
-
-} // namespace
-
-namespace mlir {
-namespace test {
-void registerTestSCFForceWrapInZeroTripCheckPasses() {
-  PassRegistration<TestForceWrapWhileLoopInZeroTripCheckPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
new file mode 100644
index 0000000000000..16eb328da468d
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
@@ -0,0 +1,80 @@
+//===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the pass to test wrap-in-zero-trip-check transforms on
+// SCF loop ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestWrapWhileLoopInZeroTripCheckPass
+    : PassWrapper<TestWrapWhileLoopInZeroTripCheckPass,
+                  OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestWrapWhileLoopInZeroTripCheckPass)
+
+  TestWrapWhileLoopInZeroTripCheckPass() = default;
+  TestWrapWhileLoopInZeroTripCheckPass(
+      const TestWrapWhileLoopInZeroTripCheckPass &) {}
+  explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) {
+    forceCreateCheck = forceCreateCheckParam;
+  }
+
+  StringRef getArgument() const final {
+    return "test-wrap-scf-while-loop-in-zero-trip-check";
+  }
+
+  StringRef getDescription() const final {
+    return "test scf::wrapWhileLoopInZeroTripCheck";
+  }
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    MLIRContext *context = &getContext();
+    IRRewriter rewriter(context);
+    if (forceCreateCheck) {
+      func.walk([&](scf::WhileOp op) {
+        FailureOr<scf::WhileOp> result =
+            scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
+        // Ignore not implemented failure in tests. The expected output should
+        // catch problems (e.g. transformation doesn't happen).
+        (void)result;
+      });
+    } else {
+      RewritePatternSet patterns(context);
+      scf::populateSCFRotateWhileLoopPatterns(patterns);
+      (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+    }
+  }
+
+  Option<bool> forceCreateCheck{
+      *this, "force-create-check",
+      llvm::cl::desc("Force to create zero-trip-check."),
+      llvm::cl::init(false)};
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestSCFWrapInZeroTripCheckPasses() {
+  PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 84d76016f2da0..1842fa158e75a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -139,7 +139,7 @@ void registerTestRecursiveTypesPass();
 void registerTestSCFUpliftWhileToFor();
 void registerTestSCFUtilsPass();
 void registerTestSCFWhileOpBuilderPass();
-void registerTestSCFForceWrapInZeroTripCheckPasses();
+void registerTestSCFWrapInZeroTripCheckPasses();
 void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
 void registerTestSPIRVFuncSignatureConversion();
@@ -274,7 +274,7 @@ void registerTestPasses() {
   mlir::test::registerTestSCFUpliftWhileToFor();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSCFWhileOpBuilderPass();
-  mlir::test::registerTestSCFForceWrapInZeroTripCheckPasses();
+  mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
   mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestSPIRVFuncSignatureConversion();



More information about the Mlir-commits mailing list