[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `vector.step` op (PR #155425)

Artem Kroviakov llvmlistbot at llvm.org
Thu Aug 28 02:04:15 PDT 2025


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/155425

>From 27db45ed838e64f13f419f132ee0d628b2a30133 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 26 Aug 2025 14:41:21 +0000
Subject: [PATCH 1/4] [MLIR][Vector] Step op warp distribution

---
 .../Vector/Transforms/VectorDistribute.cpp    | 41 ++++++++++++++++++-
 .../Vector/vector-warp-distribute.mlir        | 15 +++++++
 2 files changed, 55 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 60aa0e9bae64a..df55bf083cf7b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -705,6 +705,45 @@ struct WarpOpConstant : public WarpDistributionPattern {
   }
 };
 
+/// Sink out step op feeding into a warp op yield.
+/// Vector step op is treated similar to arith.constant, apart from
+/// the result that represents a sequence [0, vec_size).
+/// The sequence is semantically equivalent to warp's threads/lanes indices.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
+///   ...
+///   %cst = vector.step : vector<32xindex>
+///   gpu.yield %cst : vector<1xindex>
+/// }
+/// ```
+/// To
+/// ```
+/// gpu.warp_execute_on_lane_0(%arg0) {
+///   ...
+/// }
+/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
+struct WarpOpStep final : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
+    if (!yieldOperand)
+      return failure();
+    auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
+    VectorType resTy = stepOp.getResult().getType();
+    rewriter.startOpModification(warpOp);
+    rewriter.setInsertionPointAfter(warpOp);
+    Value laneIdVec = vector::BroadcastOp::create(
+        rewriter, warpOp.getLoc(), VectorType::get({1}, resTy.getElementType()),
+        warpOp.getLaneid());
+    const unsigned operandIdx = yieldOperand->getOperandNumber();
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
+    rewriter.finalizeOpModification(warpOp);
+    return success();
+  }
+};
+
 /// Sink out transfer_read op feeding into a warp op yield.
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -2016,7 +2055,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
       .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
            WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
            WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
-           WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+           WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
           patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 4d2c964a6df3c..96becb44777eb 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1824,3 +1824,18 @@ func.func @warp_propagate_duplicated_operands_in_yield(%laneid: index)  {
 // CHECK-PROP       :   }
 // CHECK-PROP       :   %[T1:.*] = math.exp %[[W]] : vector<1xf32>
 // CHECK-PROP       :   "some_use"(%[[T1]]) : (vector<1xf32>) -> ()
+
+// -----
+
+func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xindex>) {
+    %seq = vector.step : vector<32xindex>
+    gpu.yield %seq : vector<32xindex>
+  }
+  vector.transfer_write %r, %buffer[%laneid] : vector<1xindex>, memref<128xindex>
+  return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_step_distribute
+//       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %{{.*}} : index to vector<1xindex>
+//       CHECK-PROP:   vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>

>From 7ed02b7e35f13f237392d16e46f93d3b234408e4 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 27 Aug 2025 09:28:13 +0000
Subject: [PATCH 2/4] Address feedback

---
 .../Vector/Transforms/VectorDistribute.cpp    | 23 ++++++++++++-------
 .../Vector/vector-warp-distribute.mlir        |  6 +++--
 2 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index df55bf083cf7b..c19c22b92b344 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -708,7 +708,6 @@ struct WarpOpConstant : public WarpDistributionPattern {
 /// Sink out step op feeding into a warp op yield.
 /// Vector step op is treated similar to arith.constant, apart from
 /// the result that represents a sequence [0, vec_size).
-/// The sequence is semantically equivalent to warp's threads/lanes indices.
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
 ///   ...
@@ -730,16 +729,24 @@ struct WarpOpStep final : public WarpDistributionPattern {
         getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
     if (!yieldOperand)
       return failure();
+    const unsigned operandIdx = yieldOperand->getOperandNumber();
     auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
     VectorType resTy = stepOp.getResult().getType();
-    rewriter.startOpModification(warpOp);
+    if (resTy.getNumElements() != warpOp.getWarpSize())
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
+                        resTy.getNumElements(), warpOp.getWarpSize()));
+    VectorType newVecTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
     rewriter.setInsertionPointAfter(warpOp);
-    Value laneIdVec = vector::BroadcastOp::create(
-        rewriter, warpOp.getLoc(), VectorType::get({1}, resTy.getElementType()),
-        warpOp.getLaneid());
-    const unsigned operandIdx = yieldOperand->getOperandNumber();
-    rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
-    rewriter.finalizeOpModification(warpOp);
+    auto loc = warpOp.getLoc();
+    Value stepResult =
+        vector::StepOp::create(rewriter, warpOp.getLoc(), newVecTy);
+    Value laneId = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
+                                               newVecTy, warpOp.getLaneid());
+    stepResult = rewriter.create<arith::AddIOp>(loc, stepResult, laneId);
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), stepResult);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 96becb44777eb..f7cf4d9638bdf 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1837,5 +1837,7 @@ func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
 }
 
 // CHECK-PROP-LABEL: func.func @warp_step_distribute
-//       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %{{.*}} : index to vector<1xindex>
-//       CHECK-PROP:   vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>
+//       CHECK-PROP:   %[[DISTRIBUTED_STEP:.*]] = vector.step : vector<1xindex>
+//       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %arg0 : index to vector<1xindex>
+//       CHECK-PROP:   %[[LANE_STEP:.*]] = arith.addi %[[DISTRIBUTED_STEP]], %[[LANE_ID_VEC]] : vector<1xindex>
+//       CHECK-PROP:   vector.transfer_write %[[LANE_STEP]], %{{.*}} : vector<1xindex>, memref<128xindex>

