[Mlir-commits] [mlir] [mlir][scf] Fix FoldTensorCastOfOutputIntoForallOp write order bug (PR #189162)

Mehdi Amini llvmlistbot at llvm.org
Sat Mar 28 05:41:28 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189162

`FoldTensorCastOfOutputIntoForallOp` incorrectly updated the destinations of `tensor.parallel_insert_slice` ops in the `in_parallel` block by zipping `getYieldingOps()` with `getRegionIterArgs()` positionally. This assumed that the i-th yielding op writes to the i-th shared output, which is not required by the IR semantics. When slices are written to shared outputs in non-positional order, the canonicalization would silently reverse the write targets, producing incorrect output.

Fix by replacing the positional zip with a per-destination check: for each yielding op's destination operand, if it is a `tensor.cast` result whose source is one of the new `scf.forall` region iter args (i.e., a cast we introduced to bridge the type change), replace the destination with the cast's source directly. This correctly handles all orderings.

Add a regression test that exercises the multi-result case where `parallel_insert_slice` ops write to shared outputs in non-sequential order.

Fixes #172981

Assisted-by: Claude Code

>From a89111e04967c203488fa3472e06b74dc46570ca Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 04:03:10 -0700
Subject: [PATCH] [mlir][scf] Fix FoldTensorCastOfOutputIntoForallOp write
 order bug

`FoldTensorCastOfOutputIntoForallOp` incorrectly updated the destinations
of `tensor.parallel_insert_slice` ops in the `in_parallel` block by
zipping `getYieldingOps()` with `getRegionIterArgs()` positionally. This
assumed that the i-th yielding op writes to the i-th shared output, which
is not required by the IR semantics. When slices are written to shared
outputs in non-positional order, the canonicalization would silently
reverse the write targets, producing incorrect output.

Fix by replacing the positional zip with a per-destination check: for
each yielding op's destination operand, if it is a `tensor.cast` result
whose source is one of the new `scf.forall` region iter args (i.e., a
cast we introduced to bridge the type change), replace the destination
with the cast's source directly. This correctly handles all orderings.

Add a regression test that exercises the multi-result case where
`parallel_insert_slice` ops write to shared outputs in non-sequential
order.

Fixes #172981

Assisted-by: Claude Code
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 28 ++++++++++++------
 mlir/test/Dialect/SCF/canonicalize.mlir | 38 +++++++++++++++++++++++++
 2 files changed, 58 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 76467154e869f..91375d2f56b3e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1756,15 +1756,27 @@ struct FoldTensorCastOfOutputIntoForallOp
                                bbArgs.front().getParentBlock(), ivsBlockArgs);
         });
 
-    // After `mergeBlocks` happened, the destinations in the terminator were
-    // mapped to the tensor.cast old-typed results of the output bbArgs. The
-    // destination have to be updated to point to the output bbArgs directly.
+    // After `mergeBlocks` happened, the destinations in the terminator may be
+    // mapped to tensor.cast values wrapping the new output bbArgs (introduced
+    // for indices in `tensorCastProducers`). Update those destinations to
+    // point directly to the output bbArgs, bypassing the casts.
+    //
+    // Note: we cannot zip yieldingOps with regionIterArgs by position because
+    // a parallel_insert_slice inside in_parallel may write to any shared
+    // output, not necessarily the one at the same position.
+    llvm::SmallDenseSet<Value> newIterArgSet(
+        newForallOp.getRegionIterArgs().begin(),
+        newForallOp.getRegionIterArgs().end());
     auto terminator = newForallOp.getTerminator();
-    for (auto [yieldingOp, outputBlockArg] : llvm::zip(
-             terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
-      if (auto parallelCombingingOp =
-              dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
-        parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
+    for (auto &yieldingOp : terminator.getYieldingOps()) {
+      auto parallelCombiningOp =
+          dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+      if (!parallelCombiningOp)
+        continue;
+      for (OpOperand &dest : parallelCombiningOp.getUpdatedDestinations()) {
+        auto castOp = dest.get().getDefiningOp<tensor::CastOp>();
+        if (castOp && newIterArgSet.contains(castOp.getSource()))
+          dest.set(castOp.getSource());
       }
     }
 
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 62dc2305b5857..c324d34942bf8 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2322,3 +2322,41 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in
   }
   return %res#0, %res#1, %res#2 : i32, i32, i32
 }
+
+// -----
+
+// Test that FoldTensorCastOfOutputIntoForallOp correctly handles the case
+// where parallel_insert_slice ops write to shared outputs in a non-sequential
+// order. The fix ensures that each yielding op's destination is matched to the
+// correct regionIterArg based on what it actually writes to, not positionally.
+// CHECK-LABEL: func @fold_tensor_cast_into_forall_non_sequential_writes
+//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<8x32xf32>, %[[ARG1:.*]]: tensor<8x32xf32>)
+//       CHECK:   %[[FORALL:.*]]:2 = scf.forall
+//  CHECK-SAME:     shared_outs(%[[ITER0:.*]] = {{.*}}, %[[ITER1:.*]] = {{.*}}) -> (tensor<32x32xf32>, tensor<32x32xf32>)
+//       CHECK:     scf.forall.in_parallel {
+// CHECK-NEXT:        tensor.parallel_insert_slice %[[ARG0]] into %[[ITER1]]
+// CHECK-NEXT:        tensor.parallel_insert_slice %[[ARG1]] into %[[ITER0]]
+//       CHECK:     }
+//       CHECK:   %[[CAST0:.*]] = tensor.cast %[[FORALL]]#0 : tensor<32x32xf32> to tensor<?x32xf32>
+//       CHECK:   %[[CAST1:.*]] = tensor.cast %[[FORALL]]#1 : tensor<32x32xf32> to tensor<?x32xf32>
+//       CHECK:   return %[[CAST0]], %[[CAST1]]
+func.func @fold_tensor_cast_into_forall_non_sequential_writes(
+    %arg0: tensor<8x32xf32>, %arg1: tensor<8x32xf32>) -> (tensor<?x32xf32>, tensor<?x32xf32>) {
+  %c8 = arith.constant 8 : index
+  %c32 = arith.constant 32 : index
+  %init = tensor.empty(%c32) : tensor<?x32xf32>
+  %0:2 = scf.forall (%tidx) in (4) shared_outs(%arg2 = %init, %arg3 = %init)
+      -> (tensor<?x32xf32>, tensor<?x32xf32>) {
+    %pos = arith.muli %c8, %tidx : index
+    scf.forall.in_parallel {
+      // Write %arg0 to %arg3 (second shared output).
+      tensor.parallel_insert_slice %arg0 into %arg3[%pos, 0] [8, 32] [1, 1]
+          : tensor<8x32xf32> into tensor<?x32xf32>
+      // Write %arg1 to %arg2 (first shared output).
+      tensor.parallel_insert_slice %arg1 into %arg2[%pos, 0] [8, 32] [1, 1]
+          : tensor<8x32xf32> into tensor<?x32xf32>
+    }
+  }
+  // %0#0 contains %arg1 data; %0#1 contains %arg0 data.
+  return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
+}



More information about the Mlir-commits mailing list