[Mlir-commits] [mlir] [mlir][SCF]-Fix loop coalescing with	iteration arguements (PR #105488)
    Amir Bishara 
    llvmlistbot at llvm.org
       
    Wed Aug 21 11:27:14 PDT 2024
    
    
  
https://github.com/amirBish updated https://github.com/llvm/llvm-project/pull/105488
>From cf3526e4a0e3f969e229be4628282dccca82e0b3 Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Tue, 20 Aug 2024 15:22:01 +0300
Subject: [PATCH] [mlir][SCF]-Fix loop coalescing with iteration arguements
Fix a bug found when coalescing loops which have iteration
arguments, such that the inner loop's terminator may have
operands of the inner loop iteration arguments which are about
to be replaced by the outer loop's iteration arguments.
The current flow leads to crush within the IR code.
---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  12 ++
 mlir/test/Dialect/Affine/loop-coalescing.mlir | 120 ++++++++++++++++++
 2 files changed, 132 insertions(+)
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 40f82557d2eb8a..6f1a7df0ddbfd0 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -864,6 +864,18 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
 
     Operation *innerTerminator = innerLoop.getBody()->getTerminator();
     auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
+    assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
+    for (Value &yieldedVal : yieldedVals) {
+      // The yielded value may be an iteration argument of the inner loop
+      // which is about to be inlined.
+      auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
+      if (iter != innerLoop.getRegionIterArgs().end()) {
+        unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
+        // `outerLoop` iter args identical to the `innerLoop` init args.
+        assert(iterArgIndex < outerLoop.getRegionIterArgs().size());
+        yieldedVal = outerLoop.getRegionIterArgs()[iterArgIndex];
+      }
+    }
     rewriter.eraseOp(innerTerminator);
 
     SmallVector<Value> innerBlockArgs;
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index 0235000aeac538..45dd299295f640 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -114,6 +114,126 @@ func.func @unnormalized_loops() {
   return
 }
 
+func.func @noramalized_loops_with_yielded_iter_args() {
+  // CHECK: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK: %[[orig_step:.*]] = arith.constant 1
+  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c42 = arith.constant 42 : index
+  %c56 = arith.constant 56 : index
+  // The range of the new scf.
+  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+  // Updated loop bounds.
+  // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
+  %2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
+    // Inner loops must have been removed.
+    // CHECK-NOT: scf.for
+
+    // Reconstruct original IVs from the linearized one.
+    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
+      %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
+        // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        "use"(%i, %j, %k) : (index, index, index) -> ()
+        // CHECK: scf.yield %[[VAL_1]] : index
+        scf.yield %arg2 : index
+      }
+      scf.yield %0#0 : index
+    }
+    scf.yield %1#0 : index
+  }
+  return
+}
+
+func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
+  // CHECK: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK: %[[orig_step:.*]] = arith.constant 1
+  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c42 = arith.constant 42 : index
+  %c56 = arith.constant 56 : index
+  // The range of the new scf.
+  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+  // Updated loop bounds.
+  // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]], %[[VAL_2:.*]] = %[[orig_lb]]) -> (index, index) {
+  %2:2 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0, %arg1 = %c0) -> (index, index) {
+    // Inner loops must have been removed.
+    // CHECK-NOT: scf.for
+
+    // Reconstruct original IVs from the linearized one.
+    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    %1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){
+      %0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) {
+        // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        "use"(%i, %j, %k) : (index, index, index) -> ()
+        // CHECK: scf.yield %[[VAL_2]], %[[VAL_1]] : index, index
+        scf.yield %arg5, %arg4 : index, index
+      }
+      scf.yield %0#0, %0#1 : index, index
+    }
+    scf.yield %1#0, %1#1 : index, index
+  }
+  return
+}
+
+func.func @noramalized_loops_with_yielded_non_iter_args() {
+  // CHECK: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK: %[[orig_step:.*]] = arith.constant 1
+  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c42 = arith.constant 42 : index
+  %c56 = arith.constant 56 : index
+  // The range of the new scf.
+  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
+  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+
+  // Updated loop bounds.
+  // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
+  %2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
+    // Inner loops must have been removed.
+    // CHECK-NOT: scf.for
+
+    // Reconstruct original IVs from the linearized one.
+    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
+    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
+      %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
+        // CHECK: %[[res:.*]] = "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        %res = "use"(%i, %j, %k) : (index, index, index) -> (index)
+        // CHECK: scf.yield %[[res]] : index
+        scf.yield %res : index
+      }
+      scf.yield %0#0 : index
+    }
+    scf.yield %1#0 : index
+  }
+  return
+}
+
 // Check with parametric loop bounds and steps, capture the bounds here.
 // CHECK-LABEL: @parametric
 // CHECK-SAME: %[[orig_lb1:[A-Za-z0-9]+]]:
    
    
More information about the Mlir-commits
mailing list