[Mlir-commits] [mlir] [MLIR] [Vector] Linearization patterns for vector.load and vector.store (PR #145115)

Nishant Patel llvmlistbot at llvm.org
Fri Jul 11 07:35:42 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/145115

>From 8c3dbf88a7a190c3134992bb4cb3f4bcf133cfe8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 19 Jun 2025 19:40:07 +0000
Subject: [PATCH 1/6] Linearization patterns for vector.load and vector.store

---
 .../Vector/Transforms/VectorLinearize.cpp     | 71 ++++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       | 23 ++++++
 2 files changed, 92 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
   }
 };
 
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+///   vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = loadOp.getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+    auto newLoad = rewriter.create<vector::LoadOp>(
+        loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+    auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+        loadOp.getLoc(), vecTy, newLoad.getResult());
+    rewriter.replaceOp(loadOp, shapeCast.getResult());
+    return success();
+  }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.store %arg0, %arg1[%c0, %c0]
+///     : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+///   vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///   vector.store %arg0, %arg1[%c0, %%c0]
+///     : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = storeOp.getValueToStore().getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+
+    Value valueToStore = adaptor.getValueToStore();
+    if (valueToStore.getType() != linearTy) {
+      valueToStore = rewriter.create<vector::ShapeCastOp>(
+          storeOp.getLoc(), linearTy, valueToStore);
+    }
+
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+           LinearizeVectorStore>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
   %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
   return %0 : vector<1x[16]xi1>
 }
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+  // CHECK: return %[[CAST]] : vector<1x4xf32>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return
+}

>From 92b299f07846ad55821075f08df5cba598c1766f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Jun 2025 19:14:00 +0000
Subject: [PATCH 2/6] Address comments

---
 .../Vector/Transforms/VectorLinearize.cpp     | 48 ++++++++-----------
 mlir/test/Dialect/Vector/linearize.mlir       | 16 +++----
 2 files changed, 29 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index f0b77da5acd02..890d882ea2129 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,9 +623,9 @@ struct LinearizeVectorCreateMask final
   }
 };
 
-/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
-/// It currently supports only lineariztion of <1XN> to <N>
-/// Following,
+/// This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
+/// It currently supports linearization where all but the last dimension are 1
+/// The following,
 ///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
 /// is converted to:
 ///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
@@ -640,27 +640,27 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
   matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType vecTy = loadOp.getType();
-    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
-      return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
-    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
-                                    vecTy.isScalable());
+    if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1),
+                                [](auto d) { return d == 1; }))
+      return rewriter.notifyMatchFailure(loadOp,
+                                         "only vector<1x1x...xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape().back(),
+                                    vecTy.getElementType(), vecTy.isScalable());
     auto newLoad = rewriter.create<vector::LoadOp>(
         loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
-    auto shapeCast = rewriter.create<vector::ShapeCastOp>(
-        loadOp.getLoc(), vecTy, newLoad.getResult());
-    rewriter.replaceOp(loadOp, shapeCast.getResult());
+    rewriter.replaceOp(loadOp, newLoad.getResult());
     return success();
   }
 };
 
-/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
-/// It currently supports only lineariztion of <1XN> to <N>
-/// Following,
-///   vector.store %arg0, %arg1[%c0, %c0]
+/// This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
+/// It currently supports linearization where all but the last dimension are 1
+/// The following,
+///   vector.store %arg0, %arg1[%c0, %c0]s
 ///     : vector<1x4xf32>, memref<1x4xf32>
 /// is converted to:
 ///   vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
-///   vector.store %arg0, %arg1[%c0, %%c0]
+///   vector.store %arg0, %arg1[%c0, %c0]
 ///     : vector<4xf32>, memref<1x4xf32>
 struct LinearizeVectorStore final
     : public OpConversionPattern<vector::StoreOp> {
@@ -673,19 +673,13 @@ struct LinearizeVectorStore final
   matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType vecTy = storeOp.getValueToStore().getType();
-    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
-      return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
-    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
-                                    vecTy.isScalable());
-
-    Value valueToStore = adaptor.getValueToStore();
-    if (valueToStore.getType() != linearTy) {
-      valueToStore = rewriter.create<vector::ShapeCastOp>(
-          storeOp.getLoc(), linearTy, valueToStore);
-    }
-
+    if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1),
+                                [](auto d) { return d == 1; }))
+      return rewriter.notifyMatchFailure(storeOp,
+                                         "only vector<1x1x...xN> supported");
     rewriter.replaceOpWithNewOp<vector::StoreOp>(
-        storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+        storeOp, adaptor.getValueToStore(), adaptor.getBase(),
+        adaptor.getIndices());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index fa0436792d3f0..9a017ceedcebe 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -466,24 +466,24 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
 }
 
 // CHECK-LABEL: linearize_vector_load
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
-func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<2x8xf32>) -> vector<1x4xf32> {
   // CHECK: %[[CST0:.*]] = arith.constant 0 : index
-  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32>
   // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
   // CHECK: return %[[CAST]] : vector<1x4xf32>
   %c0 = arith.constant 0 : index
-  %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32>
   return %0 : vector<1x4xf32>
 }
 
 // CHECK-LABEL: linearize_vector_store
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
-func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>) {
   // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
   // CHECK: %[[CST0:.*]] = arith.constant 0 : index
-  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32>
   %c0 = arith.constant 0 : index
-  vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32>
   return
 }

