[Mlir-commits] [mlir] [mlir][SCF]-Fix loop coalescing with iteration arguements (PR #105488)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 21 02:18:37 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-scf
Author: Amir Bishara (amirBish)
<details>
<summary>Changes</summary>
Fix a bug found when coalescing loops which have iteration arguments, such that the inner loop's terminator may have operands of the inner loop iteration arguments which are about to be replaced by the outer loop's iteration arguments.
The current flow leads to crush within the IR code.
---
Full diff: https://github.com/llvm/llvm-project/pull/105488.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+12)
- (modified) mlir/test/Dialect/Affine/loop-coalescing.mlir (+120)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 40f82557d2eb8a..23e6d511d24fb3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -864,6 +864,18 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
Operation *innerTerminator = innerLoop.getBody()->getTerminator();
auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
+ llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs());
+ for (Value &yieldedVal : yieldedVals) {
+ // The yielded value may be an iteration argument of the inner loop
+ // which is about to be inlined.
+ auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
+ if (iter != innerLoop.getRegionIterArgs().end()) {
+ unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
+ // `outerLoop` iter args identical to the `innerLoop` init args.
+ assert(iterArgIndex < outerLoop.getRegionIterArgs().size());
+ yieldedVal = outerLoop.getRegionIterArgs()[iterArgIndex];
+ }
+ }
rewriter.eraseOp(innerTerminator);
SmallVector<Value> innerBlockArgs;
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index 0235000aeac538..45dd299295f640 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -114,6 +114,126 @@ func.func @unnormalized_loops() {
return
}
+func.func @noramalized_loops_with_yielded_iter_args() {
+ // CHECK: %[[orig_lb:.*]] = arith.constant 0
+ // CHECK: %[[orig_step:.*]] = arith.constant 1
+ // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+ // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+ // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c42 = arith.constant 42 : index
+ %c56 = arith.constant 56 : index
+ // The range of the new scf.
+ // CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+ // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+ // Updated loop bounds.
+ // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
+ %2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
+ // Inner loops must have been removed.
+ // CHECK-NOT: scf.for
+
+ // Reconstruct original IVs from the linearized one.
+ // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+ // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+ // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+ // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+ %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
+ %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
+ // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+ "use"(%i, %j, %k) : (index, index, index) -> ()
+ // CHECK: scf.yield %[[VAL_1]] : index
+ scf.yield %arg2 : index
+ }
+ scf.yield %0#0 : index
+ }
+ scf.yield %1#0 : index
+ }
+ return
+}
+
+func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
+ // CHECK: %[[orig_lb:.*]] = arith.constant 0
+ // CHECK: %[[orig_step:.*]] = arith.constant 1
+ // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+ // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+ // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c42 = arith.constant 42 : index
+ %c56 = arith.constant 56 : index
+ // The range of the new scf.
+ // CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+ // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+ // Updated loop bounds.
+ // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]], %[[VAL_2:.*]] = %[[orig_lb]]) -> (index, index) {
+ %2:2 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0, %arg1 = %c0) -> (index, index) {
+ // Inner loops must have been removed.
+ // CHECK-NOT: scf.for
+
+ // Reconstruct original IVs from the linearized one.
+ // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+ // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+ // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+ // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+ %1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){
+ %0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) {
+ // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+ "use"(%i, %j, %k) : (index, index, index) -> ()
+ // CHECK: scf.yield %[[VAL_2]], %[[VAL_1]] : index, index
+ scf.yield %arg5, %arg4 : index, index
+ }
+ scf.yield %0#0, %0#1 : index, index
+ }
+ scf.yield %1#0, %1#1 : index, index
+ }
+ return
+}
+
+func.func @noramalized_loops_with_yielded_non_iter_args() {
+ // CHECK: %[[orig_lb:.*]] = arith.constant 0
+ // CHECK: %[[orig_step:.*]] = arith.constant 1
+ // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+ // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+ // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c42 = arith.constant 42 : index
+ %c56 = arith.constant 56 : index
+ // The range of the new scf.
+ // CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+ // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+ // Updated loop bounds.
+ // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
+ %2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
+ // Inner loops must have been removed.
+ // CHECK-NOT: scf.for
+
+ // Reconstruct original IVs from the linearized one.
+ // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+ // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+ // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+ // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+ %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
+ %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
+ // CHECK: %[[res:.*]] = "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+ %res = "use"(%i, %j, %k) : (index, index, index) -> (index)
+ // CHECK: scf.yield %[[res]] : index
+ scf.yield %res : index
+ }
+ scf.yield %0#0 : index
+ }
+ scf.yield %1#0 : index
+ }
+ return
+}
+
// Check with parametric loop bounds and steps, capture the bounds here.
// CHECK-LABEL: @parametric
// CHECK-SAME: %[[orig_lb1:[A-Za-z0-9]+]]:
``````````
</details>
https://github.com/llvm/llvm-project/pull/105488
More information about the Mlir-commits
mailing list