[Mlir-commits] [mlir] [mlir][vector] Fix parser of `vector.contract` (PR #133434)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 28 04:59:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR adds a check in the parser to prevent a crash when `vector.contract` lacks the `iterator_types` attribute.
Fixes #<!-- -->132886.

---
Full diff: https://github.com/llvm/llvm-project/pull/133434.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-1) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+8) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eccb3e578458e..5a3983699d5a3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -787,8 +787,13 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   // because tests still use the old format when 'iterator_types' attribute is
   // represented as an array of strings.
   // TODO: Remove this conversion once tests are fixed.
-  ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
+  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
       result.attributes.get(getIteratorTypesAttrName(result.name)));
+  if (!iteratorTypes) {
+    return parser.emitError(loc)
+           << "expected " << getIteratorTypesAttrName(result.name)
+           << " array attribute";
+  }
 
   SmallVector<Attribute> iteratorTypeAttrs;
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1b89e8eb5069b..ad18cfda6fe83 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1015,6 +1015,14 @@ func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg
 
 // -----
 
+func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector<2xi32>, %arg2: vector<1xi32>) -> vector<1xi32> {
+  // expected-error at +1 {{'vector.contract' op expected "iterator_types" array attribute}}
+  %0 = vector.contract {} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2xi32> into vector<1xi32>
+  return %0 : vector<1xi32>
+}
+
+// -----
+
 func.func @create_mask_0d_no_operands() {
   %c1 = arith.constant 1 : index
   // expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/133434


More information about the Mlir-commits mailing list