[Mlir-commits] [mlir] 9c1ce31 - [mlir][vector] Add unroll patterns for vector.load and vector.store (#143420)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 20 13:50:29 PDT 2025


Author: Nishant Patel
Date: 2025-06-20T13:50:25-07:00
New Revision: 9c1ce31f546368d6296bc881e9f576ad25c20c73

URL: https://github.com/llvm/llvm-project/commit/9c1ce31f546368d6296bc881e9f576ad25c20c73
DIFF: https://github.com/llvm/llvm-project/commit/9c1ce31f546368d6296bc881e9f576ad25c20c73.diff

LOG: [mlir][vector] Add unroll patterns for vector.load and vector.store (#143420)

This PR adds unroll patterns for vector.load and vector.store. This PR is follow up of #137558

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aef156c5f1d05..926a92eff2ebb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1736,7 +1736,9 @@ def Vector_TransferWriteOp :
   let hasVerifier = 1;
 }
 
-def Vector_LoadOp : Vector_Op<"load"> {
+def Vector_LoadOp : Vector_Op<"load", [
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+  ]> {
   let summary = "reads an n-D slice of memory into an n-D vector";
   let description = [{
     The 'vector.load' operation reads an n-D slice of memory into an n-D
@@ -1822,7 +1824,9 @@ def Vector_LoadOp : Vector_Op<"load"> {
       "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
 }
 
-def Vector_StoreOp : Vector_Op<"store"> {
+def Vector_StoreOp : Vector_Op<"store", [
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+  ]> {
   let summary = "writes an n-D vector to an n-D slice of memory";
   let description = [{
     The 'vector.store' operation writes an n-D vector to an n-D slice of memory.

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6f0ac6bb58282..ee9ab61b670c4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5371,6 +5371,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
   return OpFoldResult();
 }
 
+std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getVectorType().getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -5406,6 +5410,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
   return memref::foldMemRefCast(*this);
 }
 
+std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getVectorType().getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedLoadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fc443ab0d138e..693f4f955994d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -54,6 +54,28 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
   return slicedIndices;
 }
 
+// Compute the new indices by adding `offsets` to `originalIndices`.
+// If m < n (m = offsets.size(), n = originalIndices.size()),
+// then only the trailing m values in `originalIndices` are updated.
+static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter,
+                                                Location loc,
+                                                OperandRange originalIndices,
+                                                ArrayRef<int64_t> offsets) {
+  assert(offsets.size() <= originalIndices.size() &&
+         "Offsets should not exceed the number of original indices");
+  SmallVector<Value> indices(originalIndices);
+
+  auto start = indices.size() - offsets.size();
+  for (auto [i, offset] : llvm::enumerate(offsets)) {
+    if (offset != 0) {
+      indices[start + i] = rewriter.create<arith::AddIOp>(
+          loc, originalIndices[start + i],
+          rewriter.create<arith::ConstantIndexOp>(loc, offset));
+    }
+  }
+  return indices;
+}
+
 // Clones `op` into a new operations that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +653,90 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
+  UnrollLoadPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = loadOp.getVectorType();
+
+    auto targetShape = getTargetShape(options, loadOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = loadOp.getLoc();
+    ArrayRef<int64_t> originalShape = vecType.getShape();
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, vecType, rewriter.getZeroAttr(vecType));
+
+    SmallVector<int64_t> loopOrder =
+        getUnrollOrder(originalShape.size(), loadOp, options);
+
+    auto targetVecType =
+        VectorType::get(*targetShape, vecType.getElementType());
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
+      SmallVector<Value> indices =
+          sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
+      Value slicedLoad = rewriter.create<vector::LoadOp>(
+          loc, targetVecType, loadOp.getBase(), indices);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, slicedLoad, result, offsets, strides);
+    }
+    rewriter.replaceOp(loadOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
+  UnrollStorePattern(MLIRContext *context,
+                     const vector::UnrollVectorOptions &options,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = storeOp.getVectorType();
+
+    auto targetShape = getTargetShape(options, storeOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = storeOp.getLoc();
+    ArrayRef<int64_t> originalShape = vecType.getShape();
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+
+    Value base = storeOp.getBase();
+    Value vector = storeOp.getValueToStore();
+
+    SmallVector<int64_t> loopOrder =
+        getUnrollOrder(originalShape.size(), storeOp, options);
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
+      SmallVector<Value> indices =
+          sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
+      Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, vector, offsets, *targetShape, strides);
+      rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   UnrollBroadcastPattern(MLIRContext *context,
                          const vector::UnrollVectorOptions &options,
@@ -699,10 +805,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  patterns
-      .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
-           UnrollContractionPattern, UnrollElementwisePattern,
-           UnrollReductionPattern, UnrollMultiReductionPattern,
-           UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
-          patterns.getContext(), options, benefit);
+  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+               UnrollContractionPattern, UnrollElementwisePattern,
+               UnrollReductionPattern, UnrollMultiReductionPattern,
+               UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
+               UnrollStorePattern, UnrollBroadcastPattern>(
+      patterns.getContext(), options, benefit);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index fbb178fb49d87..e129cd5c40b9c 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -378,3 +378,45 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
 //       CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
 //       CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
 //       CHECK: return [[r3]] : vector<4x4xf32>
+
+
+func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return %0 : vector<4x4xf16>
+}
+
+// CHECK-LABEL: func.func @vector_load_2D(
+// CHECK-SAME:  %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
+  // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+  // CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
+  // CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
+  // CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
+  // CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+  // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
+  // CHECK: return %[[V7]] : vector<4x4xf16>
+
+
+func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
+  %c0 = arith.constant 0 : index
+  vector.store %v, %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return
+}
+
+// CHECK-LABEL: func.func @vector_store_2D(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[V0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
+  // CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
+  // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
+  // CHECK: vector.store %[[V1]], %[[ARG0]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+  // CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
+  // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
+  // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
+  // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54aa96ba89a00..71000b98fb8f9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -163,7 +163,8 @@ struct TestVectorUnrollingPatterns
             .setFilterConstraint([](Operation *op) {
               return success(
                   isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
-                      vector::BroadcastOp>(op));
+                      vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(
+                      op));
             }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()


        


More information about the Mlir-commits mailing list