[Mlir-commits] [mlir] [mlir][vector] Add a check to ensure input vector rank equals target shape rank (PR #149239)
Longsheng Mou
llvmlistbot at llvm.org
Wed Jul 23 08:13:31 PDT 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/149239
>From b46be720fc19bb9bc82edbff1df36f1cafcef22d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 17 Jul 2025 11:33:46 +0800
Subject: [PATCH] [mlir][vector] Add a check to ensure input vector rank equals
target shape rank
The crash is caused because, during IR transformation, the
vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which
is not supported.
---
.../Dialect/Vector/Transforms/VectorUnroll.cpp | 10 +++++++++-
.../Dialect/Vector/vector-transfer-unroll.mlir | 16 ++++++++++++++++
2 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 693f4f955994d..734a8590eedb7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -169,7 +169,7 @@ struct UnrollTransferReadPattern
auto sourceVectorType = readOp.getVectorType();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = readOp.getLoc();
- ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
+ ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
@@ -224,6 +224,14 @@ struct UnrollTransferWritePattern
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = writeOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+ // Bail-out if rank(source) != rank(target). The main limitation here is the
+ // fact that `ExtractStridedSlice` requires the rank for the input and
+ // output to match. If needed, we can relax this later.
+ if (originalSize.size() != targetShape->size())
+ return rewriter.notifyMatchFailure(
+ writeOp,
+ "expected source input vector rank to match target shape rank");
+
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<int64_t> loopOrder =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 5dd65ea132d08..81e2c8dbd6283 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -365,3 +365,19 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
%res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref<?x?x?xf32>, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32>
return %res : vector<6x4xf32>
}
+
+// -----
+
+// Ensure that cases with mismatched target and source
+// shape ranks do not lead to a crash.
+
+// CHECK-LABEL: func @negative_vector_transfer_write
+// CHECK-NOT: vector.extract_strided_slice
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @negative_vector_transfer_write(%arg0: vector<6x34x62xi8>) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() : memref<6x34x62xi8>
+ vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
+ return
+}
More information about the Mlir-commits
mailing list