[Mlir-commits] [mlir] [MLIR] [Vector] Linearization patterns for vector.load and vector.store (PR #145115)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 23 09:56:49 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/145115.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+69-2) 
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+23) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
   }
 };
 
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+///   vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = loadOp.getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+    auto newLoad = rewriter.create<vector::LoadOp>(
+        loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+    auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+        loadOp.getLoc(), vecTy, newLoad.getResult());
+    rewriter.replaceOp(loadOp, shapeCast.getResult());
+    return success();
+  }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.store %arg0, %arg1[%c0, %c0]
+///     : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+///   vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///   vector.store %arg0, %arg1[%c0, %%c0]
+///     : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = storeOp.getValueToStore().getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+
+    Value valueToStore = adaptor.getValueToStore();
+    if (valueToStore.getType() != linearTy) {
+      valueToStore = rewriter.create<vector::ShapeCastOp>(
+          storeOp.getLoc(), linearTy, valueToStore);
+    }
+
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+           LinearizeVectorStore>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
   %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
   return %0 : vector<1x[16]xi1>
 }
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+  // CHECK: return %[[CAST]] : vector<1x4xf32>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/145115


More information about the Mlir-commits mailing list