[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling pattern for vector StepOp (PR #157752)

Jakub Kuderski llvmlistbot at llvm.org
Tue Sep 16 11:08:24 PDT 2025


================
@@ -809,6 +809,81 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.step` operations according to the provided
+/// target unroll shape. It decomposes a large step vector into smaller step
+/// vectors (segments) and assembles the result by inserting each computed
+/// segment into the appropriate offset of the original vector.
+///
+/// The pattern does not support scalable vectors and will fail to match them.
+///
+/// For each segment, it adds the base step vector and the segment's offset,
+/// then inserts the result into the output vector at the corresponding
+/// position.
+///
+/// Example:
+///   Given a step operation:
+///     %0 = vector.step : vector<8xindex>
+///
+///   and a target unroll shape of <4>, the pattern produces:
+///
+///     %base = vector.step : vector<4xindex>
+///     %zero = arith.constant dense<0> : vector<8xindex>
+///     %result0 = vector.insert_strided_slice %base, %zero
+///       {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
+///     %offset = arith.constant dense<4> : vector<4xindex>
+///     %segment1 = arith.addi %base, %offset : vector<4xindex>
+///     %result1 = vector.insert_strided_slice %segment1, %result0
+///       {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
+///
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+  UnrollStepPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    std::optional<SmallVector<int64_t>> targetShape =
+        getTargetShape(options, stepOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType vecType = stepOp.getType();
+    if (vecType.isScalable()) {
+      // Scalable vectors are not supported by this pattern.
+      return failure();
+    }
+    int64_t originalSize = vecType.getShape()[0];
+    Location loc = stepOp.getLoc();
+    SmallVector<int64_t> strides(1, 1);
+
+    Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+                                             rewriter.getZeroAttr(vecType));
+
+    VectorType targetVecType =
+        VectorType::get(*targetShape, vecType.getElementType());
+    Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
+    for (SmallVector<int64_t> offsets :
----------------
kuhar wrote:

```suggestion
    for (const SmallVector<int64_t> &offsets :
```

https://github.com/llvm/llvm-project/pull/157752


More information about the Mlir-commits mailing list