[Mlir-commits] [mlir] [mlir][vector] Add unroll patterns for vector.load and vector.store (PR #143420)

Nishant Patel llvmlistbot at llvm.org
Mon Jun 9 13:48:12 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/143420

>From 5003057b2010149a95fda72b6dd395c918329408 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 9 Jun 2025 17:56:30 +0000
Subject: [PATCH 1/2] Add unroll patterns for vector.load and vector.store

---
 .../Vector/Transforms/VectorUnroll.cpp        | 123 +++++++++++++++++-
 .../Vector/vector-load-store-unroll.mlir      |  73 +++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  40 ++++++
 3 files changed, 234 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-load-store-unroll.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1cc477d9dca91..43abf84cd6428 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,6 +54,33 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
   return slicedIndices;
 }
 
+// compute the new indices for vector.load/store by adding offsets to
+// originalIndices.
+// It assumes m <= n (m = offsets.size(), n = originalIndices.size())
+// Last m of originalIndices will be updated.
+static SmallVector<Value> computeIndices(PatternRewriter &rewriter,
+                                         Location loc,
+                                         ArrayRef<Value> originalIndices,
+                                         ArrayRef<int64_t> offsets) {
+  assert(offsets.size() <= originalIndices.size() &&
+         "Offsets should not exceed the number of original indices");
+  SmallVector<Value> indices(originalIndices);
+  auto originalIter = originalIndices.rbegin();
+  auto offsetsIter = offsets.rbegin();
+  auto indicesIter = indices.rbegin();
+  while (offsetsIter != offsets.rend()) {
+    Value original = *originalIter;
+    int64_t offset = *offsetsIter;
+    if (offset != 0)
+      *indicesIter = rewriter.create<arith::AddIOp>(
+          loc, original, rewriter.create<arith::ConstantIndexOp>(loc, offset));
+    originalIter++;
+    offsetsIter++;
+    indicesIter++;
+  }
+  return indices;
+};
+
 // Clones `op` into a new operations that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +658,98 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
+  UnrollLoadPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = loadOp.getVectorType();
+    // Only unroll >1D loads
+    if (vecType.getRank() <= 1)
+      return failure();
+
+    Location loc = loadOp.getLoc();
+    ArrayRef<int64_t> originalShape = vecType.getShape();
+
+    // Target type is a 1D vector of the innermost dimension.
+    auto targetType =
+        VectorType::get(originalShape.back(), vecType.getElementType());
+
+    // Extend the targetShape to the same rank of original shape by padding 1s
+    // for leading dimensions for convenience of computing offsets
+    SmallVector<int64_t> targetShape(originalShape.size(), 1);
+    targetShape.back() = originalShape.back();
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, vecType, rewriter.getZeroAttr(vecType));
+
+    SmallVector<Value> originalIndices(loadOp.getIndices().begin(),
+                                       loadOp.getIndices().end());
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, targetShape)) {
+      SmallVector<Value> indices =
+          computeIndices(rewriter, loc, originalIndices, offsets);
+      Value slice = rewriter.create<vector::LoadOp>(loc, targetType,
+                                                    loadOp.getBase(), indices);
+      // Insert the slice into the result at the correct position.
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, slice, result, offsets, SmallVector<int64_t>({1}));
+    }
+    rewriter.replaceOp(loadOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
+  UnrollStorePattern(MLIRContext *context,
+                     const vector::UnrollVectorOptions &options,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = storeOp.getVectorType();
+    // Only unroll >1D stores.
+    if (vecType.getRank() <= 1)
+      return failure();
+
+    Location loc = storeOp.getLoc();
+    ArrayRef<int64_t> originalShape = vecType.getShape();
+
+    // Extend the targetShape to the same rank of original shape by padding 1s
+    // for leading dimensions for convenience of computing offsets
+    SmallVector<int64_t> targetShape(originalShape.size(), 1);
+    targetShape.back() = originalShape.back();
+
+    Value base = storeOp.getBase();
+    Value vector = storeOp.getValueToStore();
+
+    SmallVector<Value> originalIndices(storeOp.getIndices().begin(),
+                                       storeOp.getIndices().end());
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, targetShape)) {
+      SmallVector<Value> indices =
+          computeIndices(rewriter, loc, originalIndices, offsets);
+      offsets.pop_back();
+      Value slice = rewriter.create<vector::ExtractOp>(loc, vector, offsets);
+      rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -639,6 +758,6 @@ void mlir::vector::populateVectorUnrollPatterns(
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
-               UnrollTransposePattern, UnrollGatherPattern>(
-      patterns.getContext(), options, benefit);
+               UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
+               UnrollStorePattern>(patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
new file mode 100644
index 0000000000000..3135268b8d61b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: return %[[V7]] : vector<4x4xf16>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+  // CHECK: return %[[V3]] : vector<2x2xf16>
+  %c1 = arith.constant 1 : index
+  %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+  return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  %c1 = arith.constant 1 : index
+  vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..b2b2b4ece22cd 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -289,6 +289,44 @@ struct TestVectorTransferUnrollingPatterns
       llvm::cl::init(false)};
 };
 
+struct TestVectorLoadStoreUnrollPatterns
+    : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorLoadStoreUnrollPatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-load-store-unroll";
+  }
+  StringRef getDescription() const final {
+    return "Test unrolling patterns for vector.load and vector.store ops";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect, arith::ArithDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+
+    // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
+    vector::UnrollVectorOptions options;
+    options.setFilterConstraint([](Operation *op) {
+      if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+        return success(loadOp.getType().getRank() > 1);
+      if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+        return success(storeOp.getVectorType().getRank() > 1);
+      return failure();
+    });
+
+    vector::populateVectorUnrollPatterns(patterns, options);
+
+    // Apply the patterns
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestScalarVectorTransferLoweringPatterns
     : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1033,6 +1071,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferUnrollingPatterns>();
 
+  PassRegistration<TestVectorLoadStoreUnrollPatterns>();
+
   PassRegistration<TestScalarVectorTransferLoweringPatterns>();
 
   PassRegistration<TestVectorTransferOpt>();

>From 9d91abe8417b56bfb6b7e220b8fbbd050b8e03da Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 9 Jun 2025 20:43:56 +0000
Subject: [PATCH 2/2] Clean up

---
 .../Vector/vector-load-store-unroll.mlir      | 73 -------------------
 .../Dialect/Vector/vector-unroll-options.mlir | 73 +++++++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   | 50 +++----------
 3 files changed, 83 insertions(+), 113 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/vector-load-store-unroll.mlir

diff --git a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir b/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
deleted file mode 100644
index 3135268b8d61b..0000000000000
--- a/mlir/test/Dialect/Vector/vector-load-store-unroll.mlir
+++ /dev/null
@@ -1,73 +0,0 @@
-// RUN: mlir-opt %s -test-vector-load-store-unroll --split-input-file | FileCheck %s
-
-// CHECK-LABEL: func.func @unroll_2D_vector_load(
-// CHECK-SAME:  %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
-func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
-  // CHECK: %[[C3:.*]] = arith.constant 3 : index
-  // CHECK: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
-  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
-  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
-  // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-  // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
-  // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-  // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
-  // CHECK: return %[[V7]] : vector<4x4xf16>
-  %c0 = arith.constant 0 : index
-  %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
-  return %0 : vector<4x4xf16>
-}
-
-// CHECK-LABEL: func.func @unroll_2D_vector_store(
-// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
-func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
-  // CHECK: %[[C3:.*]] = arith.constant 3 : index
-  // CHECK: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
-  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
-  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-  // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
-  // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-  // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
-  // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-  %c0 = arith.constant 0 : index
-  vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
-  return
-}
-
-// CHECK-LABEL: func.func @unroll_vector_load(
-// CHECK-SAME:  %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
-func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
-  // CHECK: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
-  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
-  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
-  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
-  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
-  // CHECK: return %[[V3]] : vector<2x2xf16>
-  %c1 = arith.constant 1 : index
-  %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
-  return %0 : vector<2x2xf16>
-}
-
-// CHECK-LABEL: func.func @unroll_vector_store(
-// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
-func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
-  // CHECK: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
-  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
-  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
-  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
-  %c1 = arith.constant 1 : index
-  vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
-  return
-}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index fbb178fb49d87..efb709e41a69c 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -378,3 +378,76 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
 //       CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
 //       CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
 //       CHECK: return [[r3]] : vector<4x4xf32>
