[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling pattern for vector StepOp (PR #157752)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 9 14:48:29 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>

This PR adds unrolling pattern for vector.step op to VectorUnroll transform. 

---
Full diff: https://github.com/llvm/llvm-project/pull/157752.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+49-1) 
- (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+23) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cca1607f..3d43e26c6be42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
 
 def Vector_StepOp : Vector_Op<"step", [
     Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
   ]> {
   let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..7ea13b5723ad8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..0671dd1c4bea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+  UnrollStepPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, stepOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType vecType = stepOp.getType();
+    if (vecType.isScalable()) {
+      // Scalable vectors are not supported by this pattern.
+      return failure();
+    }
+    int64_t originalSize = vecType.getShape()[0];
+    Location loc = stepOp.getLoc();
+    SmallVector<int64_t> strides(1, 1);
+
+    Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+                                             rewriter.getZeroAttr(vecType));
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange({originalSize}, *targetShape)) {
+      int64_t tileOffset = offsets[0];
+      auto targetVecType =
+          VectorType::get(*targetShape, vecType.getElementType());
+      Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
+      Value offsetVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
+      Value bcastOffset =
+          rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+      Value tileStep =
+          rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, tileStep, result, offsets, strides);
+    }
+    rewriter.replaceOp(stepOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
-               UnrollStorePattern, UnrollBroadcastPattern>(
+               UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
       patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..777af995b4554 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
   // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
   // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
   // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+
+func.func @vector_step() -> vector<32xindex> {
+    %0 = vector.step : vector<32xindex>
+    return %0 : vector<32xindex>
+}
+// CHECK-LABEL: func @vector_step
+// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: return %[[INS3]] : vector<32xindex>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..1cd092cec2b81 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
                       .setFilterConstraint([](Operation *op) {
                         return success(isa<vector::ReductionOp>(op));
                       }));
+    populateVectorUnrollPatterns(patterns,
+                                 UnrollVectorOptions()
+                                     .setNativeShape(ArrayRef<int64_t>{8})
+                                     .setFilterConstraint([](Operation *op) {
+                                       return success(isa<vector::StepOp>(op));
+                                     }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

``````````

</details>


https://github.com/llvm/llvm-project/pull/157752


More information about the Mlir-commits mailing list