[Mlir-commits] [mlir] [MLIR][SCF] Define `scf.while` rotation pass (PR #99850)

Victor Perez llvmlistbot at llvm.org
Mon Jul 22 01:49:14 PDT 2024


https://github.com/victor-eds created https://github.com/llvm/llvm-project/pull/99850

Define pass rotating `scf.while` loops leveraging existing `mlir::scf::wrapWhileLoopInZeroTripCheck` and exposing it as a pass.

This pass rotates SCF.WhileOp, mutating them from "while" loops to "do-while" loops. A guard checking the condition for the first iteration is inserted. Note this guard can be optimized away if the compiler can prove the loop will be executed at least once.

Using this pass, the following while loop:

```mlir
scf.while (%arg0 = %init) : (i32) -> i64 {
  %val = .., %arg0 : i64
  %cond = arith.cmpi .., %arg0 : i32
  scf.condition(%cond) %val : i64
} do {
^bb0(%arg1: i64):
  %next = .., %arg1 : i32
  scf.yield %next : i32
}
```

Can be transformed into:

``` mlir
%pre_val = .., %init : i64
%pre_cond = arith.cmpi .., %init : i32
scf.if %pre_cond -> i64 {
  %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
    // Original after block
    %next = .., %arg1 : i32
    // Original before block
    %val = .., %next : i64
    %cond = arith.cmpi .., %next : i32
    scf.condition(%cond) %val : i64
  } do {
  ^bb0(%arg2: i64):
    %scf.yield %arg2 : i32
  }
  scf.yield %res : i64
} else {
  scf.yield %pre_val : i64
}
```

>From f746d30b74fb8b59778b9efab4174aa1248030e3 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 22 Jul 2024 09:41:10 +0100
Subject: [PATCH] [MLIR][SCF] Define `scf.while` rotation pass

Define pass rotating `scf.while` loops leveraging existing
`mlir::scf::wrapWhileLoopInZeroTripCheck` and exposing it as a pass.

This pass rotates SCF.WhileOp, mutating them from "while" loops to
"do-while" loops. A guard checking the condition for the first
iteration is inserted. Note this guard can be optimized away if the
compiler can prove the loop will be executed at least once.

Using this pass, the following while loop:

```mlir
scf.while (%arg0 = %init) : (i32) -> i64 {
  %val = .., %arg0 : i64
  %cond = arith.cmpi .., %arg0 : i32
  scf.condition(%cond) %val : i64
} do {
^bb0(%arg1: i64):
  %next = .., %arg1 : i32
  scf.yield %next : i32
}
```

Can be transformed into:

``` mlir
%pre_val = .., %init : i64
%pre_cond = arith.cmpi .., %init : i32
scf.if %pre_cond -> i64 {
  %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
    // Original after block
    %next = .., %arg1 : i32
    // Original before block
    %val = .., %next : i64
    %cond = arith.cmpi .., %next : i32
    scf.condition(%cond) %val : i64
  } do {
  ^bb0(%arg2: i64):
    %scf.yield %arg2 : i32
  }
  scf.yield %res : i64
} else {
  scf.yield %pre_val : i64
}
```

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 .../mlir/Dialect/SCF/Transforms/Passes.td     | 51 ++++++++++++
 .../mlir/Dialect/SCF/Transforms/Patterns.h    |  5 ++
 .../lib/Dialect/SCF/Transforms/CMakeLists.txt |  1 +
 .../SCF/Transforms/RotateWhileLoop.cpp        | 82 +++++++++++++++++++
 ...zero-trip-check.mlir => rotate-while.mlir} |  4 +-
 mlir/test/lib/Dialect/SCF/CMakeLists.txt      |  1 -
 .../SCF/TestSCFWrapInZeroTripCheck.cpp        | 72 ----------------
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 -
 8 files changed, 141 insertions(+), 77 deletions(-)
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
 rename mlir/test/Dialect/SCF/{wrap-while-loop-in-zero-trip-check.mlir => rotate-while.mlir} (95%)
 delete mode 100644 mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 9b29affb97c43..beb52130e0137 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -164,4 +164,55 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   }];
 }
 
