[Mlir-commits] [mlir] [MLIR] [Vector] Linearization patterns for vector.load and vector.store (PR #145115)
Nishant Patel
llvmlistbot at llvm.org
Fri Jul 11 07:35:42 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/145115
>From 8c3dbf88a7a190c3134992bb4cb3f4bcf133cfe8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 19 Jun 2025 19:40:07 +0000
Subject: [PATCH 1/6] Linearization patterns for vector.load and vector.store
---
.../Vector/Transforms/VectorLinearize.cpp | 71 ++++++++++++++++++-
mlir/test/Dialect/Vector/linearize.mlir | 23 ++++++
2 files changed, 92 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
}
};
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = loadOp.getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+ auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+ loadOp.getLoc(), vecTy, newLoad.getResult());
+ rewriter.replaceOp(loadOp, shapeCast.getResult());
+ return success();
+ }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.store %arg0, %arg1[%c0, %c0]
+/// : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// vector.store %arg0, %arg1[%c0, %%c0]
+/// : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = storeOp.getValueToStore().getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+
+ Value valueToStore = adaptor.getValueToStore();
+ if (valueToStore.getType() != linearTy) {
+ valueToStore = rewriter.create<vector::ShapeCastOp>(
+ storeOp.getLoc(), linearTy, valueToStore);
+ }
+
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
- typeConverter, patterns.getContext());
+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+ LinearizeVectorStore>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
return %0 : vector<1x[16]xi1>
}
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+ // CHECK: return %[[CAST]] : vector<1x4xf32>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return
+}
>From 92b299f07846ad55821075f08df5cba598c1766f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Jun 2025 19:14:00 +0000
Subject: [PATCH 2/6] Address comments
---
.../Vector/Transforms/VectorLinearize.cpp | 48 ++++++++-----------
mlir/test/Dialect/Vector/linearize.mlir | 16 +++----
2 files changed, 29 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index f0b77da5acd02..890d882ea2129 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,9 +623,9 @@ struct LinearizeVectorCreateMask final
}
};
-/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
-/// It currently supports only lineariztion of <1XN> to <N>
-/// Following,
+/// This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
+/// It currently supports linearization where all but the last dimension are 1
+/// The following,
/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
/// is converted to:
/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
@@ -640,27 +640,27 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType vecTy = loadOp.getType();
- if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
- return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
- auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
- vecTy.isScalable());
+ if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1),
+ [](auto d) { return d == 1; }))
+ return rewriter.notifyMatchFailure(loadOp,
+ "only vector<1x1x...xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape().back(),
+ vecTy.getElementType(), vecTy.isScalable());
auto newLoad = rewriter.create<vector::LoadOp>(
loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
- auto shapeCast = rewriter.create<vector::ShapeCastOp>(
- loadOp.getLoc(), vecTy, newLoad.getResult());
- rewriter.replaceOp(loadOp, shapeCast.getResult());
+ rewriter.replaceOp(loadOp, newLoad.getResult());
return success();
}
};
-/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
-/// It currently supports only lineariztion of <1XN> to <N>
-/// Following,
-/// vector.store %arg0, %arg1[%c0, %c0]
+/// This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
+/// It currently supports linearization where all but the last dimension are 1
+/// The following,
+/// vector.store %arg0, %arg1[%c0, %c0]s
/// : vector<1x4xf32>, memref<1x4xf32>
/// is converted to:
/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
-/// vector.store %arg0, %arg1[%c0, %%c0]
+/// vector.store %arg0, %arg1[%c0, %c0]
/// : vector<4xf32>, memref<1x4xf32>
struct LinearizeVectorStore final
: public OpConversionPattern<vector::StoreOp> {
@@ -673,19 +673,13 @@ struct LinearizeVectorStore final
matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType vecTy = storeOp.getValueToStore().getType();
- if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
- return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
- auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
- vecTy.isScalable());
-
- Value valueToStore = adaptor.getValueToStore();
- if (valueToStore.getType() != linearTy) {
- valueToStore = rewriter.create<vector::ShapeCastOp>(
- storeOp.getLoc(), linearTy, valueToStore);
- }
-
+ if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1),
+ [](auto d) { return d == 1; }))
+ return rewriter.notifyMatchFailure(storeOp,
+ "only vector<1x1x...xN> supported");
rewriter.replaceOpWithNewOp<vector::StoreOp>(
- storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+ storeOp, adaptor.getValueToStore(), adaptor.getBase(),
+ adaptor.getIndices());
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index fa0436792d3f0..9a017ceedcebe 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -466,24 +466,24 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
}
// CHECK-LABEL: linearize_vector_load
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
-func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<2x8xf32>) -> vector<1x4xf32> {
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
- // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
// CHECK: return %[[CAST]] : vector<1x4xf32>
%c0 = arith.constant 0 : index
- %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32>
return %0 : vector<1x4xf32>
}
// CHECK-LABEL: linearize_vector_store
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
-func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>) {
// CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
- // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32>
%c0 = arith.constant 0 : index
- vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32>
return
}
>From f700fe6aa6dd06d85216f147fd33108e1ae5f352 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Jun 2025 19:24:14 +0000
Subject: [PATCH 3/6] Fix
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 890d882ea2129..99d18fec18120 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -644,8 +644,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
[](auto d) { return d == 1; }))
return rewriter.notifyMatchFailure(loadOp,
"only vector<1x1x...xN> supported");
- auto linearTy = VectorType::get(vecTy.getShape().back(),
- vecTy.getElementType(), vecTy.isScalable());
+ auto linearTy = typeConverter->convertType<VectorType>(loadOp.getType());
auto newLoad = rewriter.create<vector::LoadOp>(
loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOp(loadOp, newLoad.getResult());
>From bcb5306d3a3379530ad58e4b4927ad230451b732 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 3 Jul 2025 15:16:40 +0000
Subject: [PATCH 4/6] Address feedback
---
.../Vector/Transforms/VectorLinearize.cpp | 33 ++++++++++++++++---
mlir/test/Dialect/Vector/linearize.mlir | 23 +++++++++++++
2 files changed, 51 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 99d18fec18120..fa5f9bbd2dbf6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -640,11 +640,23 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType vecTy = loadOp.getType();
- if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1),
- [](auto d) { return d == 1; }))
+ if (!vecTy)
+ return rewriter.notifyMatchFailure(loadOp, "expected vector type");
+
+ auto shape = vecTy.getShape();
+ auto scalableDims = vecTy.getScalableDims();
+ // All but the last dim must be 1, and only the last dim may be scalable (if
+ // any).
+ if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
return rewriter.notifyMatchFailure(loadOp,
"only vector<1x1x...xN> supported");
- auto linearTy = typeConverter->convertType<VectorType>(loadOp.getType());
+
+ if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
+ return rewriter.notifyMatchFailure(loadOp,
+ "only innermost dim may be scalable");
+
+ auto linearTy = typeConverter->convertType<VectorType>(vecTy);
+
auto newLoad = rewriter.create<vector::LoadOp>(
loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOp(loadOp, newLoad.getResult());
@@ -672,10 +684,21 @@ struct LinearizeVectorStore final
matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType vecTy = storeOp.getValueToStore().getType();
- if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1),
- [](auto d) { return d == 1; }))
+ if (!vecTy)
+ return rewriter.notifyMatchFailure(storeOp, "expected vector type");
+
+ auto shape = vecTy.getShape();
+ auto scalableDims = vecTy.getScalableDims();
+ // All but the last dim must be 1, and only the last dim may be scalable (if
+ // any).
+ if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
return rewriter.notifyMatchFailure(storeOp,
"only vector<1x1x...xN> supported");
+
+ if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
+ return rewriter.notifyMatchFailure(storeOp,
+ "only innermost dim may be scalable");
+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
storeOp, adaptor.getValueToStore(), adaptor.getBase(),
adaptor.getIndices());
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9a017ceedcebe..11780abfc6141 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -487,3 +487,26 @@ func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>
vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32>
return
}
+
+// CHECK-LABEL: linearize_vector_load_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x[4]xf32>
+func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4]xf32> {
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<[4]xf32> to vector<1x[4]xf32>
+ // CHECK: return %[[CAST]] : vector<1x[4]xf32
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
+ return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x[4]xf32>)
+func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector<1x[4]xf32>) {
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x[4]xf32> to vector<[4]xf32>
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
+ return
+}
>From 62fc4473a9feaddd4239c70cd349ced2fe220d60 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 7 Jul 2025 20:58:34 +0000
Subject: [PATCH 5/6] Missing brace
---
mlir/test/Dialect/Vector/linearize.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 11780abfc6141..1ad08b9387b08 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -494,7 +494,7 @@ func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<[4]xf32> to vector<1x[4]xf32>
- // CHECK: return %[[CAST]] : vector<1x[4]xf32
+ // CHECK: return %[[CAST]] : vector<1x[4]xf32>
%c0 = arith.constant 0 : index
%0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
>From 351b0eeabba5fdbe015318ded44a019f8e6b1689 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 11 Jul 2025 14:33:14 +0000
Subject: [PATCH 6/6] Add comments
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index fa5f9bbd2dbf6..0ebb477038e86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -630,6 +630,8 @@ struct LinearizeVectorCreateMask final
/// is converted to:
/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+/// For generic cases, the vector unroll pass should be used to unroll the load
+/// to vector<1x1x...xN> form and then linearized
struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
@@ -673,6 +675,8 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
/// vector.store %arg0, %arg1[%c0, %c0]
/// : vector<4xf32>, memref<1x4xf32>
+/// For generic cases, the vector unroll pass should be used to unroll the store
+/// to vector<1x1x...xN> form and then linearized
struct LinearizeVectorStore final
: public OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
More information about the Mlir-commits
mailing list