[Mlir-commits] [mlir] [mlir][scf] Fix `FoldTensorCastOfOutputIntoForallOp` for multi-result scf.forall (PR #173271)
Longsheng Mou
llvmlistbot at llvm.org
Thu Dec 25 05:32:37 PST 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173271
>From e2407ddc5b807b26aef076cf51410e4a92fcceda Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 22 Dec 2025 23:24:56 +0800
Subject: [PATCH 1/3] [mlir][scf] Fix FoldTensorCastOfOutputIntoForallOp for
multi-result scf.forall
This PR fixes a bug in FoldTensorCastOfOutputIntoForallOp where incorrect folding occurs when `scf.forall` has multiple results and `parallel_insert_slice` operations yield results in a non-ascending order. The fix introduces a `yieldOpToIterArgsIndex` mapping to correctly associate yield operations with their corresponding output iterator arguments.
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 21 +++++++++-----
mlir/test/Dialect/SCF/canonicalize.mlir | 37 +++++++++++++++++++++++++
2 files changed, 51 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..82bda1daa8e64 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1986,6 +1986,15 @@ struct FoldTensorCastOfOutputIntoForallOp
if (tensorCastProducers.empty())
return failure();
+ llvm::SmallMapVector<Operation *, int64_t, 2> yieldOpToIterArgsIndex;
+ for (auto [index, iterArg] :
+ llvm::enumerate(forallOp.getRegionIterArgs())) {
+ for (Operation *user : iterArg.getUsers()) {
+ if (isa<ParallelCombiningOpInterface>(user))
+ yieldOpToIterArgsIndex[user] = index;
+ }
+ }
+
// Create new loop.
Location loc = forallOp.getLoc();
auto newForallOp = ForallOp::create(
@@ -2012,13 +2021,11 @@ struct FoldTensorCastOfOutputIntoForallOp
// After `mergeBlocks` happened, the destinations in the terminator were
// mapped to the tensor.cast old-typed results of the output bbArgs. The
// destination have to be updated to point to the output bbArgs directly.
- auto terminator = newForallOp.getTerminator();
- for (auto [yieldingOp, outputBlockArg] : llvm::zip(
- terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
- if (auto parallelCombingingOp =
- dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
- parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
- }
+ auto newOutputIterArgs = newForallOp.getRegionIterArgs();
+ for (auto [yieldOp, iterArgsIndex] : yieldOpToIterArgsIndex) {
+ auto parallelCombiningOp = cast<ParallelCombiningOpInterface>(yieldOp);
+ parallelCombiningOp.getUpdatedDestinations().assign(
+ newOutputIterArgs[iterArgsIndex]);
}
// Cast results back to the original types.
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..3cd018d4729cf 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2028,6 +2028,43 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
// -----
+// CHECK-LABEL: func.func @fold_tensor_cast_into_forall_with_multiple_result(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<16xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<8xf32>) -> (tensor<?xf32>, tensor<64xf32>) {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 16 : index
+// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<32xf32>
+// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FORALL_0:.*]]:2 = scf.forall (%[[VAL_0:.*]]) in (4) shared_outs(%[[VAL_1:.*]] = %[[EMPTY_0]], %[[VAL_2:.*]] = %[[EMPTY_1]]) -> (tensor<32xf32>, tensor<64xf32>) {
+// CHECK: %[[MULI_0:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_0]] : index
+// CHECK: %[[MULI_1:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_1]] : index
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ARG0]] into %[[VAL_2]]{{\[}}%[[MULI_1]]] [16] [1] : tensor<16xf32> into tensor<64xf32>
+// CHECK: tensor.parallel_insert_slice %[[ARG1]] into %[[VAL_1]]{{\[}}%[[MULI_0]]] [8] [1] : tensor<8xf32> into tensor<32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[CAST_0:.*]] = tensor.cast %[[FORALL_0]]#0 : tensor<32xf32> to tensor<?xf32>
+// CHECK: return %[[CAST_0]], %[[FORALL_0]]#1 : tensor<?xf32>, tensor<64xf32>
+// CHECK: }
+func.func @fold_tensor_cast_into_forall_with_multiple_result(%arg0: tensor<16xf32>, %arg1: tensor<8xf32>) -> (tensor<?xf32>, tensor<64xf32>) {
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %0 = tensor.empty(%c32) : tensor<?xf32>
+ %1 = tensor.empty() : tensor<64xf32>
+ %2:2 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0, %arg4 = %1) -> (tensor<?xf32>, tensor<64xf32>) {
+ %3 = arith.muli %c8, %arg2 : index
+ %4 = arith.muli %c16, %arg2 : index
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %arg0 into %arg4[%4] [16] [1] : tensor<16xf32> into tensor<64xf32>
+ tensor.parallel_insert_slice %arg1 into %arg3[%3] [8] [1] : tensor<8xf32> into tensor<?xf32>
+ }
+ }
+ return %2#0, %2#1 : tensor<?xf32>, tensor<64xf32>
+}
+
+// -----
+
#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
>From a557035f70f3cbe4fa0830656eddd8ad6204fff6 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Tue, 23 Dec 2025 21:20:02 +0800
Subject: [PATCH 2/3] expand auto
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 82bda1daa8e64..5194433019ec4 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2021,11 +2021,11 @@ struct FoldTensorCastOfOutputIntoForallOp
// After `mergeBlocks` happened, the destinations in the terminator were
// mapped to the tensor.cast old-typed results of the output bbArgs. The
// destination have to be updated to point to the output bbArgs directly.
- auto newOutputIterArgs = newForallOp.getRegionIterArgs();
+ ArrayRef<BlockArgument> newIterArgs = newForallOp.getRegionIterArgs();
for (auto [yieldOp, iterArgsIndex] : yieldOpToIterArgsIndex) {
auto parallelCombiningOp = cast<ParallelCombiningOpInterface>(yieldOp);
parallelCombiningOp.getUpdatedDestinations().assign(
- newOutputIterArgs[iterArgsIndex]);
+ newIterArgs[iterArgsIndex]);
}
// Cast results back to the original types.
>From e2535923f3a5092ebab3edd09e409f9cf93e1564 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 25 Dec 2025 21:26:44 +0800
Subject: [PATCH 3/3] error out
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5194433019ec4..87ee611bd63f0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1990,8 +1990,13 @@ struct FoldTensorCastOfOutputIntoForallOp
for (auto [index, iterArg] :
llvm::enumerate(forallOp.getRegionIterArgs())) {
for (Operation *user : iterArg.getUsers()) {
- if (isa<ParallelCombiningOpInterface>(user))
- yieldOpToIterArgsIndex[user] = index;
+ if (isa<ParallelCombiningOpInterface>(user)) {
+ auto [it, inserted] = yieldOpToIterArgsIndex.try_emplace(user, index);
+ if (!inserted) {
+ return rewriter.notifyMatchFailure(
+ forallOp, "expected exactly one iter arg per yielding op");
+ }
+ }
}
}
More information about the Mlir-commits
mailing list