[Mlir-commits] [mlir] [mlir][scf] Fix `FoldTensorCastOfOutputIntoForallOp` for multi-result scf.forall (PR #173271)

Longsheng Mou llvmlistbot at llvm.org
Thu Dec 25 05:32:37 PST 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173271

>From e2407ddc5b807b26aef076cf51410e4a92fcceda Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 22 Dec 2025 23:24:56 +0800
Subject: [PATCH 1/3] [mlir][scf] Fix FoldTensorCastOfOutputIntoForallOp for
 multi-result scf.forall

This PR fixes a bug in FoldTensorCastOfOutputIntoForallOp where incorrect folding occurs when `scf.forall` has multiple results and `parallel_insert_slice` operations yield results in a non-ascending order. The fix introduces a `yieldOpToIterArgsIndex` mapping to correctly associate yield operations with their corresponding output iterator arguments.
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 21 +++++++++-----
 mlir/test/Dialect/SCF/canonicalize.mlir | 37 +++++++++++++++++++++++++
 2 files changed, 51 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..82bda1daa8e64 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1986,6 +1986,15 @@ struct FoldTensorCastOfOutputIntoForallOp
     if (tensorCastProducers.empty())
       return failure();
 
+    llvm::SmallMapVector<Operation *, int64_t, 2> yieldOpToIterArgsIndex;
+    for (auto [index, iterArg] :
+         llvm::enumerate(forallOp.getRegionIterArgs())) {
+      for (Operation *user : iterArg.getUsers()) {
+        if (isa<ParallelCombiningOpInterface>(user))
+          yieldOpToIterArgsIndex[user] = index;
+      }
+    }
+
     // Create new loop.
     Location loc = forallOp.getLoc();
     auto newForallOp = ForallOp::create(
@@ -2012,13 +2021,11 @@ struct FoldTensorCastOfOutputIntoForallOp
     // After `mergeBlocks` happened, the destinations in the terminator were
     // mapped to the tensor.cast old-typed results of the output bbArgs. The
     // destination have to be updated to point to the output bbArgs directly.
-    auto terminator = newForallOp.getTerminator();
-    for (auto [yieldingOp, outputBlockArg] : llvm::zip(
-             terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
-      if (auto parallelCombingingOp =
-              dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
-        parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
-      }
+    auto newOutputIterArgs = newForallOp.getRegionIterArgs();
+    for (auto [yieldOp, iterArgsIndex] : yieldOpToIterArgsIndex) {
+      auto parallelCombiningOp = cast<ParallelCombiningOpInterface>(yieldOp);
+      parallelCombiningOp.getUpdatedDestinations().assign(
+          newOutputIterArgs[iterArgsIndex]);
     }
 
     // Cast results back to the original types.
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..3cd018d4729cf 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2028,6 +2028,43 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
 
 // -----
 
+// CHECK-LABEL:   func.func @fold_tensor_cast_into_forall_with_multiple_result(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<16xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: tensor<8xf32>) -> (tensor<?xf32>, tensor<64xf32>) {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 8 : index
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[EMPTY_0:.*]] = tensor.empty() : tensor<32xf32>
+// CHECK:           %[[EMPTY_1:.*]] = tensor.empty() : tensor<64xf32>
+// CHECK:           %[[FORALL_0:.*]]:2 = scf.forall (%[[VAL_0:.*]]) in (4) shared_outs(%[[VAL_1:.*]] = %[[EMPTY_0]], %[[VAL_2:.*]] = %[[EMPTY_1]]) -> (tensor<32xf32>, tensor<64xf32>) {
+// CHECK:             %[[MULI_0:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_0]] : index
+// CHECK:             %[[MULI_1:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_1]] : index
+// CHECK:             scf.forall.in_parallel {
+// CHECK:               tensor.parallel_insert_slice %[[ARG0]] into %[[VAL_2]]{{\[}}%[[MULI_1]]] [16] [1] : tensor<16xf32> into tensor<64xf32>
+// CHECK:               tensor.parallel_insert_slice %[[ARG1]] into %[[VAL_1]]{{\[}}%[[MULI_0]]] [8] [1] : tensor<8xf32> into tensor<32xf32>
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[CAST_0:.*]] = tensor.cast %[[FORALL_0]]#0 : tensor<32xf32> to tensor<?xf32>
+// CHECK:           return %[[CAST_0]], %[[FORALL_0]]#1 : tensor<?xf32>, tensor<64xf32>
+// CHECK:         }
+func.func @fold_tensor_cast_into_forall_with_multiple_result(%arg0: tensor<16xf32>, %arg1: tensor<8xf32>) -> (tensor<?xf32>, tensor<64xf32>) {
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %0 = tensor.empty(%c32) : tensor<?xf32>
+  %1 = tensor.empty() : tensor<64xf32>
+  %2:2 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0, %arg4 = %1) -> (tensor<?xf32>, tensor<64xf32>) {
+    %3 = arith.muli %c8, %arg2 : index
+    %4 = arith.muli %c16, %arg2 : index
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %arg0 into %arg4[%4] [16] [1] : tensor<16xf32> into tensor<64xf32>
+      tensor.parallel_insert_slice %arg1 into %arg3[%3] [8] [1] : tensor<8xf32> into tensor<?xf32>
+    }
+  }
+  return %2#0, %2#1 : tensor<?xf32>, tensor<64xf32>
+}
+
+// -----
+
 #map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
 #map1 = affine_map<(d0)[s0] -> (d0 * s0)>
 #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>

>From a557035f70f3cbe4fa0830656eddd8ad6204fff6 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Tue, 23 Dec 2025 21:20:02 +0800
Subject: [PATCH 2/3] expand auto

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 82bda1daa8e64..5194433019ec4 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2021,11 +2021,11 @@ struct FoldTensorCastOfOutputIntoForallOp
     // After `mergeBlocks` happened, the destinations in the terminator were
     // mapped to the tensor.cast old-typed results of the output bbArgs. The
     // destination have to be updated to point to the output bbArgs directly.
-    auto newOutputIterArgs = newForallOp.getRegionIterArgs();
+    ArrayRef<BlockArgument> newIterArgs = newForallOp.getRegionIterArgs();
     for (auto [yieldOp, iterArgsIndex] : yieldOpToIterArgsIndex) {
       auto parallelCombiningOp = cast<ParallelCombiningOpInterface>(yieldOp);
       parallelCombiningOp.getUpdatedDestinations().assign(
-          newOutputIterArgs[iterArgsIndex]);
+          newIterArgs[iterArgsIndex]);
     }
 
     // Cast results back to the original types.

>From e2535923f3a5092ebab3edd09e409f9cf93e1564 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 25 Dec 2025 21:26:44 +0800
Subject: [PATCH 3/3] error out

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5194433019ec4..87ee611bd63f0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1990,8 +1990,13 @@ struct FoldTensorCastOfOutputIntoForallOp
     for (auto [index, iterArg] :
          llvm::enumerate(forallOp.getRegionIterArgs())) {
       for (Operation *user : iterArg.getUsers()) {
-        if (isa<ParallelCombiningOpInterface>(user))
-          yieldOpToIterArgsIndex[user] = index;
+        if (isa<ParallelCombiningOpInterface>(user)) {
+          auto [it, inserted] = yieldOpToIterArgsIndex.try_emplace(user, index);
+          if (!inserted) {
+            return rewriter.notifyMatchFailure(
+                forallOp, "expected exactly one iter arg per yielding op");
+          }
+        }
       }
     }
 



More information about the Mlir-commits mailing list