[Mlir-commits] [mlir] f7a5264 - [mlir][vector] Add support for yielding loop bounds in `scf.for` distribution. (#163443)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 17 09:07:21 PDT 2025
Author: Charitha Saumya
Date: 2025-10-17T09:07:17-07:00
New Revision: f7a5264890fe050124cd576410695a7c90c4d8d8
URL: https://github.com/llvm/llvm-project/commit/f7a5264890fe050124cd576410695a7c90c4d8d8
DIFF: https://github.com/llvm/llvm-project/commit/f7a5264890fe050124cd576410695a7c90c4d8d8.diff
LOG: [mlir][vector] Add support for yielding loop bounds in `scf.for` distribution. (#163443)
In some cases, loop bounds (lower, upper and step) of `scf.for` can come
locally from the parent warp op the `scf.for`. Current logic will not
yield the loop bounds in the new warp op generated during lowering
causing sinked `scf.for` to have non dominating use.
In this PR, we have added logic to yield loop bounds by default (treat
them as other operands of `scf.for`) which fixes this bug.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 12e6475fa66e3..7c019e7d25bf2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2032,11 +2032,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
// Newly created `WarpOp` will yield values in following order:
- // 1. All init args of the `ForOp`.
- // 2. All escaping values.
- // 3. All non-`ForOp` yielded values.
+ // 1. Loop bounds.
+ // 2. All init args of the `ForOp`.
+ // 3. All escaping values.
+ // 4. All non-`ForOp` yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
+ newWarpOpYieldValues.insert(
+ newWarpOpYieldValues.end(),
+ {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ {forOp.getLowerBound().getType(),
+ forOp.getUpperBound().getType(),
+ forOp.getStep().getType()});
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
newWarpOpYieldValues.push_back(initArg);
// Compute the distributed type for this init arg.
@@ -2072,20 +2080,24 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
+ const unsigned initArgsStartIdx = 3; // After loop bounds.
const unsigned escapingValuesStartIdx =
+ initArgsStartIdx +
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
- for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+ for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
- rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
- forOp.getUnsignedCmp());
+ rewriter, forOp.getLoc(),
+ /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
+ /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
+ /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 401cdd29b281c..0cf6dd151e16c 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
return
}
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_local_loop_bounds
+// CHECK-PROP: (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%{{.*}}: index):
+// CHECK-PROP: %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[T2]] : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
+// CHECK-PROP: %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
+// CHECK-PROP-SAME: args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%{{.*}}: vector<128xf32>):
+// CHECK-PROP: gpu.yield %{{.*}} : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W2]] : vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
+// CHECK-PROP: return
+func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[32]
+ args(%bound : index) -> (vector<4xf32>) {
+ ^bb0(%arg1: index):
+ %ini = "some_def"() : () -> (vector<128xf32>)
+ %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+ %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc : vector<128xf32>
+ }
+ gpu.yield %3 : vector<128xf32>
+ }
+ "some_use"(%0) : (vector<4xf32>) -> ()
+ return
+}
+
// -----
// CHECK-PROP-LABEL: func @warp_scf_for_swap(
More information about the Mlir-commits
mailing list