[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling pattern for vector StepOp (PR #157752)

Nishant Patel llvmlistbot at llvm.org
Tue Sep 16 09:19:39 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/157752

>From 983c12b51e8b03830f9e04b72f7ded5fab87be86 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 9 Sep 2025 18:45:29 +0000
Subject: [PATCH 1/3] Add unroll pattern for StepOp

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  1 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  4 ++
 .../Vector/Transforms/VectorUnroll.cpp        | 50 ++++++++++++++++++-
 .../Dialect/Vector/vector-unroll-options.mlir | 23 +++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  6 +++
 5 files changed, 83 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cca1607f..3d43e26c6be42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
 
 def Vector_StepOp : Vector_Op<"step", [
     Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
   ]> {
   let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..7ea13b5723ad8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..0671dd1c4bea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+  UnrollStepPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, stepOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType vecType = stepOp.getType();
+    if (vecType.isScalable()) {
+      // Scalable vectors are not supported by this pattern.
+      return failure();
+    }
+    int64_t originalSize = vecType.getShape()[0];
+    Location loc = stepOp.getLoc();
+    SmallVector<int64_t> strides(1, 1);
+
+    Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+                                             rewriter.getZeroAttr(vecType));
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange({originalSize}, *targetShape)) {
+      int64_t tileOffset = offsets[0];
+      auto targetVecType =
+          VectorType::get(*targetShape, vecType.getElementType());
+      Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
+      Value offsetVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
+      Value bcastOffset =
+          rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+      Value tileStep =
+          rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, tileStep, result, offsets, strides);
+    }
+    rewriter.replaceOp(stepOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
-               UnrollStorePattern, UnrollBroadcastPattern>(
+               UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
       patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..777af995b4554 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
   // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
   // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
   // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+
+func.func @vector_step() -> vector<32xindex> {
+    %0 = vector.step : vector<32xindex>
+    return %0 : vector<32xindex>
+}
+// CHECK-LABEL: func @vector_step
+// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: return %[[INS3]] : vector<32xindex>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..1cd092cec2b81 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
                       .setFilterConstraint([](Operation *op) {
                         return success(isa<vector::ReductionOp>(op));
                       }));
+    populateVectorUnrollPatterns(patterns,
+                                 UnrollVectorOptions()
+                                     .setNativeShape(ArrayRef<int64_t>{8})
+                                     .setFilterConstraint([](Operation *op) {
+                                       return success(isa<vector::StepOp>(op));
+                                     }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

>From 8c6f310b09c71e6917227ae0b5e99c71fe1257a1 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 16 Sep 2025 16:04:12 +0000
Subject: [PATCH 2/3] Address Feedback

---
 .../Vector/Transforms/VectorUnroll.cpp        | 47 +++++++++++++++----
 .../Dialect/Vector/vector-unroll-options.mlir | 19 ++++----
 2 files changed, 45 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 0671dd1c4bea2..8865b96241548 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,32 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.step` operations according to the provided
+/// target unroll shape. It decomposes a large step vector into smaller step
+/// vectors (segments) and assembles the result by inserting each computed
+/// segment into the appropriate offset of the original vector.
+///
+/// The pattern does not support scalable vectors and will fail to match them.
+///
+/// For each segment, it adds the base step vector and the segment's offset,
+/// then inserts the result into the output vector at the corresponding
+/// position.
+///
+/// Example:
+///   Given a step operation:
+///     %0 = vector.step : vector<8xindex>
+///
+///   and a target unroll shape of <4>, the pattern produces:
+///
+///     %base = vector.step : vector<4xindex>
+///     %zero = arith.constant dense<0> : vector<4xindex>
+///     %result0 = vector.insert_strided_slice %base, %zero
+///       {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
+///     %offset = arith.constant dense<4> : vector<4xindex>
+///     %segment1 = arith.addi %base, %offset : vector<4xindex>
+///     %result1 = vector.insert_strided_slice %segment1, %result0
+///       {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
+///
 struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
   UnrollStepPattern(MLIRContext *context,
                     const vector::UnrollVectorOptions &options,
@@ -817,7 +843,8 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
 
   LogicalResult matchAndRewrite(vector::StepOp stepOp,
                                 PatternRewriter &rewriter) const override {
-    auto targetShape = getTargetShape(options, stepOp);
+    std::optional<SmallVector<int64_t>> targetShape =
+        getTargetShape(options, stepOp);
     if (!targetShape)
       return failure();
 
@@ -833,18 +860,18 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
     Value result = arith::ConstantOp::create(rewriter, loc, vecType,
                                              rewriter.getZeroAttr(vecType));
 
+    VectorType targetVecType =
+        VectorType::get(*targetShape, vecType.getElementType());
+    Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange({originalSize}, *targetShape)) {
-      int64_t tileOffset = offsets[0];
-      auto targetVecType =
-          VectorType::get(*targetShape, vecType.getElementType());
-      Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
-      Value offsetVal =
-          rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
-      Value bcastOffset =
-          rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+      Value bcastOffset = arith::ConstantOp::create(
+          rewriter, loc, targetVecType,
+          DenseElementsAttr::get(
+              targetVecType,
+              IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
       Value tileStep =
-          rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+          arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
 
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, tileStep, result, offsets, strides);
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 777af995b4554..35db14e0f7f1d 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -427,19 +427,16 @@ func.func @vector_step() -> vector<32xindex> {
     return %0 : vector<32xindex>
 }
 // CHECK-LABEL: func @vector_step
-// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
-// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<16> : vector<8xindex>
 // CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
-// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
-// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
-// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
-// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
-// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP]], %[[CST2]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP]], %[[CST1]] : vector<8xindex>
 // CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
-// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
-// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP]], %[[CST0]] : vector<8xindex>
 // CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
-// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
-// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
 // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
 // CHECK: return %[[INS3]] : vector<32xindex>

>From f2673413c4df547e974095ebb9ba52cde39f5de4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 16 Sep 2025 16:19:24 +0000
Subject: [PATCH 3/3] Fix shape

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 8865b96241548..f778d3d860fe9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -827,7 +827,7 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
 ///   and a target unroll shape of <4>, the pattern produces:
 ///
 ///     %base = vector.step : vector<4xindex>
-///     %zero = arith.constant dense<0> : vector<4xindex>
+///     %zero = arith.constant dense<0> : vector<8xindex>
 ///     %result0 = vector.insert_strided_slice %base, %zero
 ///       {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
 ///     %offset = arith.constant dense<4> : vector<4xindex>



More information about the Mlir-commits mailing list