[Mlir-commits] [mlir] [mlir][vector] Fix parser of `vector.contract` (PR #133434)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 28 04:59:10 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR adds a check in the parser to prevent a crash when `vector.contract` lacks the `iterator_types` attribute.
Fixes #<!-- -->132886.
---
Full diff: https://github.com/llvm/llvm-project/pull/133434.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-1)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+8)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eccb3e578458e..5a3983699d5a3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -787,8 +787,13 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
// because tests still use the old format when 'iterator_types' attribute is
// represented as an array of strings.
// TODO: Remove this conversion once tests are fixed.
- ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
+ auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
result.attributes.get(getIteratorTypesAttrName(result.name)));
+ if (!iteratorTypes) {
+ return parser.emitError(loc)
+ << "expected " << getIteratorTypesAttrName(result.name)
+ << " array attribute";
+ }
SmallVector<Attribute> iteratorTypeAttrs;
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1b89e8eb5069b..ad18cfda6fe83 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1015,6 +1015,14 @@ func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg
// -----
+func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector<2xi32>, %arg2: vector<1xi32>) -> vector<1xi32> {
+ // expected-error at +1 {{'vector.contract' op expected "iterator_types" array attribute}}
+ %0 = vector.contract {} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2xi32> into vector<1xi32>
+ return %0 : vector<1xi32>
+}
+
+// -----
+
func.func @create_mask_0d_no_operands() {
%c1 = arith.constant 1 : index
// expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}
``````````
</details>
https://github.com/llvm/llvm-project/pull/133434
More information about the Mlir-commits
mailing list