[Mlir-commits] [mlir] [MLIR][Vector] Fix `scf.for` block-argument yields in warp distribution (PR #192247)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 15 05:46:12 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Akimasa Watanuki (Men-cotton)

<details>
<summary>Changes</summary>

Teach WarpOpScfForOp to remap yielded `scf.for` body block arguments through `argMapping` before creating the replacement `gpu.yield`.

Handle yielded loop-carried values and other `scf.for` body block arguments after moving the loop body into the new inner warp op, instead of reusing the pre-merge values.

Add a regression test for yielding a loop-carried block argument during warp distribution.

Fix #<!-- -->186573

---
Full diff: https://github.com/llvm/llvm-project/pull/192247.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+7-1) 
- (added) mlir/test/Dialect/Vector/warp-distribute-scf-for-block-args.mlir (+22) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2e0e650f2bb9c..3f8bfeec48c87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2203,8 +2203,14 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
 
     argMapping.resize(forOp.getBody()->getNumArguments());
     SmallVector<Value> yieldOperands;
-    for (Value operand : forOp.getBody()->getTerminator()->getOperands())
+    for (Value operand : forOp.getBody()->getTerminator()->getOperands()) {
+      if (BlockArgument blockArg = dyn_cast<BlockArgument>(operand);
+          blockArg && blockArg.getOwner() == forOp.getBody()) {
+        yieldOperands.push_back(argMapping[blockArg.getArgNumber()]);
+        continue;
+      }
       yieldOperands.push_back(operand);
+    }
 
     rewriter.eraseOp(forOp.getBody()->getTerminator());
     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
diff --git a/mlir/test/Dialect/Vector/warp-distribute-scf-for-block-args.mlir b/mlir/test/Dialect/Vector/warp-distribute-scf-for-block-args.mlir
new file mode 100644
index 0000000000000..95a3795936340
--- /dev/null
+++ b/mlir/test/Dialect/Vector/warp-distribute-scf-for-block-args.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --test-vector-warp-distribute=propagate-distribution | FileCheck %s
+
+// Yielding a loop-carried block argument used to crash when sinking scf.for
+// out of gpu.warp_execute_on_lane_0.
+// CHECK-LABEL: func.func @warp_scf_for_yield_loop_carried_arg
+// CHECK-NOT: gpu.warp_execute_on_lane_0
+// CHECK: %[[FOR:.*]] = scf.for
+// CHECK-SAME: iter_args(%[[ARG:.*]] = %{{.*}})
+// CHECK:   scf.yield %[[ARG]]
+// CHECK: return %[[FOR]]
+func.func @warp_scf_for_yield_loop_carried_arg(%laneid: index) -> index {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %result = gpu.warp_execute_on_lane_0(%laneid)[32] -> (index) {
+    %loopResult =
+        scf.for %i = %c0 to %c1 step %c1 iter_args(%loopCarried = %c0) -> (index) {
+      scf.yield %loopCarried : index
+    }
+    gpu.yield %loopResult : index
+  }
+  return %result : index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/192247


More information about the Mlir-commits mailing list