[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `vector.step` op (PR #155425)

Artem Kroviakov llvmlistbot at llvm.org
Thu Aug 28 02:05:38 PDT 2025


================
@@ -705,6 +705,52 @@ struct WarpOpConstant : public WarpDistributionPattern {
   }
 };
 
+/// Sink out step op feeding into a warp op yield.
+/// Vector step op is treated similar to arith.constant, apart from
+/// the result that represents a sequence [0, vec_size).
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
+///   ...
+///   %cst = vector.step : vector<32xindex>
+///   gpu.yield %cst : vector<1xindex>
+/// }
+/// ```
+/// To
+/// ```
+/// gpu.warp_execute_on_lane_0(%arg0) {
+///   ...
+/// }
+/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
+struct WarpOpStep final : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
+    if (!yieldOperand)
+      return failure();
+    const unsigned operandIdx = yieldOperand->getOperandNumber();
+    auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
+    VectorType resTy = stepOp.getResult().getType();
+    if (resTy.getNumElements() != warpOp.getWarpSize())
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
+                        resTy.getNumElements(), warpOp.getWarpSize()));
+    VectorType newVecTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    rewriter.setInsertionPointAfter(warpOp);
+    auto loc = warpOp.getLoc();
+    Value stepResult =
+        vector::StepOp::create(rewriter, warpOp.getLoc(), newVecTy);
+    Value laneId = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
+                                               newVecTy, warpOp.getLaneid());
+    stepResult = rewriter.create<arith::AddIOp>(loc, stepResult, laneId);
----------------
akroviakov wrote:

Changed to a simple broadcast, added a comment

https://github.com/llvm/llvm-project/pull/155425


More information about the Mlir-commits mailing list