[Mlir-commits] [mlir] [mlir][vector] Add linearization patterns for vector.load & vector.store (PR #137558)
Nishant Patel
llvmlistbot at llvm.org
Sun Apr 27 15:41:23 PDT 2025
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/137558
This PR is a breakdown [1 / 4] of the PR #136193
This PR adds linearization patterns for vector.load and vector.store ops. The current patterns only supports 2D vectors for now.
>From bc3a47ec691cf401069d119cd65196e43806023c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sat, 26 Apr 2025 17:29:44 +0000
Subject: [PATCH] Add linearization patterns for vector.load & vector.store
---
.../Vector/Transforms/VectorLinearize.cpp | 160 ++++++++++++++++--
mlir/test/Dialect/Vector/linearize.mlir | 110 ++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 3 +-
3 files changed, 261 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..e4d88de2cf4ae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,7 +26,12 @@
using namespace mlir;
+constexpr unsigned defaultTargetVectorBitWidth =
+ std::numeric_limits<unsigned>::max();
+
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+ if (targetBitWidth == 0)
+ return false;
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
@@ -82,7 +87,7 @@ struct LinearizeConstantLike final
LinearizeConstantLike(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -136,7 +141,7 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -175,7 +180,7 @@ struct LinearizeVectorExtractStridedSlice final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -289,7 +294,7 @@ struct LinearizeVectorShuffle final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -362,13 +367,17 @@ struct LinearizeVectorExtract final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtract(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Skip if result is not a vector type
+ if (!isa<VectorType>(extractOp.getType()))
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalar extract is not supported.");
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(extractOp,
@@ -425,7 +434,7 @@ struct LinearizeVectorInsert final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorInsert(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -506,7 +515,7 @@ struct LinearizeVectorBitCast final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorBitCast(
const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
@@ -531,12 +540,139 @@ struct LinearizeVectorBitCast final
unsigned targetVectorBitWidth;
};
+// clang-format off
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %result = arith.constant dense<0.0> : vector<4x4xf32>
+/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
+/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+/// ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+// clang-format on
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+ LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = loadOp->getLoc();
+ VectorType vecType = loadOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2)
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+
+ auto unrollCount = shape[0];
+ auto vecSize = shape[1];
+ VectorType newVecType =
+ VectorType::get({vecSize}, vecType.getElementType());
+
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ // Construct the 2D vector.
+ Value resultVec =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecType));
+ // Emit unrolled loads for each 1D vector slice.
+ for (auto i = 0; i < unrollCount; i++) {
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ auto vec = rewriter.create<vector::LoadOp>(loc, newVecType,
+ adaptor.getBase(), indices);
+ resultVec = rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+ }
+
+ rewriter.replaceOp(loadOp, resultVec);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %slice_0 = vector.extract %source[0] : vector<4xf32>
+/// vector.store %slice_0, %base[%indices] : vector<4xf32>
+/// %slice_1 = vector.extract %source[1] : vector<4xf32>
+/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+/// ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by
+/// extracting slices from the source vector and storing them into the
+/// destination. The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+ LinearizeVectorStore(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = storeOp->getLoc();
+ VectorType vecType = storeOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2)
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+
+ auto unrollCount = shape[0];
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ auto vec = rewriter.create<vector::ShapeCastOp>(loc, vecType,
+ adaptor.getValueToStore());
+
+ for (auto i = 0; i < unrollCount; i++) {
+ auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+ indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth) {
+ typeConverter.addConversion([](Type type) -> Type { return type; });
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
return type;
@@ -555,9 +691,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+ target.addLegalOp<vector::ShapeCastOp>();
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
+ if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp>(op) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -567,9 +704,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
- targetBitWidth);
+ patterns
+ .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+ LinearizeVectorLoad, LinearizeVectorStore>(
+ typeConverter, patterns.getContext(), targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..9e793c5dc8233 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,113 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+// ALL-LABEL: linearize_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @linearize_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+ // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+ // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
+ // DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+ // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
+ // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+ // DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+ // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
+ // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+ // DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+ // DEFAULT: %[[C3:.*]] = arith.constant 3 : index
+ // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+ // DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
+ // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
+ // DEFAULT: return %[[CAST]] : vector<4x4xf16>
+
+ // BW-128: %[[C1:.*]] = arith.constant 1 : index
+ // BW-128: %[[C2:.*]] = arith.constant 2 : index
+ // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
+ // BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+ // BW-128: %[[C1_0:.*]] = arith.constant 1 : index
+ // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+ // BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+ // BW-128: %[[C2_1:.*]] = arith.constant 2 : index
+ // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+ // BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+ // BW-128: %[[C3:.*]] = arith.constant 3 : index
+ // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+ // BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
+ // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
+ // BW-128: return %[[CAST]] : vector<4x4xf16>
+
+ // BW-0: %[[C1:.*]] = arith.constant 1 : index
+ // BW-0: %[[C2:.*]] = arith.constant 2 : index
+ // BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
+ // BW-0: return %[[LOAD]] : vector<4x4xf16>
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %0 = vector.load %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16>
+ return %0 : vector<4x4xf16>
+}
+
+// -----
+// ALL-LABEL: linearize_vector_store
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>, %[[ARG_1:.*]]: vector<4x4xf16>) {
+func.func @linearize_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+ // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16>
+ // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+ // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+ // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16>
+ // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16>
+ // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16>
+ // DEFAULT: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16>
+ // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
+ // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+ // DEFAULT: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16>
+ // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
+ // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+ // DEFAULT: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16>
+ // DEFAULT: %[[C3:.*]] = arith.constant 3 : index
+ // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+ // DEFAULT: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // DEFAULT: return
+
+ // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16>
+ // BW-128: %[[C1:.*]] = arith.constant 1 : index
+ // BW-128: %[[C2:.*]] = arith.constant 2 : index
+ // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16>
+ // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16>
+ // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16>
+ // BW-128: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16>
+ // BW-128: %[[C1_0:.*]] = arith.constant 1 : index
+ // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+ // BW-128: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16>
+ // BW-128: %[[C2_1:.*]] = arith.constant 2 : index
+ // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+ // BW-128: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16>
+ // BW-128: %[[C3:.*]] = arith.constant 3 : index
+ // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+ // BW-128: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+ // BW-128: return
+
+ // BW-0: %[[C1:.*]] = arith.constant 1 : index
+ // BW-0: %[[C2:.*]] = arith.constant 2 : index
+ // BW-0: vector.store %[[ARG_1]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
+ // BW-0: return
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg1, %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16>
+ return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03f907e46c2c6..14c7e9d554cd9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -852,7 +852,8 @@ struct TestVectorLinearize final
return "Linearizes ND vectors for N >= 2 into 1D vectors";
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect>();
+ registry.insert<vector::VectorDialect, memref::MemRefDialect,
+ arith::ArithDialect>();
}
Option<unsigned> targetVectorBitwidth{
More information about the Mlir-commits
mailing list