[Mlir-commits] [mlir] [mlir][vector] Add a check to ensure input vector rank equals target shape rank (PR #149239)

Longsheng Mou llvmlistbot at llvm.org
Wed Jul 23 18:57:06 PDT 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/149239

>From b46be720fc19bb9bc82edbff1df36f1cafcef22d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 17 Jul 2025 11:33:46 +0800
Subject: [PATCH 1/2] [mlir][vector] Add a check to ensure input vector rank
 equals target shape rank

The crash is caused because, during IR transformation, the
vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which
is not supported.
---
 .../Dialect/Vector/Transforms/VectorUnroll.cpp   | 10 +++++++++-
 .../Dialect/Vector/vector-transfer-unroll.mlir   | 16 ++++++++++++++++
 2 files changed, 25 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 693f4f955994d..734a8590eedb7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -169,7 +169,7 @@ struct UnrollTransferReadPattern
     auto sourceVectorType = readOp.getVectorType();
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = readOp.getLoc();
-    ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
+    ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
 
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
@@ -224,6 +224,14 @@ struct UnrollTransferWritePattern
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = writeOp.getLoc();
     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+    // Bail-out if rank(source) != rank(target). The main limitation here is the
+    // fact that `ExtractStridedSlice` requires the rank for the input and
+    // output to match. If needed, we can relax this later.
+    if (originalSize.size() != targetShape->size())
+      return rewriter.notifyMatchFailure(
+          writeOp,
+          "expected source input vector rank to match target shape rank");
+
     SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
                                        writeOp.getIndices().end());
     SmallVector<int64_t> loopOrder =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 5dd65ea132d08..81e2c8dbd6283 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -365,3 +365,19 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
   %res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref<?x?x?xf32>, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32>
   return %res : vector<6x4xf32>
 }
+
+// -----
+
+// Ensure that cases with mismatched target and source
+// shape ranks do not lead to a crash.
+
+// CHECK-LABEL: func @negative_vector_transfer_write
+//   CHECK-NOT:   vector.extract_strided_slice
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @negative_vector_transfer_write(%arg0: vector<6x34x62xi8>) {
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() : memref<6x34x62xi8>
+  vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
+  return
+}

>From 02290d865b91267c762d4c181e5223a10987c594 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 24 Jul 2025 09:56:58 +0800
Subject: [PATCH 2/2] change var name

---
 .../Vector/vector-transfer-unroll.mlir        | 32 +++++++++----------
 1 file changed, 16 insertions(+), 16 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 81e2c8dbd6283..c7025044a0e1b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -68,6 +68,22 @@ func.func @transfer_write_unroll(%mem : memref<4x4xf32>, %vec : vector<4x4xf32>)
 
 // -----
 
+// Ensure that cases with mismatched target and source
+// shape ranks do not lead to a crash.
+
+// CHECK-LABEL: func @negative_transfer_write
+//   CHECK-NOT:   vector.extract_strided_slice
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @negative_transfer_write(%vec: vector<6x34x62xi8>) {
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() : memref<6x34x62xi8>
+  vector.transfer_write %vec, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @transfer_readwrite_unroll
 // CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
@@ -365,19 +381,3 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
   %res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref<?x?x?xf32>, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32>
   return %res : vector<6x4xf32>
 }
-
-// -----
-
-// Ensure that cases with mismatched target and source
-// shape ranks do not lead to a crash.
-
-// CHECK-LABEL: func @negative_vector_transfer_write
-//   CHECK-NOT:   vector.extract_strided_slice
-//       CHECK:   vector.transfer_write
-//       CHECK:   return
-func.func @negative_vector_transfer_write(%arg0: vector<6x34x62xi8>) {
-  %c0 = arith.constant 0 : index
-  %alloc = memref.alloc() : memref<6x34x62xi8>
-  vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
-  return
-}



More information about the Mlir-commits mailing list