[Mlir-commits] [mlir] [MLIR] [Vector] Linearization patterns for vector.load and vector.store (PR #145115)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 23 09:56:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/145115.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+69-2)
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+23)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
}
};
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = loadOp.getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+ auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+ loadOp.getLoc(), vecTy, newLoad.getResult());
+ rewriter.replaceOp(loadOp, shapeCast.getResult());
+ return success();
+ }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.store %arg0, %arg1[%c0, %c0]
+/// : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// vector.store %arg0, %arg1[%c0, %%c0]
+/// : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = storeOp.getValueToStore().getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+
+ Value valueToStore = adaptor.getValueToStore();
+ if (valueToStore.getType() != linearTy) {
+ valueToStore = rewriter.create<vector::ShapeCastOp>(
+ storeOp.getLoc(), linearTy, valueToStore);
+ }
+
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
- typeConverter, patterns.getContext());
+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+ LinearizeVectorStore>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
return %0 : vector<1x[16]xi1>
}
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+ // CHECK: return %[[CAST]] : vector<1x4xf32>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/145115
More information about the Mlir-commits
mailing list