[Mlir-commits] [mlir] adfc1a9 - [MLIR][VectorOps] Fix crash in ShuffleOp inferReturnTypes (#185714)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 11 04:00:10 PDT 2026
Author: tudinhh
Date: 2026-03-11T11:00:05Z
New Revision: adfc1a95ad729b451f9163279f053248a1e17742
URL: https://github.com/llvm/llvm-project/commit/adfc1a95ad729b451f9163279f053248a1e17742
DIFF: https://github.com/llvm/llvm-project/commit/adfc1a95ad729b451f9163279f053248a1e17742.diff
LOG: [MLIR][VectorOps] Fix crash in ShuffleOp inferReturnTypes (#185714)
Validate the operand type in ShuffleOp::inferReturnTypes to prevent a
crash when the operation is parsed with scalar types instead of vectors.
It will stop gracefully now.
Fixes #185587
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7afcd7d3f88fb..927d35342cfdd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3272,10 +3272,13 @@ LogicalResult ShuffleOp::verify() {
}
LogicalResult
-ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
+ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
ShuffleOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
- auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
+ auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
+ if (!v1Type) {
+ return emitOptionalError(loc, "expected vector type");
+ }
auto v1Rank = v1Type.getRank();
// Construct resulting type: leading dimension matches mask
// length, all trailing dimensions match the operands.
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 333d342d76103..3957455ccc76e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -119,6 +119,14 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
// -----
+func.func @shuffle_scalar_input(%a: i8, %b: i8) {
+ // expected-error @+1 {{expected vector type}}
+ %shuffle = vector.shuffle %a, %b [0] : i8, i8
+ return
+}
+
+// -----
+
func.func @extract_vector_type(%arg0: index) {
// expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}}
%1 = vector.extract %arg0[] : index from index
More information about the Mlir-commits
mailing list