[Mlir-commits] [mlir] [mlir][scf] Add simple LICM pattern for `scf.while` (PR #76370)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 25 12:34:23 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Move non-side-effecting ops from `before` block if all their args are defined outside the loop. 
Doesn't visit nested regions.
This cleanup is needed for `scf.while` -> `scf.for` uplifting https://github.com/llvm/llvm-project/pull/76108 as it expects `before` block consisting of single `cmp` op.

---
Full diff: https://github.com/llvm/llvm-project/pull/76370.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+31-1) 
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+26-3) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5570c2ec688c8a..1f5c420e4775fa 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3810,6 +3810,36 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
   }
 };
 
+/// Simple Loop Invariant Code Motion pattern for `scf.while` op.
+/// `scf.while` to `scf.for` uplifting expects `before` block consisting of
+/// single `cmp` op.
+/// Pattern moves ops from `before` block, doesn't visit nested regions.
+struct SCFWhileLICM : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+
+    DominanceInfo dom;
+    Block *body = loop.getBeforeBody();
+    for (Operation &op :
+         llvm::make_early_inc_range(body->without_terminator())) {
+      if (llvm::any_of(op.getOperands(), [&](Value arg) {
+            return !dom.properlyDominates(arg, loop);
+          }))
+        continue;
+
+      if (!isMemoryEffectFree(&op))
+        continue;
+
+      rewriter.updateRootInPlace(&op, [&]() { op.moveBefore(loop); });
+      changed = true;
+    }
+    return success(changed);
+  }
+};
+
 /// Remove duplicated ConditionOp args.
 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -3879,7 +3909,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
               RemoveLoopInvariantValueYielded, WhileConditionTruth,
               WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
-              WhileRemoveUnusedArgs>(context);
+              SCFWhileLICM, WhileRemoveUnusedArgs>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 52e0fdfa36d6cd..f8f4d737b381df 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1022,10 +1022,10 @@ func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) ->
 // CHECK-SAME: (%[[ARG:.+]]: tensor<i32>)
 // CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
 // CHECK:    %[[ONE:.*]] = arith.constant dense<1>
+// CHECK:    %[[COND:.*]] = arith.cmpi sgt, %[[ARG]], %[[ZERO]]
+// CHECK:    %[[COND1:.*]] = tensor.extract %[[COND]][]
 // CHECK:    %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
-// CHECK:       arith.cmpi sgt, %[[ARG]], %[[ZERO]]
-// CHECK:       tensor.extract %{{.*}}[]
-// CHECK:       scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
+// CHECK:       scf.condition(%[[COND1]]) %[[ARG1]], %[[ARG4]]
 // CHECK:    } do {
 // CHECK:     ^{{.*}}(%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
 // CHECK:       scf.yield %[[ZERO]], %[[ONE]]
@@ -1144,6 +1144,29 @@ func.func @while_duplicated_res() -> (i32, i32) {
 // CHECK:         }
 // CHECK:         return %[[RES]], %[[RES]] : i32, i32
 
+// -----
+
+func.func @while_licm(%arg1: i32, %arg2: i32, %arg3: i32) {
+  scf.while () : () -> () {
+    %val0 = arith.addi %arg1, %arg2 : i32
+    %val = arith.addi %val0, %arg3 : i32
+    %condition = "test.condition"(%val) : (i32) -> i1
+    scf.condition(%condition)
+  } do {
+  ^bb0():
+    "test.test"() : () -> ()
+    scf.yield
+  }
+  return
+}
+// CHECK-LABEL: @while_licm
+//  CHECK-SAME:  (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32)
+//       CHECK:  %[[VAL0:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
+//       CHECK:  %[[VAL1:.*]] = arith.addi %[[VAL0]], %[[ARG3]] : i32
+//       CHECK:  scf.while
+//  CHECK-NEXT:  %[[COND:.*]] = "test.condition"(%[[VAL1]]) : (i32) -> i1
+//  CHECK-NEXT:  scf.condition(%[[COND]])
+
 
 // -----
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/76370


More information about the Mlir-commits mailing list