[Mlir-commits] [mlir] [MLIR][Vector] Add unrolling pattern for vector StepOp (PR #157752)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 9 14:48:29 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
This PR adds unrolling pattern for vector.step op to VectorUnroll transform.
---
Full diff: https://github.com/llvm/llvm-project/pull/157752.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+49-1)
- (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+23)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cca1607f..3d43e26c6be42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
def Vector_StepOp : Vector_Op<"step", [
Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..7ea13b5723ad8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..0671dd1c4bea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+ UnrollStepPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, stepOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType vecType = stepOp.getType();
+ if (vecType.isScalable()) {
+ // Scalable vectors are not supported by this pattern.
+ return failure();
+ }
+ int64_t originalSize = vecType.getShape()[0];
+ Location loc = stepOp.getLoc();
+ SmallVector<int64_t> strides(1, 1);
+
+ Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+ rewriter.getZeroAttr(vecType));
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange({originalSize}, *targetShape)) {
+ int64_t tileOffset = offsets[0];
+ auto targetVecType =
+ VectorType::get(*targetShape, vecType.getElementType());
+ Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
+ Value offsetVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
+ Value bcastOffset =
+ rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+ Value tileStep =
+ rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tileStep, result, offsets, strides);
+ }
+ rewriter.replaceOp(stepOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
- UnrollStorePattern, UnrollBroadcastPattern>(
+ UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..777af995b4554 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+
+func.func @vector_step() -> vector<32xindex> {
+ %0 = vector.step : vector<32xindex>
+ return %0 : vector<32xindex>
+}
+// CHECK-LABEL: func @vector_step
+// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: return %[[INS3]] : vector<32xindex>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..1cd092cec2b81 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
+ populateVectorUnrollPatterns(patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::StepOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
``````````
</details>
https://github.com/llvm/llvm-project/pull/157752
More information about the Mlir-commits
mailing list