[Mlir-commits] [mlir] [mlir][scf] Fix FoldTensorCastOfOutputIntoForallOp write order bug (PR #189162)
Mehdi Amini
llvmlistbot at llvm.org
Sat Mar 28 05:41:28 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189162
`FoldTensorCastOfOutputIntoForallOp` incorrectly updated the destinations of `tensor.parallel_insert_slice` ops in the `in_parallel` block by zipping `getYieldingOps()` with `getRegionIterArgs()` positionally. This assumed that the i-th yielding op writes to the i-th shared output, which is not required by the IR semantics. When slices are written to shared outputs in non-positional order, the canonicalization would silently reverse the write targets, producing incorrect output.
Fix by replacing the positional zip with a per-destination check: for each yielding op's destination operand, if it is a `tensor.cast` result whose source is one of the new `scf.forall` region iter args (i.e., a cast we introduced to bridge the type change), replace the destination with the cast's source directly. This correctly handles all orderings.
Add a regression test that exercises the multi-result case where `parallel_insert_slice` ops write to shared outputs in non-sequential order.
Fixes #172981
Assisted-by: Claude Code
>From a89111e04967c203488fa3472e06b74dc46570ca Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 04:03:10 -0700
Subject: [PATCH] [mlir][scf] Fix FoldTensorCastOfOutputIntoForallOp write
order bug
`FoldTensorCastOfOutputIntoForallOp` incorrectly updated the destinations
of `tensor.parallel_insert_slice` ops in the `in_parallel` block by
zipping `getYieldingOps()` with `getRegionIterArgs()` positionally. This
assumed that the i-th yielding op writes to the i-th shared output, which
is not required by the IR semantics. When slices are written to shared
outputs in non-positional order, the canonicalization would silently
reverse the write targets, producing incorrect output.
Fix by replacing the positional zip with a per-destination check: for
each yielding op's destination operand, if it is a `tensor.cast` result
whose source is one of the new `scf.forall` region iter args (i.e., a
cast we introduced to bridge the type change), replace the destination
with the cast's source directly. This correctly handles all orderings.
Add a regression test that exercises the multi-result case where
`parallel_insert_slice` ops write to shared outputs in non-sequential
order.
Fixes #172981
Assisted-by: Claude Code
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 28 ++++++++++++------
mlir/test/Dialect/SCF/canonicalize.mlir | 38 +++++++++++++++++++++++++
2 files changed, 58 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 76467154e869f..91375d2f56b3e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1756,15 +1756,27 @@ struct FoldTensorCastOfOutputIntoForallOp
bbArgs.front().getParentBlock(), ivsBlockArgs);
});
- // After `mergeBlocks` happened, the destinations in the terminator were
- // mapped to the tensor.cast old-typed results of the output bbArgs. The
- // destination have to be updated to point to the output bbArgs directly.
+ // After `mergeBlocks` happened, the destinations in the terminator may be
+ // mapped to tensor.cast values wrapping the new output bbArgs (introduced
+ // for indices in `tensorCastProducers`). Update those destinations to
+ // point directly to the output bbArgs, bypassing the casts.
+ //
+ // Note: we cannot zip yieldingOps with regionIterArgs by position because
+ // a parallel_insert_slice inside in_parallel may write to any shared
+ // output, not necessarily the one at the same position.
+ llvm::SmallDenseSet<Value> newIterArgSet(
+ newForallOp.getRegionIterArgs().begin(),
+ newForallOp.getRegionIterArgs().end());
auto terminator = newForallOp.getTerminator();
- for (auto [yieldingOp, outputBlockArg] : llvm::zip(
- terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
- if (auto parallelCombingingOp =
- dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
- parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
+ for (auto &yieldingOp : terminator.getYieldingOps()) {
+ auto parallelCombiningOp =
+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+ if (!parallelCombiningOp)
+ continue;
+ for (OpOperand &dest : parallelCombiningOp.getUpdatedDestinations()) {
+ auto castOp = dest.get().getDefiningOp<tensor::CastOp>();
+ if (castOp && newIterArgSet.contains(castOp.getSource()))
+ dest.set(castOp.getSource());
}
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 62dc2305b5857..c324d34942bf8 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2322,3 +2322,41 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in
}
return %res#0, %res#1, %res#2 : i32, i32, i32
}
+
+// -----
+
+// Test that FoldTensorCastOfOutputIntoForallOp correctly handles the case
+// where parallel_insert_slice ops write to shared outputs in a non-sequential
+// order. The fix ensures that each yielding op's destination is matched to the
+// correct regionIterArg based on what it actually writes to, not positionally.
+// CHECK-LABEL: func @fold_tensor_cast_into_forall_non_sequential_writes
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x32xf32>, %[[ARG1:.*]]: tensor<8x32xf32>)
+// CHECK: %[[FORALL:.*]]:2 = scf.forall
+// CHECK-SAME: shared_outs(%[[ITER0:.*]] = {{.*}}, %[[ITER1:.*]] = {{.*}}) -> (tensor<32x32xf32>, tensor<32x32xf32>)
+// CHECK: scf.forall.in_parallel {
+// CHECK-NEXT: tensor.parallel_insert_slice %[[ARG0]] into %[[ITER1]]
+// CHECK-NEXT: tensor.parallel_insert_slice %[[ARG1]] into %[[ITER0]]
+// CHECK: }
+// CHECK: %[[CAST0:.*]] = tensor.cast %[[FORALL]]#0 : tensor<32x32xf32> to tensor<?x32xf32>
+// CHECK: %[[CAST1:.*]] = tensor.cast %[[FORALL]]#1 : tensor<32x32xf32> to tensor<?x32xf32>
+// CHECK: return %[[CAST0]], %[[CAST1]]
+func.func @fold_tensor_cast_into_forall_non_sequential_writes(
+ %arg0: tensor<8x32xf32>, %arg1: tensor<8x32xf32>) -> (tensor<?x32xf32>, tensor<?x32xf32>) {
+ %c8 = arith.constant 8 : index
+ %c32 = arith.constant 32 : index
+ %init = tensor.empty(%c32) : tensor<?x32xf32>
+ %0:2 = scf.forall (%tidx) in (4) shared_outs(%arg2 = %init, %arg3 = %init)
+ -> (tensor<?x32xf32>, tensor<?x32xf32>) {
+ %pos = arith.muli %c8, %tidx : index
+ scf.forall.in_parallel {
+ // Write %arg0 to %arg3 (second shared output).
+ tensor.parallel_insert_slice %arg0 into %arg3[%pos, 0] [8, 32] [1, 1]
+ : tensor<8x32xf32> into tensor<?x32xf32>
+ // Write %arg1 to %arg2 (first shared output).
+ tensor.parallel_insert_slice %arg1 into %arg2[%pos, 0] [8, 32] [1, 1]
+ : tensor<8x32xf32> into tensor<?x32xf32>
+ }
+ }
+ // %0#0 contains %arg1 data; %0#1 contains %arg0 data.
+ return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
+}
More information about the Mlir-commits
mailing list