[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling pattern for vector StepOp (PR #157752)
Nishant Patel
llvmlistbot at llvm.org
Tue Sep 16 11:34:42 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/157752
>From 983c12b51e8b03830f9e04b72f7ded5fab87be86 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 9 Sep 2025 18:45:29 +0000
Subject: [PATCH 1/5] Add unroll pattern for StepOp
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 1 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++
.../Vector/Transforms/VectorUnroll.cpp | 50 ++++++++++++++++++-
.../Dialect/Vector/vector-unroll-options.mlir | 23 +++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 6 +++
5 files changed, 83 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cca1607f..3d43e26c6be42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
def Vector_StepOp : Vector_Op<"step", [
Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..7ea13b5723ad8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..0671dd1c4bea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+ UnrollStepPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, stepOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType vecType = stepOp.getType();
+ if (vecType.isScalable()) {
+ // Scalable vectors are not supported by this pattern.
+ return failure();
+ }
+ int64_t originalSize = vecType.getShape()[0];
+ Location loc = stepOp.getLoc();
+ SmallVector<int64_t> strides(1, 1);
+
+ Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+ rewriter.getZeroAttr(vecType));
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange({originalSize}, *targetShape)) {
+ int64_t tileOffset = offsets[0];
+ auto targetVecType =
+ VectorType::get(*targetShape, vecType.getElementType());
+ Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
+ Value offsetVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
+ Value bcastOffset =
+ rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+ Value tileStep =
+ rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tileStep, result, offsets, strides);
+ }
+ rewriter.replaceOp(stepOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
- UnrollStorePattern, UnrollBroadcastPattern>(
+ UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..777af995b4554 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+
+func.func @vector_step() -> vector<32xindex> {
+ %0 = vector.step : vector<32xindex>
+ return %0 : vector<32xindex>
+}
+// CHECK-LABEL: func @vector_step
+// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: return %[[INS3]] : vector<32xindex>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..1cd092cec2b81 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
+ populateVectorUnrollPatterns(patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::StepOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
>From 8c6f310b09c71e6917227ae0b5e99c71fe1257a1 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 16 Sep 2025 16:04:12 +0000
Subject: [PATCH 2/5] Address Feedback
---
.../Vector/Transforms/VectorUnroll.cpp | 47 +++++++++++++++----
.../Dialect/Vector/vector-unroll-options.mlir | 19 ++++----
2 files changed, 45 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 0671dd1c4bea2..8865b96241548 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,32 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.step` operations according to the provided
+/// target unroll shape. It decomposes a large step vector into smaller step
+/// vectors (segments) and assembles the result by inserting each computed
+/// segment into the appropriate offset of the original vector.
+///
+/// The pattern does not support scalable vectors and will fail to match them.
+///
+/// For each segment, it adds the base step vector and the segment's offset,
+/// then inserts the result into the output vector at the corresponding
+/// position.
+///
+/// Example:
+/// Given a step operation:
+/// %0 = vector.step : vector<8xindex>
+///
+/// and a target unroll shape of <4>, the pattern produces:
+///
+/// %base = vector.step : vector<4xindex>
+/// %zero = arith.constant dense<0> : vector<4xindex>
+/// %result0 = vector.insert_strided_slice %base, %zero
+/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
+/// %offset = arith.constant dense<4> : vector<4xindex>
+/// %segment1 = arith.addi %base, %offset : vector<4xindex>
+/// %result1 = vector.insert_strided_slice %segment1, %result0
+/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
+///
struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
UnrollStepPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -817,7 +843,8 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
LogicalResult matchAndRewrite(vector::StepOp stepOp,
PatternRewriter &rewriter) const override {
- auto targetShape = getTargetShape(options, stepOp);
+ std::optional<SmallVector<int64_t>> targetShape =
+ getTargetShape(options, stepOp);
if (!targetShape)
return failure();
@@ -833,18 +860,18 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
Value result = arith::ConstantOp::create(rewriter, loc, vecType,
rewriter.getZeroAttr(vecType));
+ VectorType targetVecType =
+ VectorType::get(*targetShape, vecType.getElementType());
+ Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange({originalSize}, *targetShape)) {
- int64_t tileOffset = offsets[0];
- auto targetVecType =
- VectorType::get(*targetShape, vecType.getElementType());
- Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
- Value offsetVal =
- rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
- Value bcastOffset =
- rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+ Value bcastOffset = arith::ConstantOp::create(
+ rewriter, loc, targetVecType,
+ DenseElementsAttr::get(
+ targetVecType,
+ IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
Value tileStep =
- rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+ arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tileStep, result, offsets, strides);
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 777af995b4554..35db14e0f7f1d 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -427,19 +427,16 @@ func.func @vector_step() -> vector<32xindex> {
return %0 : vector<32xindex>
}
// CHECK-LABEL: func @vector_step
-// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
-// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<16> : vector<8xindex>
// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
-// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
-// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
-// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
-// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
-// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP]], %[[CST2]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP]], %[[CST1]] : vector<8xindex>
// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
-// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
-// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP]], %[[CST0]] : vector<8xindex>
// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
-// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
-// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: return %[[INS3]] : vector<32xindex>
>From f2673413c4df547e974095ebb9ba52cde39f5de4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 16 Sep 2025 16:19:24 +0000
Subject: [PATCH 3/5] Fix shape
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 8865b96241548..f778d3d860fe9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -827,7 +827,7 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
/// and a target unroll shape of <4>, the pattern produces:
///
/// %base = vector.step : vector<4xindex>
-/// %zero = arith.constant dense<0> : vector<4xindex>
+/// %zero = arith.constant dense<0> : vector<8xindex>
/// %result0 = vector.insert_strided_slice %base, %zero
/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
/// %offset = arith.constant dense<4> : vector<4xindex>
>From 4d8a2865e981800e64122c37136e55824d9d0fa3 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 16 Sep 2025 17:28:01 +0000
Subject: [PATCH 4/5] Remove getShapeForUnroll for step
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ----
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3d43e26c6be42..c60fa3b85b396 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,7 +3018,7 @@ def Vector_ScanOp :
def Vector_StepOp : Vector_Op<"step", [
Pure,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7ea13b5723ad8..85e485c28c74e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,10 +7442,6 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
-std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
- return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
-}
-
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
>From 80359dd4aa1c792de0987354be91c8737b507d81 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 16 Sep 2025 18:28:33 +0000
Subject: [PATCH 5/5] Feedback
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index f778d3d860fe9..79786f33a2d46 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -860,10 +860,10 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
Value result = arith::ConstantOp::create(rewriter, loc, vecType,
rewriter.getZeroAttr(vecType));
- VectorType targetVecType =
+ auto targetVecType =
VectorType::get(*targetShape, vecType.getElementType());
Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
- for (SmallVector<int64_t> offsets :
+ for (const SmallVector<int64_t> &offsets :
StaticTileOffsetRange({originalSize}, *targetShape)) {
Value bcastOffset = arith::ConstantOp::create(
rewriter, loc, targetVecType,
More information about the Mlir-commits
mailing list