+def SCFRotateWhileLoopPass : Pass<"scf-rotate-while"> {
+  let summary = "Rotate while loops, turning them into do-while loops.";
+  let description = [{
+    This pass rotates SCF.WhileOp, mutating them from "while" loops to
+    "do-while" loops. A guard checking the condition for the first iteration is
+    inserted. Note this guard can be optimized away if the compiler can prove
+    the loop will be executed at least once.
+
+    Using this pass, the following while loop:
+
+    ```mlir
+# Before:
+scf.while (%arg0 = %init) : (i32) -> i64 {
+  %val = .., %arg0 : i64
+  %cond = arith.cmpi .., %arg0 : i32
+  scf.condition(%cond) %val : i64
+} do {
+^bb0(%arg1: i64):
+  %next = .., %arg1 : i32
+  scf.yield %next : i32
+}
+    ```
+
+    Can be transformed into:
+    ``` mlir
+%pre_val = .., %init : i64
+%pre_cond = arith.cmpi .., %init : i32
+scf.if %pre_cond -> i64 {
+  %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+    // Original after block
+    %next = .., %arg1 : i32
+    // Original before block
+    %val = .., %next : i64
+    %cond = arith.cmpi .., %next : i32
+    scf.condition(%cond) %val : i64
+  } do {
+  ^bb0(%arg2: i64):
+    %scf.yield %arg2 : i32
+  }
+  scf.yield %res : i64
+} else {
+  scf.yield %pre_val : i64
+}
+    ```
+  }];
+  let options = [
+    Option<"forceCreateCheck", "force-create-check", "bool", /*default=*/"false",
+           "Create loop guard even if the loop is already in a do-while form.">
+  ];
+}
+
 #endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index fdf2570626980..3a6ed32f37c67 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -16,6 +16,7 @@
 namespace mlir {
 
 class ConversionTarget;
+class SCFRotateWhileLoopPassOptions;
 class TypeConverter;
 
 namespace scf {
@@ -85,6 +86,10 @@ void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 ///  * `after` block containing arith.addi
 void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
+/// from `while` loops.
+void populateSCFRotateWhileLoopPatterns(
+    RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options);
 } // namespace scf
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index d363ffe941fce..8c73515c608f5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
+  RotateWhileLoop.cpp
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
   WrapInZeroTripCheck.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
new file mode 100644
index 0000000000000..f493adaad6be9
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -0,0 +1,82 @@
+//===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Rotates `scf.while` loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "scf-rotate-while"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFROTATEWHILELOOPPASS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
+  RotateWhileLoopPattern(bool rotateLoop, MLIRContext *context,
+                         PatternBenefit benefit = 1,
+                         ArrayRef<StringRef> generatedNames = {})
+      : OpRewritePattern<scf::WhileOp>(context, benefit, generatedNames),
+        forceCreateCheck(rotateLoop) {}
+
+  LogicalResult matchAndRewrite(scf::WhileOp whileOp,
+                                PatternRewriter &rewriter) const final {
+    FailureOr<scf::WhileOp> result =
+        scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
+    if (failed(result) || *result == whileOp) {
+      LLVM_DEBUG(whileOp->emitRemark("Failed to rotate loop"));
+      return failure();
+    };
+    return success();
+  }
+
+  bool forceCreateCheck;
+};
+
+struct SCFRotateWhileLoopPass
+    : impl::SCFRotateWhileLoopPassBase<SCFRotateWhileLoopPass> {
+  using Base::Base;
+
+  void runOnOperation() final {
+    Operation *parentOp = getOperation();
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    SCFRotateWhileLoopPassOptions options{forceCreateCheck};
+    scf::populateSCFRotateWhileLoopPatterns(patterns, options);
+    // Avoid applying the pattern to a loop more than once.
+    GreedyRewriteConfig config;
+    config.strictMode = GreedyRewriteStrictness::ExistingOps;
+    [[maybe_unused]] LogicalResult success =
+        applyPatternsAndFoldGreedily(parentOp, std::move(patterns), config);
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace scf {
+void populateSCFRotateWhileLoopPatterns(
+    RewritePatternSet &patterns, const SCFRotateWhileLoopPassOptions &options) {
+  patterns.add<RotateWhileLoopPattern>(options.forceCreateCheck,
+                                       patterns.getContext());
+}
+} // namespace scf
+} // namespace mlir
diff --git a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir b/mlir/test/Dialect/SCF/rotate-while.mlir
similarity index 95%
rename from mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
rename to mlir/test/Dialect/SCF/rotate-while.mlir
index 8954839c3c93e..d6b205ee330d0 100644
--- a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
+++ b/mlir/test/Dialect/SCF/rotate-while.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check -split-input-file  | FileCheck %s
-// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check='force-create-check=true' -split-input-file  | FileCheck %s --check-prefix FORCE-CREATE-CHECK
+// RUN: mlir-opt %s -scf-rotate-while -split-input-file  | FileCheck %s
+// RUN: mlir-opt %s -scf-rotate-while='force-create-check=true' -split-input-file  | FileCheck %s --check-prefix FORCE-CREATE-CHECK
 
 func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
   %cst0 = arith.constant 0 : i32
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 792430cc84b65..2fa550e2d9f87 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
   TestSCFUtils.cpp
-  TestSCFWrapInZeroTripCheck.cpp
   TestUpliftWhileToFor.cpp
   TestWhileOpBuilder.cpp
 
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
deleted file mode 100644
index 10206dd7cedf6..0000000000000
--- a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
+++ /dev/null
@@ -1,72 +0,0 @@
-//===- TestWrapInZeroTripCheck.cpp -- Passes to test SCF zero-trip-check --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the passes to test wrap-in-zero-trip-check transforms on
-// SCF loop ops.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-
-using namespace mlir;
-
-namespace {
-
-struct TestWrapWhileLoopInZeroTripCheckPass
-    : public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      TestWrapWhileLoopInZeroTripCheckPass)
-
-  StringRef getArgument() const final {
-    return "test-wrap-scf-while-loop-in-zero-trip-check";
-  }
-
-  StringRef getDescription() const final {
-    return "test scf::wrapWhileLoopInZeroTripCheck";
-  }
-
-  TestWrapWhileLoopInZeroTripCheckPass() = default;
-  TestWrapWhileLoopInZeroTripCheckPass(
-      const TestWrapWhileLoopInZeroTripCheckPass &) {}
-  explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) {
-    forceCreateCheck = forceCreateCheckParam;
-  }
-
-  void runOnOperation() override {
-    func::FuncOp func = getOperation();
-    MLIRContext *context = &getContext();
-    IRRewriter rewriter(context);
-    func.walk([&](scf::WhileOp op) {
-      FailureOr<scf::WhileOp> result =
-          scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
-      // Ignore not implemented failure in tests. The expected output should
-      // catch problems (e.g. transformation doesn't happen).
-      (void)result;
-    });
-  }
-
-  Option<bool> forceCreateCheck{
-      *this, "force-create-check",
-      llvm::cl::desc("Force to create zero-trip-check."),
-      llvm::cl::init(false)};
-};
-
-} // namespace
-
-namespace mlir {
-namespace test {
-void registerTestSCFWrapInZeroTripCheckPasses() {
-  PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8cafb0afac9ae..af99bf55a5e7c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -138,7 +138,6 @@ void registerTestRecursiveTypesPass();
 void registerTestSCFUpliftWhileToFor();
 void registerTestSCFUtilsPass();
 void registerTestSCFWhileOpBuilderPass();
-void registerTestSCFWrapInZeroTripCheckPasses();
 void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorCopyInsertionPass();
@@ -270,7 +269,6 @@ void registerTestPasses() {
   mlir::test::registerTestSCFUpliftWhileToFor();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSCFWhileOpBuilderPass();
-  mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
   mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorCopyInsertionPass();



More information about the Mlir-commits mailing list