[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `vector.step` op (PR #155425)
Artem Kroviakov
llvmlistbot at llvm.org
Wed Aug 27 02:28:34 PDT 2025
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/155425
>From 2b24134b8ab269cd7c4aba149a3e00b01955d0ad Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 26 Aug 2025 14:41:21 +0000
Subject: [PATCH 1/2] [MLIR][Vector] Step op warp distribution
---
.../Vector/Transforms/VectorDistribute.cpp | 41 ++++++++++++++++++-
.../Vector/vector-warp-distribute.mlir | 15 +++++++
2 files changed, 55 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 60aa0e9bae64a..df55bf083cf7b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -705,6 +705,45 @@ struct WarpOpConstant : public WarpDistributionPattern {
}
};
+/// Sink out step op feeding into a warp op yield.
+/// Vector step op is treated similar to arith.constant, apart from
+/// the result that represents a sequence [0, vec_size).
+/// The sequence is semantically equivalent to warp's threads/lanes indices.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
+/// ...
+/// %cst = vector.step : vector<32xindex>
+/// gpu.yield %cst : vector<1xindex>
+/// }
+/// ```
+/// To
+/// ```
+/// gpu.warp_execute_on_lane_0(%arg0) {
+/// ...
+/// }
+/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
+struct WarpOpStep final : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
+ if (!yieldOperand)
+ return failure();
+ auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
+ VectorType resTy = stepOp.getResult().getType();
+ rewriter.startOpModification(warpOp);
+ rewriter.setInsertionPointAfter(warpOp);
+ Value laneIdVec = vector::BroadcastOp::create(
+ rewriter, warpOp.getLoc(), VectorType::get({1}, resTy.getElementType()),
+ warpOp.getLaneid());
+ const unsigned operandIdx = yieldOperand->getOperandNumber();
+ rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
+ rewriter.finalizeOpModification(warpOp);
+ return success();
+ }
+};
+
/// Sink out transfer_read op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -2016,7 +2055,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
- WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+ WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 4d2c964a6df3c..96becb44777eb 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1824,3 +1824,18 @@ func.func @warp_propagate_duplicated_operands_in_yield(%laneid: index) {
// CHECK-PROP : }
// CHECK-PROP : %[T1:.*] = math.exp %[[W]] : vector<1xf32>
// CHECK-PROP : "some_use"(%[[T1]]) : (vector<1xf32>) -> ()
+
+// -----
+
+func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xindex>) {
+ %seq = vector.step : vector<32xindex>
+ gpu.yield %seq : vector<32xindex>
+ }
+ vector.transfer_write %r, %buffer[%laneid] : vector<1xindex>, memref<128xindex>
+ return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_step_distribute
+// CHECK-PROP: %[[LANE_ID_VEC:.*]] = vector.broadcast %{{.*}} : index to vector<1xindex>
+// CHECK-PROP: vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>
>From 05438510396b3bd1f6d418d69684cbcd9c194cc0 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 27 Aug 2025 09:28:13 +0000
Subject: [PATCH 2/2] Address feedback
---
.../Vector/Transforms/VectorDistribute.cpp | 23 ++++++++++++-------
.../Vector/vector-warp-distribute.mlir | 6 +++--
2 files changed, 19 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index df55bf083cf7b..c19c22b92b344 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -708,7 +708,6 @@ struct WarpOpConstant : public WarpDistributionPattern {
/// Sink out step op feeding into a warp op yield.
/// Vector step op is treated similar to arith.constant, apart from
/// the result that represents a sequence [0, vec_size).
-/// The sequence is semantically equivalent to warp's threads/lanes indices.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
/// ...
@@ -730,16 +729,24 @@ struct WarpOpStep final : public WarpDistributionPattern {
getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
if (!yieldOperand)
return failure();
+ const unsigned operandIdx = yieldOperand->getOperandNumber();
auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
VectorType resTy = stepOp.getResult().getType();
- rewriter.startOpModification(warpOp);
+ if (resTy.getNumElements() != warpOp.getWarpSize())
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
+ resTy.getNumElements(), warpOp.getWarpSize()));
+ VectorType newVecTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
rewriter.setInsertionPointAfter(warpOp);
- Value laneIdVec = vector::BroadcastOp::create(
- rewriter, warpOp.getLoc(), VectorType::get({1}, resTy.getElementType()),
- warpOp.getLaneid());
- const unsigned operandIdx = yieldOperand->getOperandNumber();
- rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
- rewriter.finalizeOpModification(warpOp);
+ auto loc = warpOp.getLoc();
+ Value stepResult =
+ vector::StepOp::create(rewriter, warpOp.getLoc(), newVecTy);
+ Value laneId = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
+ newVecTy, warpOp.getLaneid());
+ stepResult = rewriter.create<arith::AddIOp>(loc, stepResult, laneId);
+ rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), stepResult);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 96becb44777eb..f7cf4d9638bdf 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1837,5 +1837,7 @@ func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>) {
}
// CHECK-PROP-LABEL: func.func @warp_step_distribute
-// CHECK-PROP: %[[LANE_ID_VEC:.*]] = vector.broadcast %{{.*}} : index to vector<1xindex>
-// CHECK-PROP: vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>
+// CHECK-PROP: %[[DISTRIBUTED_STEP:.*]] = vector.step : vector<1xindex>
+// CHECK-PROP: %[[LANE_ID_VEC:.*]] = vector.broadcast %arg0 : index to vector<1xindex>
+// CHECK-PROP: %[[LANE_STEP:.*]] = arith.addi %[[DISTRIBUTED_STEP]], %[[LANE_ID_VEC]] : vector<1xindex>
+// CHECK-PROP: vector.transfer_write %[[LANE_STEP]], %{{.*}} : vector<1xindex>, memref<128xindex>
More information about the Mlir-commits
mailing list