[Mlir-commits] [mlir] [mlir][vector] Add a check to ensure input vector rank equals target shape rank (PR #149239)
Longsheng Mou
llvmlistbot at llvm.org
Thu Jul 24 06:46:31 PDT 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/149239
>From b46be720fc19bb9bc82edbff1df36f1cafcef22d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 17 Jul 2025 11:33:46 +0800
Subject: [PATCH 1/3] [mlir][vector] Add a check to ensure input vector rank
equals target shape rank
The crash is caused because, during IR transformation, the
vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which
is not supported.
---
.../Dialect/Vector/Transforms/VectorUnroll.cpp | 10 +++++++++-
.../Dialect/Vector/vector-transfer-unroll.mlir | 16 ++++++++++++++++
2 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 693f4f955994d..734a8590eedb7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -169,7 +169,7 @@ struct UnrollTransferReadPattern
auto sourceVectorType = readOp.getVectorType();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = readOp.getLoc();
- ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
+ ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
@@ -224,6 +224,14 @@ struct UnrollTransferWritePattern
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = writeOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+ // Bail-out if rank(source) != rank(target). The main limitation here is the
+ // fact that `ExtractStridedSlice` requires the rank for the input and
+ // output to match. If needed, we can relax this later.
+ if (originalSize.size() != targetShape->size())
+ return rewriter.notifyMatchFailure(
+ writeOp,
+ "expected source input vector rank to match target shape rank");
+
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<int64_t> loopOrder =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 5dd65ea132d08..81e2c8dbd6283 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -365,3 +365,19 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
%res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref<?x?x?xf32>, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32>
return %res : vector<6x4xf32>
}
+
+// -----
+
+// Ensure that cases with mismatched target and source
+// shape ranks do not lead to a crash.
+
+// CHECK-LABEL: func @negative_vector_transfer_write
+// CHECK-NOT: vector.extract_strided_slice
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @negative_vector_transfer_write(%arg0: vector<6x34x62xi8>) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() : memref<6x34x62xi8>
+ vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
+ return
+}
>From ca644fd07ff1dbf4b0d60026ad19d0ec1c3ec763 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 24 Jul 2025 21:42:07 +0800
Subject: [PATCH 2/3] add note
---
mlir/test/Dialect/Vector/vector-transfer-unroll.mlir | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 81e2c8dbd6283..181e0609fb219 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -368,8 +368,10 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
// -----
-// Ensure that cases with mismatched target and source
-// shape ranks do not lead to a crash.
+// Ensure that cases with mismatched target and source shape ranks
+// do not lead to a crash.
+// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns`
+// is currently hard-coded to [2, 2].
// CHECK-LABEL: func @negative_vector_transfer_write
// CHECK-NOT: vector.extract_strided_slice
>From 543522d22ac9132d180e08005f2da2e8f46fd4dc Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 24 Jul 2025 21:46:05 +0800
Subject: [PATCH 3/3] change var name
---
.../Vector/vector-transfer-unroll.mlir | 36 +++++++++----------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 181e0609fb219..44601a4a47dda 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -68,6 +68,24 @@ func.func @transfer_write_unroll(%mem : memref<4x4xf32>, %vec : vector<4x4xf32>)
// -----
+// Ensure that cases with mismatched target and source shape ranks
+// do not lead to a crash.
+// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns`
+// is currently hard-coded to [2, 2].
+
+// CHECK-LABEL: func @negative_transfer_write
+// CHECK-NOT: vector.extract_strided_slice
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @negative_transfer_write(%vec: vector<6x34x62xi8>) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() : memref<6x34x62xi8>
+ vector.transfer_write %vec, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @transfer_readwrite_unroll
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -365,21 +383,3 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
%res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref<?x?x?xf32>, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32>
return %res : vector<6x4xf32>
}
-
-// -----
-
-// Ensure that cases with mismatched target and source shape ranks
-// do not lead to a crash.
-// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns`
-// is currently hard-coded to [2, 2].
-
-// CHECK-LABEL: func @negative_vector_transfer_write
-// CHECK-NOT: vector.extract_strided_slice
-// CHECK: vector.transfer_write
-// CHECK: return
-func.func @negative_vector_transfer_write(%arg0: vector<6x34x62xi8>) {
- %c0 = arith.constant 0 : index
- %alloc = memref.alloc() : memref<6x34x62xi8>
- vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
- return
-}
More information about the Mlir-commits
mailing list