[Mlir-commits] [mlir] [mlir][scf] Fix `FoldTensorCastOfOutputIntoForallOp` for multi-result scf.forall (PR #173271)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 22 07:31:13 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR fixes a bug in `FoldTensorCastOfOutputIntoForallOp` where incorrect folding occurs when `scf.forall` has multiple results and `parallel_insert_slice` operations yield results in a non-ascending order. The fix introduces a `yieldOpToIterArgsIndex` mapping to correctly associate yield operations with their corresponding output iterator arguments. Fixes #<!-- -->172981.
---
Full diff: https://github.com/llvm/llvm-project/pull/173271.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+14-7)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+37)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..82bda1daa8e64 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1986,6 +1986,15 @@ struct FoldTensorCastOfOutputIntoForallOp
if (tensorCastProducers.empty())
return failure();
+ llvm::SmallMapVector<Operation *, int64_t, 2> yieldOpToIterArgsIndex;
+ for (auto [index, iterArg] :
+ llvm::enumerate(forallOp.getRegionIterArgs())) {
+ for (Operation *user : iterArg.getUsers()) {
+ if (isa<ParallelCombiningOpInterface>(user))
+ yieldOpToIterArgsIndex[user] = index;
+ }
+ }
+
// Create new loop.
Location loc = forallOp.getLoc();
auto newForallOp = ForallOp::create(
@@ -2012,13 +2021,11 @@ struct FoldTensorCastOfOutputIntoForallOp
// After `mergeBlocks` happened, the destinations in the terminator were
// mapped to the tensor.cast old-typed results of the output bbArgs. The
// destination have to be updated to point to the output bbArgs directly.
- auto terminator = newForallOp.getTerminator();
- for (auto [yieldingOp, outputBlockArg] : llvm::zip(
- terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
- if (auto parallelCombingingOp =
- dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
- parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
- }
+ auto newOutputIterArgs = newForallOp.getRegionIterArgs();
+ for (auto [yieldOp, iterArgsIndex] : yieldOpToIterArgsIndex) {
+ auto parallelCombiningOp = cast<ParallelCombiningOpInterface>(yieldOp);
+ parallelCombiningOp.getUpdatedDestinations().assign(
+ newOutputIterArgs[iterArgsIndex]);
}
// Cast results back to the original types.
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..3cd018d4729cf 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2028,6 +2028,43 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
// -----
+// CHECK-LABEL: func.func @fold_tensor_cast_into_forall_with_multiple_result(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<16xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<8xf32>) -> (tensor<?xf32>, tensor<64xf32>) {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 16 : index
+// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<32xf32>
+// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FORALL_0:.*]]:2 = scf.forall (%[[VAL_0:.*]]) in (4) shared_outs(%[[VAL_1:.*]] = %[[EMPTY_0]], %[[VAL_2:.*]] = %[[EMPTY_1]]) -> (tensor<32xf32>, tensor<64xf32>) {
+// CHECK: %[[MULI_0:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_0]] : index
+// CHECK: %[[MULI_1:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_1]] : index
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ARG0]] into %[[VAL_2]]{{\[}}%[[MULI_1]]] [16] [1] : tensor<16xf32> into tensor<64xf32>
+// CHECK: tensor.parallel_insert_slice %[[ARG1]] into %[[VAL_1]]{{\[}}%[[MULI_0]]] [8] [1] : tensor<8xf32> into tensor<32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[CAST_0:.*]] = tensor.cast %[[FORALL_0]]#0 : tensor<32xf32> to tensor<?xf32>
+// CHECK: return %[[CAST_0]], %[[FORALL_0]]#1 : tensor<?xf32>, tensor<64xf32>
+// CHECK: }
+func.func @fold_tensor_cast_into_forall_with_multiple_result(%arg0: tensor<16xf32>, %arg1: tensor<8xf32>) -> (tensor<?xf32>, tensor<64xf32>) {
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %0 = tensor.empty(%c32) : tensor<?xf32>
+ %1 = tensor.empty() : tensor<64xf32>
+ %2:2 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0, %arg4 = %1) -> (tensor<?xf32>, tensor<64xf32>) {
+ %3 = arith.muli %c8, %arg2 : index
+ %4 = arith.muli %c16, %arg2 : index
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %arg0 into %arg4[%4] [16] [1] : tensor<16xf32> into tensor<64xf32>
+ tensor.parallel_insert_slice %arg1 into %arg3[%3] [8] [1] : tensor<8xf32> into tensor<?xf32>
+ }
+ }
+ return %2#0, %2#1 : tensor<?xf32>, tensor<64xf32>
+}
+
+// -----
+
#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
``````````
</details>
https://github.com/llvm/llvm-project/pull/173271
More information about the Mlir-commits
mailing list