[Mlir-commits] [mlir] e8f07cd - [MLIR][SCF] Define `-scf-rotate-while` pass (#99850)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 30 01:06:05 PDT 2024
Author: Victor Perez
Date: 2024-07-30T10:06:01+02:00
New Revision: e8f07cdb57602d71f8960c0499765bcb000745c2
URL: https://github.com/llvm/llvm-project/commit/e8f07cdb57602d71f8960c0499765bcb000745c2
DIFF: https://github.com/llvm/llvm-project/commit/e8f07cdb57602d71f8960c0499765bcb000745c2.diff
LOG: [MLIR][SCF] Define `-scf-rotate-while` pass (#99850)
Define SCF dialect patterns rotating `scf.while` loops leveraging
existing `mlir::scf::wrapWhileLoopInZeroTripCheck`. `forceCreateCheck`
is always `false` as the pattern would lead to an infinite recursion
otherwise.
This pattern rotates `scf.while` ops, mutating them from "while" loops to
"do-while" loops. A guard checking the condition for the first iteration
is inserted. Note this guard can be optimized away if the compiler can
prove the loop will be executed at least once.
Using this pattern, the following while loop:
```mlir
scf.while (%arg0 = %init) : (i32) -> i64 {
%val = .., %arg0 : i64
%cond = arith.cmpi .., %arg0 : i32
scf.condition(%cond) %val : i64
} do {
^bb0(%arg1: i64):
%next = .., %arg1 : i32
scf.yield %next : i32
}
```
Can be transformed into:
``` mlir
%pre_val = .., %init : i64
%pre_cond = arith.cmpi .., %init : i32
scf.if %pre_cond -> i64 {
%res = scf.while (%arg1 = %va0) : (i64) -> i64 {
// Original after block
%next = .., %arg1 : i32
// Original before block
%val = .., %next : i64
%cond = arith.cmpi .., %next : i32
scf.condition(%cond) %val : i64
} do {
^bb0(%arg2: i64):
%scf.yield %arg2 : i32
}
scf.yield %res : i64
} else {
scf.yield %pre_val : i64
}
```
The test pass for `wrapWhileLoopInZeroTripCheck` has been modified to
use the new pattern when `forceCreateCheck=false`.
---------
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
Added:
mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index fdf2570626980..5e66774d2f143 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -85,6 +85,9 @@ void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
/// * `after` block containing arith.addi
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);
+/// Populate patterns to rotate `scf.while` ops, constructing `do-while` loops
+/// from `while` loops.
+void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns);
} // namespace scf
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 71835cd178930..ea2f457c4e889 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -228,6 +228,11 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
/// } else {
/// scf.yield %pre_val : i64
/// }
+///
+/// Failure mechanism is not implemented for this function, so it currently
+/// always returns a `WhileOp` operation: a new one if the transformation took
+/// place or the input `whileOp` if the loop was already in a `do-while` form
+/// and `forceCreateCheck` is `false`.
FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
RewriterBase &rewriter,
bool forceCreateCheck = false);
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index d363ffe941fce..8c73515c608f5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ParallelLoopCollapsing.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
+ RotateWhileLoop.cpp
StructuralTypeConversions.cpp
TileUsingInterface.cpp
WrapInZeroTripCheck.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
new file mode 100644
index 0000000000000..8707ec91328dc
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
@@ -0,0 +1,44 @@
+//===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Rotates `scf.while` loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+
+using namespace mlir;
+
+namespace {
+struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::WhileOp whileOp,
+ PatternRewriter &rewriter) const final {
+ // Setting this option would lead to infinite recursion on a greedy driver
+ // as 'do-while' loops wouldn't be skipped.
+ constexpr bool forceCreateCheck = false;
+ FailureOr<scf::WhileOp> result =
+ scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
+ // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
+ // mechanism. 'do-while' loops are simply returned unmodified. In order to
+ // stop recursion, we check input and output operations
diff er.
+ return success(succeeded(result) && *result != whileOp);
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace scf {
+void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
+ patterns.add<RotateWhileLoopPattern>(patterns.getContext());
+}
+} // namespace scf
+} // namespace mlir
diff --git a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
index 8954839c3c93e..43a4693ca3d2b 100644
--- a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
+++ b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
@@ -20,7 +20,7 @@ func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
// CHECK-SAME: %[[BOUND:.*]]: i32) -> i32 {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
-// CHECK-DAG: %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[BOUND]] : i32
+// CHECK-DAG: %[[PRE_COND:.*]] = arith.cmpi sgt, %[[BOUND]], %[[C0]] : i32
// CHECK-DAG: %[[PRE_INV:.*]] = arith.addi %[[BOUND]], %[[C5]] : i32
// CHECK: %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
// CHECK: %[[WHILE:.*]]:2 = scf.while (
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
index 10206dd7cedf6..7e51d67702b05 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
@@ -1,4 +1,4 @@
-//===- TestWrapInZeroTripCheck.cpp -- Passes to test SCF zero-trip-check --===//
+//===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -13,9 +13,11 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -46,13 +48,19 @@ struct TestWrapWhileLoopInZeroTripCheckPass
func::FuncOp func = getOperation();
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
- func.walk([&](scf::WhileOp op) {
- FailureOr<scf::WhileOp> result =
- scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
- // Ignore not implemented failure in tests. The expected output should
- // catch problems (e.g. transformation doesn't happen).
- (void)result;
- });
+ if (forceCreateCheck) {
+ func.walk([&](scf::WhileOp op) {
+ FailureOr<scf::WhileOp> result =
+ scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
+ // Ignore not implemented failure in tests. The expected output should
+ // catch problems (e.g. transformation doesn't happen).
+ (void)result;
+ });
+ } else {
+ RewritePatternSet patterns(context);
+ scf::populateSCFRotateWhileLoopPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+ }
}
Option<bool> forceCreateCheck{
More information about the Mlir-commits
mailing list