+
+
+// CHECK-LABEL: func.func @unroll_2D_vector_load(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
+  // CHECK: return %[[V7]] : vector<4x4xf16>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_2D_vector_store(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  // CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
+  // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return
+}
+
+// CHECK-LABEL: func.func @unroll_vector_load(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
+  // CHECK: return %[[V3]] : vector<2x2xf16>
+  %c1 = arith.constant 1 : index
+  %0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+  return %0 : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func.func @unroll_vector_store(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
+func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  // CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
+  %c1 = arith.constant 1 : index
+  vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 8014362a1a6ec..023a6706b58be 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,16 @@ struct TestVectorUnrollingPatterns
                         return success(isa<vector::TransposeOp>(op));
                       }));
 
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{2, 2})
+                      .setFilterConstraint([](Operation *op) {
+                        if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+                          return success(loadOp.getType().getRank() > 1);
+                        if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+                          return success(storeOp.getVectorType().getRank() > 1);
+                        return failure();
+                      }));
     if (unrollBasedOnType) {
       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
@@ -292,44 +302,6 @@ struct TestVectorTransferUnrollingPatterns
       llvm::cl::init(false)};
 };
 
-struct TestVectorLoadStoreUnrollPatterns
-    : public PassWrapper<TestVectorLoadStoreUnrollPatterns,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      TestVectorLoadStoreUnrollPatterns)
-
-  StringRef getArgument() const final {
-    return "test-vector-load-store-unroll";
-  }
-  StringRef getDescription() const final {
-    return "Test unrolling patterns for vector.load and vector.store ops";
-  }
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect, arith::ArithDialect>();
-  }
-
-  void runOnOperation() override {
-    MLIRContext *ctx = &getContext();
-    RewritePatternSet patterns(ctx);
-
-    // Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
-    vector::UnrollVectorOptions options;
-    options.setFilterConstraint([](Operation *op) {
-      if (auto loadOp = dyn_cast<vector::LoadOp>(op))
-        return success(loadOp.getType().getRank() > 1);
-      if (auto storeOp = dyn_cast<vector::StoreOp>(op))
-        return success(storeOp.getVectorType().getRank() > 1);
-      return failure();
-    });
-
-    vector::populateVectorUnrollPatterns(patterns, options);
-
-    // Apply the patterns
-    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
-  }
-};
-
 struct TestScalarVectorTransferLoweringPatterns
     : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1070,8 +1042,6 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferUnrollingPatterns>();
 
-  PassRegistration<TestVectorLoadStoreUnrollPatterns>();
-
   PassRegistration<TestScalarVectorTransferLoweringPatterns>();
 
   PassRegistration<TestVectorTransferOpt>();



More information about the Mlir-commits mailing list