[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `vector.step` op (PR #155425)

Artem Kroviakov llvmlistbot at llvm.org
Wed Aug 27 02:28:34 PDT 2025


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/155425

>From 2b24134b8ab269cd7c4aba149a3e00b01955d0ad Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 26 Aug 2025 14:41:21 +0000
Subject: [PATCH 1/2] [MLIR][Vector] Step op warp distribution

---
 .../Vector/Transforms/VectorDistribute.cpp    | 41 ++++++++++++++++++-
 .../Vector/vector-warp-distribute.mlir        | 15 +++++++
 2 files changed, 55 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 60aa0e9bae64a..df55bf083cf7b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -705,6 +705,45 @@ struct WarpOpConstant : public WarpDistributionPattern {
   }
 };
 
+/// Sink out step op feeding into a warp op yield.
+/// Vector step op is treated similar to arith.constant, apart from
+/// the result that represents a sequence [0, vec_size).
+/// The sequence is semantically equivalent to warp's threads/lanes indices.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
+///   ...
+///   %cst = vector.step : vector<32xindex>
+///   gpu.yield %cst : vector<1xindex>
+/// }
+/// ```
+/// To
+/// ```
+/// gpu.warp_execute_on_lane_0(%arg0) {
+///   ...
+/// }
+/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
+struct WarpOpStep final : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
+    if (!yieldOperand)
+      return failure();
+    auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
+    VectorType resTy = stepOp.getResult().getType();
+    rewriter.startOpModification(warpOp);
+    rewriter.setInsertionPointAfter(warpOp);
+    Value laneIdVec = vector::BroadcastOp::create(
+        rewriter, warpOp.getLoc(), VectorType::get({1}, resTy.getElementType()),
+        warpOp.getLaneid());
+    const unsigned operandIdx = yieldOperand->getOperandNumber();
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
+    rewriter.finalizeOpModification(warpOp);
+    return success();
+  }
+};
+
 /// Sink out transfer_read op feeding into a warp op yield.
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -2016,7 +2055,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
       .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
            WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
            WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
-           WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+           WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
           patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 4d2c964a6df3c..96becb44777eb 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1824,3 +1824,18 @@ func.func @warp_propagate_duplicated_operands_in_yield(%laneid: index)  {
 // CHECK-PROP       :   }
 // CHECK-PROP       :   %[T1:.*] = math.exp %[[W]] : vector<1xf32>
 // CHECK-PROP       :   "some_use"(%[[T1]]) : (vector<1xf32>) -> ()
+
+// -----
+
+func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xindex>) {
+    %seq = vector.step : vector<32xindex>
+    gpu.yield %seq : vector<32xindex>
+  }
+  vector.transfer_write %r, %buffer[%laneid] : vector<1xindex>, memref<128xindex>
+  return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_step_distribute
+//       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %{{.*}} : index to vector<1xindex>
+//       CHECK-PROP:   vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>

>From 05438510396b3bd1f6d418d69684cbcd9c194cc0 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 27 Aug 2025 09:28:13 +0000
Subject: [PATCH 2/2] Address feedback

---
 .../Vector/Transforms/VectorDistribute.cpp    | 23 ++++++++++++-------
 .../Vector/vector-warp-distribute.mlir        |  6 +++--
 2 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index df55bf083cf7b..c19c22b92b344 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -708,7 +708,6 @@ struct WarpOpConstant : public WarpDistributionPattern {
 /// Sink out step op feeding into a warp op yield.
 /// Vector step op is treated similar to arith.constant, apart from
 /// the result that represents a sequence [0, vec_size).
-/// The sequence is semantically equivalent to warp's threads/lanes indices.
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
 ///   ...
@@ -730,16 +729,24 @@ struct WarpOpStep final : public WarpDistributionPattern {
         getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
     if (!yieldOperand)
       return failure();
+    const unsigned operandIdx = yieldOperand->getOperandNumber();
     auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
     VectorType resTy = stepOp.getResult().getType();
-    rewriter.startOpModification(warpOp);
+    if (resTy.getNumElements() != warpOp.getWarpSize())
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
+                        resTy.getNumElements(), warpOp.getWarpSize()));
+    VectorType newVecTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
     rewriter.setInsertionPointAfter(warpOp);
-    Value laneIdVec = vector::BroadcastOp::create(
-        rewriter, warpOp.getLoc(), VectorType::get({1}, resTy.getElementType()),
-        warpOp.getLaneid());
-    const unsigned operandIdx = yieldOperand->getOperandNumber();
-    rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
-    rewriter.finalizeOpModification(warpOp);
+    auto loc = warpOp.getLoc();
+    Value stepResult =
+        vector::StepOp::create(rewriter, warpOp.getLoc(), newVecTy);
+    Value laneId = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
+                                               newVecTy, warpOp.getLaneid());
+    stepResult = rewriter.create<arith::AddIOp>(loc, stepResult, laneId);
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), stepResult);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 96becb44777eb..f7cf4d9638bdf 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1837,5 +1837,7 @@ func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
 }
 
 // CHECK-PROP-LABEL: func.func @warp_step_distribute
-//       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %{{.*}} : index to vector<1xindex>
-//       CHECK-PROP:   vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>
+//       CHECK-PROP:   %[[DISTRIBUTED_STEP:.*]] = vector.step : vector<1xindex>
+//       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %arg0 : index to vector<1xindex>
+//       CHECK-PROP:   %[[LANE_STEP:.*]] = arith.addi %[[DISTRIBUTED_STEP]], %[[LANE_ID_VEC]] : vector<1xindex>
+//       CHECK-PROP:   vector.transfer_write %[[LANE_STEP]], %{{.*}} : vector<1xindex>, memref<128xindex>



More information about the Mlir-commits mailing list