[Mlir-commits] [mlir] [MLIR] [Vector] Linearization patterns for vector.load and vector.store (PR #145115)
James Newling
llvmlistbot at llvm.org
Mon Jun 23 11:47:11 PDT 2025
================
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
}
};
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = loadOp.getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+ auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+ loadOp.getLoc(), vecTy, newLoad.getResult());
+ rewriter.replaceOp(loadOp, shapeCast.getResult());
+ return success();
----------------
newling wrote:
Because this is a type conversion, we get the type and shape_cast 'for free'
(Same for store op)
https://github.com/llvm/llvm-project/pull/145115
More information about the Mlir-commits
mailing list