[Mlir-commits] [mlir] [mlir][scf] Add simple LICM pattern for `scf.while` (PR #76370)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 25 12:34:23 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Move non-side-effecting ops from `before` block if all their args are defined outside the loop.
Doesn't visit nested regions.
This cleanup is needed for `scf.while` -> `scf.for` uplifting https://github.com/llvm/llvm-project/pull/76108 as it expects `before` block consisting of single `cmp` op.
---
Full diff: https://github.com/llvm/llvm-project/pull/76370.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+31-1)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+26-3)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5570c2ec688c8a..1f5c420e4775fa 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3810,6 +3810,36 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
}
};
+/// Simple Loop Invariant Code Motion pattern for `scf.while` op.
+/// `scf.while` to `scf.for` uplifting expects `before` block consisting of
+/// single `cmp` op.
+/// Pattern moves ops from `before` block, doesn't visit nested regions.
+struct SCFWhileLICM : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+
+ DominanceInfo dom;
+ Block *body = loop.getBeforeBody();
+ for (Operation &op :
+ llvm::make_early_inc_range(body->without_terminator())) {
+ if (llvm::any_of(op.getOperands(), [&](Value arg) {
+ return !dom.properlyDominates(arg, loop);
+ }))
+ continue;
+
+ if (!isMemoryEffectFree(&op))
+ continue;
+
+ rewriter.updateRootInPlace(&op, [&]() { op.moveBefore(loop); });
+ changed = true;
+ }
+ return success(changed);
+ }
+};
+
/// Remove duplicated ConditionOp args.
struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
using OpRewritePattern::OpRewritePattern;
@@ -3879,7 +3909,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs>(context);
+ SCFWhileLICM, WhileRemoveUnusedArgs>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 52e0fdfa36d6cd..f8f4d737b381df 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1022,10 +1022,10 @@ func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) ->
// CHECK-SAME: (%[[ARG:.+]]: tensor<i32>)
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
+// CHECK: %[[COND:.*]] = arith.cmpi sgt, %[[ARG]], %[[ZERO]]
+// CHECK: %[[COND1:.*]] = tensor.extract %[[COND]][]
// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
-// CHECK: arith.cmpi sgt, %[[ARG]], %[[ZERO]]
-// CHECK: tensor.extract %{{.*}}[]
-// CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
+// CHECK: scf.condition(%[[COND1]]) %[[ARG1]], %[[ARG4]]
// CHECK: } do {
// CHECK: ^{{.*}}(%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
// CHECK: scf.yield %[[ZERO]], %[[ONE]]
@@ -1144,6 +1144,29 @@ func.func @while_duplicated_res() -> (i32, i32) {
// CHECK: }
// CHECK: return %[[RES]], %[[RES]] : i32, i32
+// -----
+
+func.func @while_licm(%arg1: i32, %arg2: i32, %arg3: i32) {
+ scf.while () : () -> () {
+ %val0 = arith.addi %arg1, %arg2 : i32
+ %val = arith.addi %val0, %arg3 : i32
+ %condition = "test.condition"(%val) : (i32) -> i1
+ scf.condition(%condition)
+ } do {
+ ^bb0():
+ "test.test"() : () -> ()
+ scf.yield
+ }
+ return
+}
+// CHECK-LABEL: @while_licm
+// CHECK-SAME: (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32)
+// CHECK: %[[VAL0:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
+// CHECK: %[[VAL1:.*]] = arith.addi %[[VAL0]], %[[ARG3]] : i32
+// CHECK: scf.while
+// CHECK-NEXT: %[[COND:.*]] = "test.condition"(%[[VAL1]]) : (i32) -> i1
+// CHECK-NEXT: scf.condition(%[[COND]])
+
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/76370
More information about the Mlir-commits
mailing list