[Mlir-commits] [mlir] [mlir][vector] Fix parser of `vector.contract` (PR #133434)

Longsheng Mou llvmlistbot at llvm.org
Fri Mar 28 05:44:57 PDT 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/133434

>From 91dc8709eadd26857efe2525e475b4b85803e1a5 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 28 Mar 2025 19:47:27 +0800
Subject: [PATCH] [mlir][vector] Fix parser of `vector.contract`

This PR adds a check in the parser to prevent a crash when
`vector.contract` lacks the `iterator_types` attribute.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 7 ++++++-
 mlir/test/Dialect/Vector/invalid.mlir    | 8 ++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eccb3e578458e..5a3983699d5a3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -787,8 +787,13 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   // because tests still use the old format when 'iterator_types' attribute is
   // represented as an array of strings.
   // TODO: Remove this conversion once tests are fixed.
-  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
+  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
       result.attributes.get(getIteratorTypesAttrName(result.name)));
+  if (!iteratorTypes) {
+    return parser.emitError(loc)
+           << "expected " << getIteratorTypesAttrName(result.name)
+           << " array attribute";
+  }
 
   SmallVector<Attribute> iteratorTypeAttrs;
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1b89e8eb5069b..ea6d0021391fb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1015,6 +1015,14 @@ func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg
 
 // -----
 
+func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector<2xi32>, %arg2: vector<1xi32>) -> vector<1xi32> {
+  // expected-error at +1 {{'vector.contract' expected "iterator_types" array attribute}}
+  %0 = vector.contract {} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2xi32> into vector<1xi32>
+  return %0 : vector<1xi32>
+}
+
+// -----
+
 func.func @create_mask_0d_no_operands() {
   %c1 = arith.constant 1 : index
   // expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}



More information about the Mlir-commits mailing list