[Mlir-commits] [mlir] [mlir][vector] Add unroll patterns for vector.load and vector.store (PR #143420)
Nishant Patel
llvmlistbot at llvm.org
Fri Jun 20 07:32:30 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/143420
>From 5003057b2010149a95fda72b6dd395c918329408 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 9 Jun 2025 17:56:30 +0000
Subject: [PATCH 1/5] Add unroll patterns for vector.load and vector.store
---
.../Vector/Transforms/VectorUnroll.cpp | 123 +++++++++++++++++-
.../Vector/vector-load-store-unroll.mlir | 73 +++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 40 ++++++
3 files changed, 234 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1cc477d9dca91..43abf84cd6428 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
return slicedIndices;
}
+// compute the new indices for vector.load/store by adding offsets to
+// originalIndices.
+// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
+// Last m of originalIndices will be updated.
+static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
+ Location loc,
+ ArrayRef<Value> originalIndices,
+ ArrayRef<int64_t> offsets) {
+ assert(offsets.size() <= originalIndices.size() &&
+ "Offsets should not exceed the number of original indices");
+ SmallVector<Value> indices(originalIndices);
+ auto originalIter = originalIndices.rbegin();
+ auto offsetsIter = offsets.rbegin();
+ auto indicesIter = indices.rbegin();
+ while (offsetsIter != offsets.rend()) {
+ Value original = *originalIter;
+ int64_t offset = *offsetsIter;
+ if (offset != 0)
+ *indicesIter = rewriter.create<arith::AddIOp>(
+ loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ originalIter++;
+ offsetsIter++;
+ indicesIter++;
+ }
+ return indices;
+};
+
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
+ UnrollLoadPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = loadOp.getVectorType();
+ // Only unroll >1D loads
+ if (vecType.getRank() <= 1)
+ return failure();
+
+ Location loc = loadOp.getLoc();
+ ArrayRef<int64_t> originalShape = vecType.getShape();
+
+ // Target type is a 1D vector of the innermost dimension.
+ auto targetType =
+ VectorType::get(originalShape.back(), vecType.getElementType());
+
+ // Extend the targetShape to the same rank of original shape by padding 1s
+ // for leading dimensions for convenience of computing offsets
+ SmallVector<int64_t> targetShape(originalShape.size(), 1);
+ targetShape.back() = originalShape.back();
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, vecType, rewriter.getZeroAttr(vecType));
+
+ SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
+ loadOp.getIndices().end());
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, targetShape)) {
+ SmallVector<Value> indices =
+ computeIndices(rewriter, loc, originalIndices, offsets);
+ Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
+ loadOp.getBase(), indices);
+ // Insert the slice into the result at the correct position.
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, slice, result, offsets, SmallVector<int64_t>({1}));
+ }
+ rewriter.replaceOp(loadOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
+ UnrollStorePattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = storeOp.getVectorType();
+ // Only unroll >1D stores.
+ if (vecType.getRank() <= 1)
+ return failure();
+
+ Location loc = storeOp.getLoc();
+ ArrayRef<int64_t> originalShape = vecType.getShape();
+
+ // Extend the targetShape to the same rank of original shape by padding 1s
+ // for leading dimensions for convenience of computing offsets
+ SmallVector<int64_t> targetShape(originalShape.size(), 1);
+ targetShape.back() = originalShape.back();
+
+ Value base = storeOp.getBase();
+ Value vector = storeOp.getValueToStore();
+
+ SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
+ storeOp.getIndices().end());
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, targetShape)) {
+ SmallVector<Value> indices =
+ computeIndices(rewriter, loc, originalIndices, offsets);
+ offsets.pop_back();
+ Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
+ rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -639,6 +758,6 @@ void mlir::vector::populateVectorUnrollPatterns(
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
- UnrollTransposePattern, UnrollGatherPattern>(
- patterns.getContext(), options, benefit);
+ UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
+ UnrollStorePattern>(patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
new file mode 100644
index 0000000000000..3135268b8d61b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: return %[[V7]] : vector<4x4xf16>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: return %[[V3]] : vector<2x2xf16>
+ %c1 = arith.constant 1 : index
+ %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ %c1 = arith.constant 1 : index
+ vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..b2b2b4ece22cd 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -289,6 +289,44 @@ struct TestVectorTransferUnrollingPatterns
llvm::cl::init(false)};
};
+struct TestVectorLoadStoreUnrollPatterns
+ : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorLoadStoreUnrollPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-load-store-unroll";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for vector.load and vector.store ops";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect, arith::ArithDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+
+ // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
+ vector::UnrollVectorOptions options;
+ options.setFilterConstraint([](Operation *op) {
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return success(loadOp.getType().getRank() > 1);
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return success(storeOp.getVectorType().getRank() > 1);
+ return failure();
+ });
+
+ vector::populateVectorUnrollPatterns(patterns, options);
+
+ // Apply the patterns
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestScalarVectorTransferLoweringPatterns
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
@@ -1033,6 +1071,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferUnrollingPatterns>();
+ PassRegistration<TestVectorLoadStoreUnrollPatterns>();
+
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
PassRegistration<TestVectorTransferOpt>();
>From 9d91abe8417b56bfb6b7e220b8fbbd050b8e03da Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 9 Jun 2025 20:43:56 +0000
Subject: [PATCH 2/5] Clean up
---
.../Vector/vector-load-store-unroll.mlir | 73 -------------------
.../Dialect/Vector/vector-unroll-options.mlir | 73 +++++++++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 50 +++----------
3 files changed, 83 insertions(+), 113 deletions(-)
delete mode 100644 mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
deleted file mode 100644
index 3135268b8d61b..0000000000000
--- a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
+++ /dev/null
@@ -1,73 +0,0 @@
-// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
-
-// CHECK-LABEL: func.func @unroll_2D_vector_load(
-// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
-func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
- // CHECK: %[[C3:.*]] = arith.constant 3 : index
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
- // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
- // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
- // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
- // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
- // CHECK: return %[[V7]] : vector<4x4xf16>
- %c0 = arith.constant 0 : index
- %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
- return %0 : vector<4x4xf16>
-}
-
-// CHECK-LABEL: func.func @unroll_2D_vector_store(
-// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
-func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
- // CHECK: %[[C3:.*]] = arith.constant 3 : index
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
- // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
- // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
- // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
- // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- %c0 = arith.constant 0 : index
- vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
- return
-}
-
-// CHECK-LABEL: func.func @unroll_vector_load(
-// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
-func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
- // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
- // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
- // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
- // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
- // CHECK: return %[[V3]] : vector<2x2xf16>
- %c1 = arith.constant 1 : index
- %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
- return %0 : vector<2x2xf16>
-}
-
-// CHECK-LABEL: func.func @unroll_vector_store(
-// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
-func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
- // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
- // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
- // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
- %c1 = arith.constant 1 : index
- vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
- return
-}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index fbb178fb49d87..efb709e41a69c 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -378,3 +378,76 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: return [[r3]] : vector<4x4xf32>
+
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: return %[[V7]] : vector<4x4xf16>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+ // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+ // CHECK: return %[[V3]] : vector<2x2xf16>
+ %c1 = arith.constant 1 : index
+ %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ %c1 = arith.constant 1 : index
+ vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 8014362a1a6ec..023a6706b58be 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,16 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::TransposeOp>(op));
}));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{2, 2})
+ .setFilterConstraint([](Operation *op) {
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return success(loadOp.getType().getRank() > 1);
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return success(storeOp.getVectorType().getRank() > 1);
+ return failure();
+ }));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
@@ -292,44 +302,6 @@ struct TestVectorTransferUnrollingPatterns
llvm::cl::init(false)};
};
-struct TestVectorLoadStoreUnrollPatterns
- : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestVectorLoadStoreUnrollPatterns)
-
- StringRef getArgument() const final {
- return "test-vector-load-store-unroll";
- }
- StringRef getDescription() const final {
- return "Test unrolling patterns for vector.load and vector.store ops";
- }
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect, arith::ArithDialect>();
- }
-
- void runOnOperation() override {
- MLIRContext *ctx = &getContext();
- RewritePatternSet patterns(ctx);
-
- // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
- vector::UnrollVectorOptions options;
- options.setFilterConstraint([](Operation *op) {
- if (auto loadOp = dyn_cast<vector::LoadOp>(op))
- return success(loadOp.getType().getRank() > 1);
- if (auto storeOp = dyn_cast<vector::StoreOp>(op))
- return success(storeOp.getVectorType().getRank() > 1);
- return failure();
- });
-
- vector::populateVectorUnrollPatterns(patterns, options);
-
- // Apply the patterns
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
- }
-};
-
struct TestScalarVectorTransferLoweringPatterns
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
@@ -1070,8 +1042,6 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferUnrollingPatterns>();
- PassRegistration<TestVectorLoadStoreUnrollPatterns>();
-
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
PassRegistration<TestVectorTransferOpt>();
>From 3f4094825463cc592415dc90f03013b9db5a5230 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 10 Jun 2025 18:36:24 +0000
Subject: [PATCH 3/5] Address feedback
---
.../Vector/Transforms/VectorUnroll.cpp | 31 +++++++++++--
.../Dialect/Vector/vector-unroll-options.mlir | 44 +++++++++++--------
2 files changed, 53 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e912a6ef29b21..6780a898b7fd5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,10 +54,10 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
return slicedIndices;
}
-// compute the new indices for vector.load/store by adding offsets to
-// originalIndices.
+// Compute the new indices for vector.load/store by adding `offsets` to
+// `originalIndices`.
// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
-// Last m of originalIndices will be updated.
+// Last m of `originalIndices` will be updated.
static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
Location loc,
ArrayRef<Value> originalIndices,
@@ -658,6 +658,20 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+// clang-format off
+// This pattern unrolls the vector load into multiple 1D vector loads by
+// extracting slices from the base memory and inserting them into the result
+// vector using vector.insert_strided_slice.
+// Following,
+// vector.load %base[%indices] : memref<4x4xf32>, vector<4x4xf32>
+// is converted to :
+// %cst = arith.constant dense<0.0> : vector<4x4xf32>
+// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32>
+// %result_0 = vector.insert_strided_slice %slice_0, %cst {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
+// %slice_1 = vector.load %base[%indices + 1] : memref<4x4xf32>, vector<4xf32>
+// %result_1 = vector.insert_strided_slice %slice_1, %result_0 {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
+// ...
+// clang-format on
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
UnrollLoadPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -707,6 +721,17 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
vector::UnrollVectorOptions options;
};
+// This pattern unrolls the vector store into multiple 1D vector stores by
+// extracting slices from the source vector and storing them into the
+// destination.
+// Following,
+// vector.store %source, %base[%indices] : vector<4x4xf32>
+// is converted to :
+// %slice_0 = vector.extract %source[0] : vector<4xf32>
+// vector.store %slice_0, %base[%indices] : vector<4xf32>
+// %slice_1 = vector.extract %source[1] : vector<4xf32>
+// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+// ...
struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
UnrollStorePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index efb709e41a69c..23344a400bcc7 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -380,9 +380,14 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
// CHECK: return [[r3]] : vector<4x4xf32>
-// CHECK-LABEL: func.func @unroll_2D_vector_load(
+func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @vector_load_2D(
// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
-func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -397,14 +402,16 @@ func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
// CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
// CHECK: return %[[V7]] : vector<4x4xf16>
+
+
+func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
%c0 = arith.constant 0 : index
- %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
- return %0 : vector<4x4xf16>
+ vector.store %v, %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+ return
}
-// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-LABEL: func.func @vector_store_2D(
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
-func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -417,14 +424,16 @@ func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
// CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- %c0 = arith.constant 0 : index
- vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
- return
+
+
+func.func @vector_load_4D_to_2D(%mem: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+ %c1 = arith.constant 1 : index
+ %0 = vector.load %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return %0 : vector<2x2xf16>
}
-// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-LABEL: func.func @vector_load_4D_to_2D(
// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
-func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
@@ -433,21 +442,18 @@ func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
// CHECK: return %[[V3]] : vector<2x2xf16>
+
+func.func @vector_store_2D_to_4D(%mem: memref<4x4x4x4xf16>, %v: vector<2x2xf16>) {
%c1 = arith.constant 1 : index
- %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
- return %0 : vector<2x2xf16>
+ vector.store %v, %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+ return
}
-// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-LABEL: func.func @vector_store_2D_to_4D(
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
-func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
// CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
- %c1 = arith.constant 1 : index
- vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
- return
-}
>From 5a2070b794db9932e1b24fd53e20a662f8212e2a Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 22:30:05 +0000
Subject: [PATCH 4/5] Simplify computeIndices
---
.../Vector/Transforms/VectorUnroll.cpp | 40 +++++++++----------
1 file changed, 18 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 6780a898b7fd5..57e36e91f6e5e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,10 +54,9 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
return slicedIndices;
}
-// Compute the new indices for vector.load/store by adding `offsets` to
-// `originalIndices`.
-// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
-// Last m of `originalIndices` will be updated.
+// Compute the new indices by adding `offsets` to `originalIndices`.
+// If m < n (m = offsets.size(), n = originalIndices.size()),
+// then only the trailing m values in `originalIndices` are updated.
static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
Location loc,
ArrayRef<Value> originalIndices,
@@ -65,21 +64,17 @@ static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
assert(offsets.size() <= originalIndices.size() &&
"Offsets should not exceed the number of original indices");
SmallVector<Value> indices(originalIndices);
- auto originalIter = originalIndices.rbegin();
- auto offsetsIter = offsets.rbegin();
- auto indicesIter = indices.rbegin();
- while (offsetsIter != offsets.rend()) {
- Value original = *originalIter;
- int64_t offset = *offsetsIter;
- if (offset != 0)
- *indicesIter = rewriter.create<arith::AddIOp>(
- loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
- originalIter++;
- offsetsIter++;
- indicesIter++;
+
+ auto start = indices.size() - offsets.size();
+ for (auto [i, offset] : llvm::enumerate(offsets)) {
+ if (offset != 0) {
+ indices[start + i] = rewriter.create<arith::AddIOp>(
+ loc, originalIndices[start + i],
+ rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ }
}
return indices;
-};
+}
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
@@ -658,7 +653,6 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
-// clang-format off
// This pattern unrolls the vector load into multiple 1D vector loads by
// extracting slices from the base memory and inserting them into the result
// vector using vector.insert_strided_slice.
@@ -667,11 +661,13 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// is converted to :
// %cst = arith.constant dense<0.0> : vector<4x4xf32>
// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32>
-// %result_0 = vector.insert_strided_slice %slice_0, %cst {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
-// %slice_1 = vector.load %base[%indices + 1] : memref<4x4xf32>, vector<4xf32>
-// %result_1 = vector.insert_strided_slice %slice_1, %result_0 {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
+// %result_0 = vector.insert_strided_slice %slice_0, %cst
+// {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
+// %slice_1 = vector.load %base[%indices + 1]
+// : memref<4x4xf32>, vector<4xf32>
+// %result_1 = vector.insert_strided_slice %slice_1, %result_0
+// {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
// ...
-// clang-format on
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
UnrollLoadPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
>From 57cc380c625d9b1c344240d3715025f773ed9c46 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 19 Jun 2025 17:06:48 +0000
Subject: [PATCH 5/5] Use unroll options
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 8 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 ++
.../Vector/Transforms/VectorUnroll.cpp | 74 +++++++------------
.../Dialect/Vector/vector-unroll-options.mlir | 69 ++++-------------
.../Dialect/Vector/TestVectorTransforms.cpp | 12 +--
5 files changed, 56 insertions(+), 115 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..d05fea3a5d755 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1673,7 +1673,9 @@ def Vector_TransferWriteOp :
let hasVerifier = 1;
}
-def Vector_LoadOp : Vector_Op<"load"> {
+def Vector_LoadOp : Vector_Op<"load", [
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ ]> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
The 'vector.load' operation reads an n-D slice of memory into an n-D
@@ -1759,7 +1761,9 @@ def Vector_LoadOp : Vector_Op<"load"> {
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
}
-def Vector_StoreOp : Vector_Op<"store"> {
+def Vector_StoreOp : Vector_Op<"store", [
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ ]> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
The 'vector.store' operation writes an n-D vector to an n-D slice of memory.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3179b4f975404..1d0d0ec3c2fc9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5266,6 +5266,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}
+std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getVectorType().getShape());
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
@@ -5301,6 +5305,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
return memref::foldMemRefCast(*this);
}
+std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getVectorType().getShape());
+}
+
//===----------------------------------------------------------------------===//
// MaskedLoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 57e36e91f6e5e..baee341f6768b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -653,21 +653,6 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
-// This pattern unrolls the vector load into multiple 1D vector loads by
-// extracting slices from the base memory and inserting them into the result
-// vector using vector.insert_strided_slice.
-// Following,
-// vector.load %base[%indices] : memref<4x4xf32>, vector<4x4xf32>
-// is converted to :
-// %cst = arith.constant dense<0.0> : vector<4x4xf32>
-// %slice_0 = vector.load %base[%indices] : memref<4x4xf32>, vector<4xf32>
-// %result_0 = vector.insert_strided_slice %slice_0, %cst
-// {offsets = [0, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
-// %slice_1 = vector.load %base[%indices + 1]
-// : memref<4x4xf32>, vector<4xf32>
-// %result_1 = vector.insert_strided_slice %slice_1, %result_0
-// {offsets = [1, 0], strides = [1]} : vector<4xf32> into vector<4x4xf32>
-// ...
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
UnrollLoadPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -677,21 +662,16 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
VectorType vecType = loadOp.getVectorType();
- // Only unroll >1D loads
if (vecType.getRank() <= 1)
return failure();
+ auto targetShape = getTargetShape(options, loadOp);
+ if (!targetShape)
+ return failure();
+
Location loc = loadOp.getLoc();
ArrayRef<int64_t> originalShape = vecType.getShape();
-
- // Target type is a 1D vector of the innermost dimension.
- auto targetType =
- VectorType::get(originalShape.back(), vecType.getElementType());
-
- // Extend the targetShape to the same rank of original shape by padding 1s
- // for leading dimensions for convenience of computing offsets
- SmallVector<int64_t> targetShape(originalShape.size(), 1);
- targetShape.back() = originalShape.back();
+ SmallVector<int64_t> strides(targetShape->size(), 1);
Value result = rewriter.create<arith::ConstantOp>(
loc, vecType, rewriter.getZeroAttr(vecType));
@@ -699,15 +679,20 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
loadOp.getIndices().end());
+ SmallVector<int64_t> loopOrder =
+ getUnrollOrder(originalShape.size(), loadOp, options);
+
+ auto targetVecType =
+ VectorType::get(*targetShape, vecType.getElementType());
+
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalShape, targetShape)) {
+ StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
SmallVector<Value> indices =
computeIndices(rewriter, loc, originalIndices, offsets);
- Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
+ Value slice = rewriter.create<vector::LoadOp>(loc, targetVecType,
loadOp.getBase(), indices);
- // Insert the slice into the result at the correct position.
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, slice, result, offsets, SmallVector<int64_t>({1}));
+ loc, slice, result, offsets, strides);
}
rewriter.replaceOp(loadOp, result);
return success();
@@ -717,17 +702,6 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
vector::UnrollVectorOptions options;
};
-// This pattern unrolls the vector store into multiple 1D vector stores by
-// extracting slices from the source vector and storing them into the
-// destination.
-// Following,
-// vector.store %source, %base[%indices] : vector<4x4xf32>
-// is converted to :
-// %slice_0 = vector.extract %source[0] : vector<4xf32>
-// vector.store %slice_0, %base[%indices] : vector<4xf32>
-// %slice_1 = vector.extract %source[1] : vector<4xf32>
-// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
-// ...
struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
UnrollStorePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
@@ -737,17 +711,16 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
VectorType vecType = storeOp.getVectorType();
- // Only unroll >1D stores.
if (vecType.getRank() <= 1)
return failure();
+ auto targetShape = getTargetShape(options, storeOp);
+ if (!targetShape)
+ return failure();
+
Location loc = storeOp.getLoc();
ArrayRef<int64_t> originalShape = vecType.getShape();
-
- // Extend the targetShape to the same rank of original shape by padding 1s
- // for leading dimensions for convenience of computing offsets
- SmallVector<int64_t> targetShape(originalShape.size(), 1);
- targetShape.back() = originalShape.back();
+ SmallVector<int64_t> strides(targetShape->size(), 1);
Value base = storeOp.getBase();
Value vector = storeOp.getValueToStore();
@@ -755,12 +728,15 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
storeOp.getIndices().end());
+ SmallVector<int64_t> loopOrder =
+ getUnrollOrder(originalShape.size(), storeOp, options);
+
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalShape, targetShape)) {
+ StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
SmallVector<Value> indices =
computeIndices(rewriter, loc, originalIndices, offsets);
- offsets.pop_back();
- Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
+ Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, vector, offsets, *targetShape, strides);
rewriter.create<vector::StoreOp>(loc, slice, base, indices);
}
rewriter.eraseOp(storeOp);
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 23344a400bcc7..e129cd5c40b9c 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -388,19 +388,17 @@ func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> {
// CHECK-LABEL: func.func @vector_load_2D(
// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
- // CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
- // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
- // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
- // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
- // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+ // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
+ // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
+ // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
+ // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+ // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
// CHECK: return %[[V7]] : vector<4x4xf16>
@@ -412,48 +410,13 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
// CHECK-LABEL: func.func @vector_store_2D(
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
- // CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
- // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
- // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
- // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
- // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
- // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-
-
-func.func @vector_load_4D_to_2D(%mem: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
- %c1 = arith.constant 1 : index
- %0 = vector.load %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
- return %0 : vector<2x2xf16>
-}
-
-// CHECK-LABEL: func.func @vector_load_4D_to_2D(
-// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
- // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
- // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
- // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
- // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
- // CHECK: return %[[V3]] : vector<2x2xf16>
-
-func.func @vector_store_2D_to_4D(%mem: memref<4x4x4x4xf16>, %v: vector<2x2xf16>) {
- %c1 = arith.constant 1 : index
- vector.store %v, %mem[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
- return
-}
-
-// CHECK-LABEL: func.func @vector_store_2D_to_4D(
-// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
- // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
- // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
- // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+ // CHECK: %[[V0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
+ // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
+ // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
+ // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+ // CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
+ // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
+ // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
+ // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 023a6706b58be..fc75b273a057b 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -163,7 +163,7 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(
isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
- vector::BroadcastOp>(op));
+ vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
@@ -178,16 +178,6 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::TransposeOp>(op));
}));
- populateVectorUnrollPatterns(
- patterns, UnrollVectorOptions()
- .setNativeShape(ArrayRef<int64_t>{2, 2})
- .setFilterConstraint([](Operation *op) {
- if (auto loadOp = dyn_cast<vector::LoadOp>(op))
- return success(loadOp.getType().getRank() > 1);
- if (auto storeOp = dyn_cast<vector::StoreOp>(op))
- return success(storeOp.getVectorType().getRank() > 1);
- return failure();
- }));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
More information about the Mlir-commits
mailing list