[Mlir-commits] [mlir] e8f07cd - [MLIR][SCF] Define `-scf-rotate-while` pass (#99850)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 30 01:06:05 PDT 2024


Author: Victor Perez
Date: 2024-07-30T10:06:01+02:00
New Revision: e8f07cdb57602d71f8960c0499765bcb000745c2

URL: https://github.com/llvm/llvm-project/commit/e8f07cdb57602d71f8960c0499765bcb000745c2
DIFF: https://github.com/llvm/llvm-project/commit/e8f07cdb57602d71f8960c0499765bcb000745c2.diff

LOG: [MLIR][SCF] Define `-scf-rotate-while` pass (#99850)

Define SCF dialect patterns rotating `scf.while` loops leveraging
existing `mlir::scf::wrapWhileLoopInZeroTripCheck`. `forceCreateCheck`
is always `false` as the pattern would lead to an infinite recursion
otherwise.

This pattern rotates `scf.while` ops, mutating them from "while" loops to
"do-while" loops. A guard checking the condition for the first iteration
is inserted. Note this guard can be optimized away if the compiler can
prove the loop will be executed at least once.

Using this pattern, the following while loop:

```mlir
scf.while (%arg0 = %init) : (i32) -> i64 {
  %val = .., %arg0 : i64
  %cond = arith.cmpi .., %arg0 : i32
  scf.condition(%cond) %val : i64
} do {
^bb0(%arg1: i64):
  %next = .., %arg1 : i32
  scf.yield %next : i32
}
```

Can be transformed into:

``` mlir
%pre_val = .., %init : i64
%pre_cond = arith.cmpi .., %init : i32
scf.if %pre_cond -> i64 {
  %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
    // Original after block
    %next = .., %arg1 : i32
    // Original before block
    %val = .., %next : i64
    %cond = arith.cmpi .., %next : i32
    scf.condition(%cond) %val : i64
  } do {
  ^bb0(%arg2: i64):
    %scf.yield %arg2 : i32
  }
  scf.yield %res : i64
} else {
  scf.yield %pre_val : i64
}
```

The test pass for `wrapWhileLoopInZeroTripCheck` has been modified to
use the new pattern when `forceCreateCheck=false`.

---------

Signed-off-by: Victor Perez <victor.perez at codeplay.com>

Added: 
    mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
    mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
    mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index fdf2570626980..5e66774d2f143 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -85,6 +85,9 @@ void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 ///  * `after` block containing arith.addi
 void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
+/// from `while` loops.
+void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns);
 } // namespace scf
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 71835cd178930..ea2f457c4e889 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -228,6 +228,11 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
 ///   } else {
 ///     scf.yield %pre_val : i64
 ///   }
+///
+/// Failure mechanism is not implemented for this function, so it currently
+/// always returns a `WhileOp` operation: a new one if the transformation took
+/// place or the input `whileOp` if the loop was already in a `do-while` form
+/// and `forceCreateCheck` is `false`.
 FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
                                                 RewriterBase &rewriter,
                                                 bool forceCreateCheck = false);

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index d363ffe941fce..8c73515c608f5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
+  RotateWhileLoop.cpp
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
   WrapInZeroTripCheck.cpp

diff  --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
new file mode 100644
index 0000000000000..8707ec91328dc
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -0,0 +1,44 @@
+//===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Rotates `scf.while` loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+
+using namespace mlir;
+
+namespace {
+struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::WhileOp whileOp,
+                                PatternRewriter &rewriter) const final {
+    // Setting this option would lead to infinite recursion on a greedy driver
+    // as 'do-while' loops wouldn't be skipped.
+    constexpr bool forceCreateCheck = false;
+    FailureOr<scf::WhileOp> result =
+        scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
+    // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
+    // mechanism. 'do-while' loops are simply returned unmodified. In order to
+    // stop recursion, we check input and output operations 
diff er.
+    return success(succeeded(result) && *result != whileOp);
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace scf {
+void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
+  patterns.add<RotateWhileLoopPattern>(patterns.getContext());
+}
+} // namespace scf
+} // namespace mlir

diff  --git a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
index 8954839c3c93e..43a4693ca3d2b 100644
--- a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
+++ b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
@@ -20,7 +20,7 @@ func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
 // CHECK-SAME:      %[[BOUND:.*]]: i32) -> i32 {
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : i32
 // CHECK-DAG:     %[[C5:.*]] = arith.constant 5 : i32
-// CHECK-DAG:     %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[BOUND]] : i32
+// CHECK-DAG:     %[[PRE_COND:.*]] = arith.cmpi sgt, %[[BOUND]], %[[C0]] : i32
 // CHECK-DAG:     %[[PRE_INV:.*]] = arith.addi %[[BOUND]], %[[C5]] : i32
 // CHECK:         %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
 // CHECK:           %[[WHILE:.*]]:2 = scf.while (

diff  --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
index 10206dd7cedf6..7e51d67702b05 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
@@ -1,4 +1,4 @@
-//===- TestWrapInZeroTripCheck.cpp -- Passes to test SCF zero-trip-check --===//
+//===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -13,9 +13,11 @@
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
@@ -46,13 +48,19 @@ struct TestWrapWhileLoopInZeroTripCheckPass
     func::FuncOp func = getOperation();
     MLIRContext *context = &getContext();
     IRRewriter rewriter(context);
-    func.walk([&](scf::WhileOp op) {
-      FailureOr<scf::WhileOp> result =
-          scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
-      // Ignore not implemented failure in tests. The expected output should
-      // catch problems (e.g. transformation doesn't happen).
-      (void)result;
-    });
+    if (forceCreateCheck) {
+      func.walk([&](scf::WhileOp op) {
+        FailureOr<scf::WhileOp> result =
+            scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
+        // Ignore not implemented failure in tests. The expected output should
+        // catch problems (e.g. transformation doesn't happen).
+        (void)result;
+      });
+    } else {
+      RewritePatternSet patterns(context);
+      scf::populateSCFRotateWhileLoopPatterns(patterns);
+      (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+    }
   }
 
   Option<bool> forceCreateCheck{


        


More information about the Mlir-commits mailing list