>From f700fe6aa6dd06d85216f147fd33108e1ae5f352 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Jun 2025 19:24:14 +0000
Subject: [PATCH 3/6] Fix

---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 890d882ea2129..99d18fec18120 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -644,8 +644,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
                                 [](auto d) { return d == 1; }))
       return rewriter.notifyMatchFailure(loadOp,
                                          "only vector<1x1x...xN> supported");
-    auto linearTy = VectorType::get(vecTy.getShape().back(),
-                                    vecTy.getElementType(), vecTy.isScalable());
+    auto linearTy = typeConverter->convertType<VectorType>(loadOp.getType());
     auto newLoad = rewriter.create<vector::LoadOp>(
         loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
     rewriter.replaceOp(loadOp, newLoad.getResult());

>From bcb5306d3a3379530ad58e4b4927ad230451b732 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 3 Jul 2025 15:16:40 +0000
Subject: [PATCH 4/6] Address feedback

---
 .../Vector/Transforms/VectorLinearize.cpp     | 33 ++++++++++++++++---
 mlir/test/Dialect/Vector/linearize.mlir       | 23 +++++++++++++
 2 files changed, 51 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 99d18fec18120..fa5f9bbd2dbf6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -640,11 +640,23 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
   matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType vecTy = loadOp.getType();
-    if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1),
-                                [](auto d) { return d == 1; }))
+    if (!vecTy)
+      return rewriter.notifyMatchFailure(loadOp, "expected vector type");
+
+    auto shape = vecTy.getShape();
+    auto scalableDims = vecTy.getScalableDims();
+    // All but the last dim must be 1, and only the last dim may be scalable (if
+    // any).
+    if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
       return rewriter.notifyMatchFailure(loadOp,
                                          "only vector<1x1x...xN> supported");
-    auto linearTy = typeConverter->convertType<VectorType>(loadOp.getType());
+
+    if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
+      return rewriter.notifyMatchFailure(loadOp,
+                                         "only innermost dim may be scalable");
+
+    auto linearTy = typeConverter->convertType<VectorType>(vecTy);
+
     auto newLoad = rewriter.create<vector::LoadOp>(
         loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
     rewriter.replaceOp(loadOp, newLoad.getResult());
@@ -672,10 +684,21 @@ struct LinearizeVectorStore final
   matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType vecTy = storeOp.getValueToStore().getType();
-    if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1),
-                                [](auto d) { return d == 1; }))
+    if (!vecTy)
+      return rewriter.notifyMatchFailure(storeOp, "expected vector type");
+
+    auto shape = vecTy.getShape();
+    auto scalableDims = vecTy.getScalableDims();
+    // All but the last dim must be 1, and only the last dim may be scalable (if
+    // any).
+    if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
       return rewriter.notifyMatchFailure(storeOp,
                                          "only vector<1x1x...xN> supported");
+
+    if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
+      return rewriter.notifyMatchFailure(storeOp,
+                                         "only innermost dim may be scalable");
+
     rewriter.replaceOpWithNewOp<vector::StoreOp>(
         storeOp, adaptor.getValueToStore(), adaptor.getBase(),
         adaptor.getIndices());
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9a017ceedcebe..11780abfc6141 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -487,3 +487,26 @@ func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>
   vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32>
   return
 }
+
+// CHECK-LABEL: linearize_vector_load_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x[4]xf32>
+func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4]xf32> {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<[4]xf32> to vector<1x[4]xf32>
+  // CHECK: return %[[CAST]] : vector<1x[4]xf32
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x[4]xf32>)
+func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector<1x[4]xf32>) {
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x[4]xf32> to vector<[4]xf32>
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
+  return
+}

>From 62fc4473a9feaddd4239c70cd349ced2fe220d60 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 7 Jul 2025 20:58:34 +0000
Subject: [PATCH 5/6] Missing brace

---
 mlir/test/Dialect/Vector/linearize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 11780abfc6141..1ad08b9387b08 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -494,7 +494,7 @@ func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4
   // CHECK: %[[CST0:.*]] = arith.constant 0 : index
   // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
   // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<[4]xf32> to vector<1x[4]xf32>
-  // CHECK: return %[[CAST]] : vector<1x[4]xf32
+  // CHECK: return %[[CAST]] : vector<1x[4]xf32>
   %c0 = arith.constant 0 : index
   %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>

>From 351b0eeabba5fdbe015318ded44a019f8e6b1689 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 11 Jul 2025 14:33:14 +0000
Subject: [PATCH 6/6] Add comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index fa5f9bbd2dbf6..0ebb477038e86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -630,6 +630,8 @@ struct LinearizeVectorCreateMask final
 /// is converted to:
 ///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
 ///   vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+/// For generic cases, the vector unroll pass should be used to unroll the load
+/// to vector<1x1x...xN> form and then linearized
 struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
@@ -673,6 +675,8 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
 ///   vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
 ///   vector.store %arg0, %arg1[%c0, %c0]
 ///     : vector<4xf32>, memref<1x4xf32>
+/// For generic cases, the vector unroll pass should be used to unroll the store
+/// to vector<1x1x...xN> form and then linearized
 struct LinearizeVectorStore final
     : public OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;



More information about the Mlir-commits mailing list