>From 575606cd6abc8fbe57b5e44071ab90c32e260e4a Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 27 Aug 2025 14:02:35 +0000
Subject: [PATCH 3/4] Remove warning, add negative test

---
 .../Dialect/Vector/Transforms/VectorDistribute.cpp |  2 +-
 .../Dialect/Vector/vector-warp-distribute.mlir     | 14 ++++++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c19c22b92b344..43f46f69e23e5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -732,7 +732,7 @@ struct WarpOpStep final : public WarpDistributionPattern {
     const unsigned operandIdx = yieldOperand->getOperandNumber();
     auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
     VectorType resTy = stepOp.getResult().getType();
-    if (resTy.getNumElements() != warpOp.getWarpSize())
+    if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
       return rewriter.notifyMatchFailure(
           warpOp,
           llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f7cf4d9638bdf..ce3f9c256ec26 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1841,3 +1841,17 @@ func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
 //       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %arg0 : index to vector<1xindex>
 //       CHECK-PROP:   %[[LANE_STEP:.*]] = arith.addi %[[DISTRIBUTED_STEP]], %[[LANE_ID_VEC]] : vector<1xindex>
 //       CHECK-PROP:   vector.transfer_write %[[LANE_STEP]], %{{.*}} : vector<1xindex>, memref<128xindex>
+
+// -----
+
+func.func @negative_warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xindex>) {
+    %seq = vector.step : vector<64xindex>
+    gpu.yield %seq : vector<64xindex>
+  }
+  vector.transfer_write %r, %buffer[%laneid] : vector<2xindex>, memref<128xindex>
+  return
+}
+
+// CHECK-PROP-LABEL: @negative_warp_step_distribute
+// CHECK-PROP-NOT: vector.broadcast

>From 1787ae370a2533d2e2479174e3983271b2bb3cb8 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 28 Aug 2025 09:03:52 +0000
Subject: [PATCH 4/4] Use simple broadcast

---
 .../Vector/Transforms/VectorDistribute.cpp       | 14 +++++++-------
 .../Dialect/Vector/vector-warp-distribute.mlir   | 16 ++++++++--------
 2 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 43f46f69e23e5..c84eb2c9f8857 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -708,6 +708,10 @@ struct WarpOpConstant : public WarpDistributionPattern {
 /// Sink out step op feeding into a warp op yield.
 /// Vector step op is treated similar to arith.constant, apart from
 /// the result that represents a sequence [0, vec_size).
+/// Due to the to vec_size == warp_size limitation,
+/// we can simply wrap the lane id into a vector (i.e., broadcast).
+/// Supporting vec_size != warp_size may involve preserving the step
+/// result and using additional arith ops (the exact details are TBD).
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
 ///   ...
@@ -740,13 +744,9 @@ struct WarpOpStep final : public WarpDistributionPattern {
     VectorType newVecTy =
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
     rewriter.setInsertionPointAfter(warpOp);
-    auto loc = warpOp.getLoc();
-    Value stepResult =
-        vector::StepOp::create(rewriter, warpOp.getLoc(), newVecTy);
-    Value laneId = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
-                                               newVecTy, warpOp.getLaneid());
-    stepResult = rewriter.create<arith::AddIOp>(loc, stepResult, laneId);
-    rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), stepResult);
+    Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
+                                                  newVecTy, warpOp.getLaneid());
+    rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index ce3f9c256ec26..521c0d0fe0a67 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1827,7 +1827,8 @@ func.func @warp_propagate_duplicated_operands_in_yield(%laneid: index)  {
 
 // -----
 
-func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
+func.func @warp_step_distribute(%buffer: memref<128xindex>)  {
+  %laneid = gpu.lane_id
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xindex>) {
     %seq = vector.step : vector<32xindex>
     gpu.yield %seq : vector<32xindex>
@@ -1836,15 +1837,14 @@ func.func @warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
   return
 }
 
-// CHECK-PROP-LABEL: func.func @warp_step_distribute
-//       CHECK-PROP:   %[[DISTRIBUTED_STEP:.*]] = vector.step : vector<1xindex>
-//       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %arg0 : index to vector<1xindex>
-//       CHECK-PROP:   %[[LANE_STEP:.*]] = arith.addi %[[DISTRIBUTED_STEP]], %[[LANE_ID_VEC]] : vector<1xindex>
-//       CHECK-PROP:   vector.transfer_write %[[LANE_STEP]], %{{.*}} : vector<1xindex>, memref<128xindex>
+// CHECK-PROP-LABEL: func.func @warp_step_distribute(
+//       CHECK-PROP:   %[[LANE_ID:.*]] = gpu.lane_id
+//       CHECK-PROP:   %[[LANE_ID_VEC:.*]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
+//       CHECK-PROP:   vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>
 
 // -----
 
-func.func @negative_warp_step_distribute(%laneid: index, %buffer: memref<128xindex>)  {
+func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memref<128xindex>)  {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xindex>) {
     %seq = vector.step : vector<64xindex>
     gpu.yield %seq : vector<64xindex>
@@ -1853,5 +1853,5 @@ func.func @negative_warp_step_distribute(%laneid: index, %buffer: memref<128xind
   return
 }
 
-// CHECK-PROP-LABEL: @negative_warp_step_distribute
+// CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
 // CHECK-PROP-NOT: vector.broadcast



More information about the Mlir-commits mailing list