[Mlir-commits] [mlir] [mlir][vector] Add unroll patterns for vector.load and vector.store (PR #143420)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 9 11:49:09 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
This PR adds unroll patterns for vector.load and vector.store with rank > 1 and unrolls them to 1D load and store. This PR is follow up of #<!-- -->137558
---
Full diff: https://github.com/llvm/llvm-project/pull/143420.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+125-6)
- (added) mlir/test/Dialect/Vector/vector-load-store-unroll.mlir (+73)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+40)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fc443ab0d138e..e912a6ef29b21 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
return slicedIndices;
}
+// compute the new indices for vector.load/store by adding offsets to
+// originalIndices.
+// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
+// Last m of originalIndices will be updated.
+static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
+ Location loc,
+ ArrayRef<Value> originalIndices,
+ ArrayRef<int64_t> offsets) {
+ assert(offsets.size() <= originalIndices.size() &&
+ "Offsets should not exceed the number of original indices");
+ SmallVector<Value> indices(originalIndices);
+ auto originalIter = originalIndices.rbegin();
+ auto offsetsIter = offsets.rbegin();
+ auto indicesIter = indices.rbegin();
+ while (offsetsIter != offsets.rend()) {
+ Value original = *originalIter;
+ int64_t offset = *offsetsIter;
+ if (offset != 0)
+ *indicesIter = rewriter.create<arith::AddIOp>(
+ loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ originalIter++;
+ offsetsIter++;
+ indicesIter++;
+ }
+ return indices;
+};
+
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
+ UnrollLoadPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = loadOp.getVectorType();
+ // Only unroll >1D loads
+ if (vecType.getRank() <= 1)
+ return failure();
+
+ Location loc = loadOp.getLoc();
+ ArrayRef<int64_t> originalShape = vecType.getShape();
+
+ // Target type is a 1D vector of the innermost dimension.
+ auto targetType =
+ VectorType::get(originalShape.back(), vecType.getElementType());
+
+ // Extend the targetShape to the same rank of original shape by padding 1s
+ // for leading dimensions for convenience of computing offsets
+ SmallVector<int64_t> targetShape(originalShape.size(), 1);
+ targetShape.back() = originalShape.back();
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, vecType, rewriter.getZeroAttr(vecType));
+
+ SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
+ loadOp.getIndices().end());
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, targetShape)) {
+ SmallVector<Value> indices =
+ computeIndices(rewriter, loc, originalIndices, offsets);
+ Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
+ loadOp.getBase(), indices);
+ // Insert the slice into the result at the correct position.
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, slice, result, offsets, SmallVector<int64_t>({1}));
+ }
+ rewriter.replaceOp(loadOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
+ UnrollStorePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = storeOp.getVectorType();
+ // Only unroll >1D stores.
+ if (vecType.getRank() <= 1)
+ return failure();
+
+ Location loc = storeOp.getLoc();
+ ArrayRef<int64_t> originalShape = vecType.getShape();
+
+ // Extend the targetShape to the same rank of original shape by padding 1s
+ // for leading dimensions for convenience of computing offsets
+ SmallVector<int64_t> targetShape(originalShape.size(), 1);
+ targetShape.back() = originalShape.back();
+
+ Value base = storeOp.getBase();
+ Value vector = storeOp.getValueToStore();
+
+ SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
+ storeOp.getIndices().end());
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, targetShape)) {
+ SmallVector<Value> indices =
+ computeIndices(rewriter, loc, originalIndices, offsets);
+ offsets.pop_back();
+ Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
+ rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
UnrollBroadcastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -699,10 +818,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
- patterns
- .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
- UnrollContractionPattern, UnrollElementwisePattern,
- UnrollReductionPattern, UnrollMultiReductionPattern,
- UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
- patterns.getContext(), options, benefit);
+ patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+ UnrollContractionPattern, UnrollElementwisePattern,
+ UnrollReductionPattern, UnrollMultiReductionPattern,
+ UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
+ UnrollStorePattern, UnrollBroadcastPattern>(
+ patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
new file mode 100644
index 0000000000000..3135268b8d61b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: return %[[V7]] : vector<4x4xf16>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: return %[[V3]] : vector<2x2xf16>
+ %c1 = arith.constant 1 : index
+ %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ %c1 = arith.constant 1 : index
+ vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54aa96ba89a00..8014362a1a6ec 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -292,6 +292,44 @@ struct TestVectorTransferUnrollingPatterns
llvm::cl::init(false)};
};
+struct TestVectorLoadStoreUnrollPatterns
+ : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorLoadStoreUnrollPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-load-store-unroll";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for vector.load and vector.store ops";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect, arith::ArithDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+
+ // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
+ vector::UnrollVectorOptions options;
+ options.setFilterConstraint([](Operation *op) {
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return success(loadOp.getType().getRank() > 1);
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return success(storeOp.getVectorType().getRank() > 1);
+ return failure();
+ });
+
+ vector::populateVectorUnrollPatterns(patterns, options);
+
+ // Apply the patterns
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestScalarVectorTransferLoweringPatterns
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
@@ -1032,6 +1070,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferUnrollingPatterns>();
+ PassRegistration<TestVectorLoadStoreUnrollPatterns>();
+
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
PassRegistration<TestVectorTransferOpt>();
``````````
</details>
https://github.com/llvm/llvm-project/pull/143420
More information about the Mlir-commits
mailing list