[Mlir-commits] [mlir] adfc1a9 - [MLIR][VectorOps] Fix crash in ShuffleOp inferReturnTypes (#185714)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 11 04:00:10 PDT 2026


Author: tudinhh
Date: 2026-03-11T11:00:05Z
New Revision: adfc1a95ad729b451f9163279f053248a1e17742

URL: https://github.com/llvm/llvm-project/commit/adfc1a95ad729b451f9163279f053248a1e17742
DIFF: https://github.com/llvm/llvm-project/commit/adfc1a95ad729b451f9163279f053248a1e17742.diff

LOG: [MLIR][VectorOps] Fix crash in ShuffleOp inferReturnTypes (#185714)

Validate the operand type in ShuffleOp::inferReturnTypes to prevent a
crash when the operation is parsed with scalar types instead of vectors.
It will stop gracefully now.

Fixes #185587

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7afcd7d3f88fb..927d35342cfdd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3272,10 +3272,13 @@ LogicalResult ShuffleOp::verify() {
 }
 
 LogicalResult
-ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
+ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
                             ShuffleOp::Adaptor adaptor,
                             SmallVectorImpl<Type> &inferredReturnTypes) {
-  auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
+  auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
+  if (!v1Type) {
+    return emitOptionalError(loc, "expected vector type");
+  }
   auto v1Rank = v1Type.getRank();
   // Construct resulting type: leading dimension matches mask
   // length, all trailing dimensions match the operands.

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 333d342d76103..3957455ccc76e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -119,6 +119,14 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
 
 // -----
 
+func.func @shuffle_scalar_input(%a: i8, %b: i8) {
+  // expected-error @+1 {{expected vector type}}
+  %shuffle = vector.shuffle %a, %b [0] : i8, i8
+  return
+}
+
+// -----
+
 func.func @extract_vector_type(%arg0: index) {
   // expected-error at +1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}}
   %1 = vector.extract %arg0[] : index from index


        


More information about the Mlir-commits mailing list