[Mlir-commits] [mlir] [mlir][vector] Emit error when `kind` attribute is not a CombiningKind (PR #173659)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 26 04:57:04 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR fixes a crash by validating the type of the `kind` attribute. For `vector.contract` and `vector.outerproduct`, the parser now emits an error when `kind` is not a CombiningKindAttr. Fixes #<!-- -->173555.
---
Full diff: https://github.com/llvm/llvm-project/pull/173659.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+24-9)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+29)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..162d18ea405bc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -874,12 +874,20 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
result.attributes.set(getIteratorTypesAttrName(result.name),
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
- if (!result.attributes.get(getKindAttrName(result.name))) {
+ StringAttr kindAttrName = getKindAttrName(result.name);
+ auto kindAttr = result.attributes.get(kindAttrName);
+ if (!kindAttr) {
result.addAttribute(
- getKindAttrName(result.name),
- CombiningKindAttr::get(result.getContext(),
- ContractionOp::getDefaultKind()));
+ kindAttrName, CombiningKindAttr::get(result.getContext(),
+ ContractionOp::getDefaultKind()));
+ } else {
+ if (!isa<CombiningKindAttr>(kindAttr)) {
+ return parser.emitError(parser.getNameLoc())
+ << "expected " << kindAttrName
+ << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
+ }
}
+
if (masksInfo.empty())
return success();
if (masksInfo.size() != 2)
@@ -4056,11 +4064,18 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
scalableDimsRes);
}
- if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
- result.attributes.append(
- OuterProductOp::getKindAttrName(result.name),
- CombiningKindAttr::get(result.getContext(),
- OuterProductOp::getDefaultKind()));
+ StringAttr kindAttrName = getKindAttrName(result.name);
+ auto kindAttr = result.attributes.get(kindAttrName);
+ if (!kindAttr) {
+ result.addAttribute(
+ kindAttrName, CombiningKindAttr::get(result.getContext(),
+ OuterProductOp::getDefaultKind()));
+ } else {
+ if (!isa<CombiningKindAttr>(kindAttr)) {
+ return parser.emitError(parser.getNameLoc())
+ << "expected " << kindAttrName
+ << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
+ }
}
return failure(
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 79b09e172145b..de8758ecef258 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -223,6 +223,14 @@ func.func @outerproduct_non_vector_operand(%arg0: f32) {
// -----
+func.func @outerproduct_invalid_kind_attr(%arg0 : vector<[4]xf32>, %arg1 : vector<[8]xf32>) {
+ // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+ %0 = vector.outerproduct %arg0, %arg1 {kind = "invalid"} : vector<[4]xf32>, vector<[8]xf32>
+ return
+}
+
+// -----
+
func.func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// expected-error at +1 {{expected 1-d vector for operand #1}}
%1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
@@ -994,6 +1002,27 @@ func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector
// -----
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+ ]
+#contraction_trait = {
+ kind = "invalid",
+ indexing_maps = #contraction_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+func.func @contraction_invalid_kind(%arg0: vector<4x3xf32>,
+ %arg1: vector<3x7xf32>,
+ %arg2: vector<4x7xf32>) {
+ // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+ %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
+ : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
+ return
+}
+
+// -----
+
func.func @create_mask_0d_no_operands() {
%c1 = arith.constant 1 : index
// expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}
``````````
</details>
https://github.com/llvm/llvm-project/pull/173659
More information about the Mlir-commits
mailing list