[Mlir-commits] [mlir] [mlir][SCF]-Fix loop coalescing with iteration arguements (PR #105488)

Amir Bishara llvmlistbot at llvm.org
Wed Aug 21 02:18:07 PDT 2024


https://github.com/amirBish created https://github.com/llvm/llvm-project/pull/105488

Fix a bug found when coalescing loops which have iteration arguments, such that the inner loop's terminator may have operands of the inner loop iteration arguments which are about to be replaced by the outer loop's iteration arguments.

The current flow leads to crush within the IR code.

>From 712264b898c4d00806f176465634ab9f7fcf0e46 Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Tue, 20 Aug 2024 15:22:01 +0300
Subject: [PATCH] [mlir][SCF]-Fix loop coalescing with iteration arguements

Fix a bug found when coalescing loops which have iteration
arguments, such that the inner loop's terminator may have
operands of the inner loop iteration arguments which are about
to be replaced by the outer loop's iteration arguments.

The current flow leads to crush within the IR code.
---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  12 ++
 mlir/test/Dialect/Affine/loop-coalescing.mlir | 120 ++++++++++++++++++
 2 files changed, 132 insertions(+)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 40f82557d2eb8a..23e6d511d24fb3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -864,6 +864,18 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
 
     Operation *innerTerminator = innerLoop.getBody()->getTerminator();
     auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
+    llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs());
+    for (Value &yieldedVal : yieldedVals) {
+      // The yielded value may be an iteration argument of the inner loop
+      // which is about to be inlined.
+      auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
+      if (iter != innerLoop.getRegionIterArgs().end()) {
+        unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
+        // `outerLoop` iter args identical to the `innerLoop` init args.
+        assert(iterArgIndex < outerLoop.getRegionIterArgs().size());
+        yieldedVal = outerLoop.getRegionIterArgs()[iterArgIndex];
+      }
+    }
     rewriter.eraseOp(innerTerminator);
 
     SmallVector<Value> innerBlockArgs;
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index 0235000aeac538..45dd299295f640 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -114,6 +114,126 @@ func.func @unnormalized_loops() {
   return
 }
 
+func.func @noramalized_loops_with_yielded_iter_args() {
+  // CHECK: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK: %[[orig_step:.*]] = arith.constant 1
+  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c42 = arith.constant 42 : index
+  %c56 = arith.constant 56 : index
+  // The range of the new scf.
+  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+  // Updated loop bounds.
+  // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
+  %2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
+    // Inner loops must have been removed.
+    // CHECK-NOT: scf.for
+
+    // Reconstruct original IVs from the linearized one.
+    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
+      %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
+        // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        "use"(%i, %j, %k) : (index, index, index) -> ()
+        // CHECK: scf.yield %[[VAL_1]] : index
+        scf.yield %arg2 : index
+      }
+      scf.yield %0#0 : index
+    }
+    scf.yield %1#0 : index
+  }
+  return
+}
+
+func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
+  // CHECK: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK: %[[orig_step:.*]] = arith.constant 1
+  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c42 = arith.constant 42 : index
+  %c56 = arith.constant 56 : index
+  // The range of the new scf.
+  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+  // Updated loop bounds.
+  // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]], %[[VAL_2:.*]] = %[[orig_lb]]) -> (index, index) {
+  %2:2 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0, %arg1 = %c0) -> (index, index) {
+    // Inner loops must have been removed.
+    // CHECK-NOT: scf.for
+
+    // Reconstruct original IVs from the linearized one.
+    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    %1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){
+      %0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) {
+        // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        "use"(%i, %j, %k) : (index, index, index) -> ()
+        // CHECK: scf.yield %[[VAL_2]], %[[VAL_1]] : index, index
+        scf.yield %arg5, %arg4 : index, index
+      }
+      scf.yield %0#0, %0#1 : index, index
+    }
+    scf.yield %1#0, %1#1 : index, index
+  }
+  return
+}
+
+func.func @noramalized_loops_with_yielded_non_iter_args() {
+  // CHECK: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK: %[[orig_step:.*]] = arith.constant 1
+  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c42 = arith.constant 42 : index
+  %c56 = arith.constant 56 : index
+  // The range of the new scf.
+  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+  // Updated loop bounds.
+  // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
+  %2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
+    // Inner loops must have been removed.
+    // CHECK-NOT: scf.for
+
+    // Reconstruct original IVs from the linearized one.
+    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
+      %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
+        // CHECK: %[[res:.*]] = "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        %res = "use"(%i, %j, %k) : (index, index, index) -> (index)
+        // CHECK: scf.yield %[[res]] : index
+        scf.yield %res : index
+      }
+      scf.yield %0#0 : index
+    }
+    scf.yield %1#0 : index
+  }
+  return
+}
+
 // Check with parametric loop bounds and steps, capture the bounds here.
 // CHECK-LABEL: @parametric
 // CHECK-SAME: %[[orig_lb1:[A-Za-z0-9]+]]:



More information about the Mlir-commits mailing list