[Mlir-commits] [mlir] [MLIR][Vector] Fix `scf.for` block-argument yields in warp distribution (PR #192247)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 15 05:46:12 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Akimasa Watanuki (Men-cotton)
<details>
<summary>Changes</summary>
Teach WarpOpScfForOp to remap yielded `scf.for` body block arguments through `argMapping` before creating the replacement `gpu.yield`.
Handle yielded loop-carried values and other `scf.for` body block arguments after moving the loop body into the new inner warp op, instead of reusing the pre-merge values.
Add a regression test for yielding a loop-carried block argument during warp distribution.
Fix #<!-- -->186573
---
Full diff: https://github.com/llvm/llvm-project/pull/192247.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+7-1)
- (added) mlir/test/Dialect/Vector/warp-distribute-scf-for-block-args.mlir (+22)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2e0e650f2bb9c..3f8bfeec48c87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2203,8 +2203,14 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
argMapping.resize(forOp.getBody()->getNumArguments());
SmallVector<Value> yieldOperands;
- for (Value operand : forOp.getBody()->getTerminator()->getOperands())
+ for (Value operand : forOp.getBody()->getTerminator()->getOperands()) {
+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(operand);
+ blockArg && blockArg.getOwner() == forOp.getBody()) {
+ yieldOperands.push_back(argMapping[blockArg.getArgNumber()]);
+ continue;
+ }
yieldOperands.push_back(operand);
+ }
rewriter.eraseOp(forOp.getBody()->getTerminator());
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
diff --git a/mlir/test/Dialect/Vector/warp-distribute-scf-for-block-args.mlir b/mlir/test/Dialect/Vector/warp-distribute-scf-for-block-args.mlir
new file mode 100644
index 0000000000000..95a3795936340
--- /dev/null
+++ b/mlir/test/Dialect/Vector/warp-distribute-scf-for-block-args.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --test-vector-warp-distribute=propagate-distribution | FileCheck %s
+
+// Yielding a loop-carried block argument used to crash when sinking scf.for
+// out of gpu.warp_execute_on_lane_0.
+// CHECK-LABEL: func.func @warp_scf_for_yield_loop_carried_arg
+// CHECK-NOT: gpu.warp_execute_on_lane_0
+// CHECK: %[[FOR:.*]] = scf.for
+// CHECK-SAME: iter_args(%[[ARG:.*]] = %{{.*}})
+// CHECK: scf.yield %[[ARG]]
+// CHECK: return %[[FOR]]
+func.func @warp_scf_for_yield_loop_carried_arg(%laneid: index) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %result = gpu.warp_execute_on_lane_0(%laneid)[32] -> (index) {
+ %loopResult =
+ scf.for %i = %c0 to %c1 step %c1 iter_args(%loopCarried = %c0) -> (index) {
+ scf.yield %loopCarried : index
+ }
+ gpu.yield %loopResult : index
+ }
+ return %result : index
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/192247
More information about the Mlir-commits
mailing list