[Mlir-commits] [mlir] [mlir][vector] Add support for yielding loop bounds in `scf.for` distribution. (PR #163443)
Charitha Saumya
llvmlistbot at llvm.org
Wed Oct 15 15:39:33 PDT 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/163443
>From d9ad36dfc4c47e64e2ca0170f13569c3e904dd79 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 14 Oct 2025 20:01:35 +0000
Subject: [PATCH 1/2] add fix
---
.../Vector/Transforms/VectorDistribute.cpp | 25 +++++++++----
.../Vector/vector-warp-distribute.mlir | 35 +++++++++++++++++++
2 files changed, 53 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f7d18be..2ee65dc0f902a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
// Newly created `WarpOp` will yield values in following order:
- // 1. All init args of the `ForOp`.
- // 2. All escaping values.
- // 3. All non-`ForOp` yielded values.
+ // 1. Loop bounds.
+ // 2. All init args of the `ForOp`.
+ // 3. All escaping values.
+ // 4. All non-`ForOp` yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
+ newWarpOpYieldValues.insert(
+ newWarpOpYieldValues.end(),
+ {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ {forOp.getLowerBound().getType(),
+ forOp.getUpperBound().getType(),
+ forOp.getStep().getType()});
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
newWarpOpYieldValues.push_back(initArg);
// Compute the distributed type for this init arg.
@@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
+ const unsigned initArgsStartIdx = 3; // After loop bounds.
const unsigned escapingValuesStartIdx =
+ initArgsStartIdx +
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
- for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+ for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(i));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
- rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
- forOp.getUnsignedCmp());
+ rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0),
+ /**UpperBound=**/ newWarpOp.getResult(1),
+ /**Step=**/ newWarpOp.getResult(2), newForOpOperands,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb7639204022f..ab87684dbb01a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
return
}
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_local_loop_bounds
+// CHECK-PROP: (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%{{.*}}: index):
+// CHECK-PROP: %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[T2]] : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
+// CHECK-PROP: %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
+// CHECK-PROP-SAME: args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP: ^bb0(%{{.*}}: vector<128xf32>):
+// CHECK-PROP: gpu.yield %{{.*}} : vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W2]] : vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
+// CHECK-PROP: return
+func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[32]
+ args(%bound : index) -> (vector<4xf32>) {
+ ^bb0(%arg1: index):
+ %ini = "some_def"() : () -> (vector<128xf32>)
+ %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+ %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc : vector<128xf32>
+ }
+ gpu.yield %3 : vector<128xf32>
+ }
+ "some_use"(%0) : (vector<4xf32>) -> ()
+ return
+}
+
// -----
// CHECK-PROP-LABEL: func @warp_scf_for_swap(
>From 7224394205570ee7381e202f1e5239480f25fc65 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 15 Oct 2025 22:39:07 +0000
Subject: [PATCH 2/2] fix
---
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7752f9e0a263b..7c019e7d25bf2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2087,15 +2087,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
- newForOpOperands.push_back(newWarpOp.getResult(i));
+ newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
- rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0),
- /**UpperBound=**/ newWarpOp.getResult(1),
- /**Step=**/ newWarpOp.getResult(2), newForOpOperands,
+ rewriter, forOp.getLoc(),
+ /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
+ /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
+ /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
/*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
More information about the Mlir-commits
mailing list