[Mlir-commits] [mlir] [mlir][vector] Emit error when `kind` attribute is not a CombiningKind (PR #173659)
Longsheng Mou
llvmlistbot at llvm.org
Tue Dec 30 09:07:23 PST 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173659
>From 2f3a57cedc42bafe7d8914c41569ef344000b411 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 26 Dec 2025 20:54:04 +0800
Subject: [PATCH 1/2] [mlir][vector] Emit error when `kind` attribute is not a
CombiningKind
This PR fixes a crash by validating the type of the `kind` attribute. For
`vector.contract` and `vector.outerproduct`, the parser now emits an error
when `kind` is not a CombiningKindAttr.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 33 +++++++++++++++++-------
mlir/test/Dialect/Vector/invalid.mlir | 29 +++++++++++++++++++++
2 files changed, 53 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..162d18ea405bc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -874,12 +874,20 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
result.attributes.set(getIteratorTypesAttrName(result.name),
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
- if (!result.attributes.get(getKindAttrName(result.name))) {
+ StringAttr kindAttrName = getKindAttrName(result.name);
+ auto kindAttr = result.attributes.get(kindAttrName);
+ if (!kindAttr) {
result.addAttribute(
- getKindAttrName(result.name),
- CombiningKindAttr::get(result.getContext(),
- ContractionOp::getDefaultKind()));
+ kindAttrName, CombiningKindAttr::get(result.getContext(),
+ ContractionOp::getDefaultKind()));
+ } else {
+ if (!isa<CombiningKindAttr>(kindAttr)) {
+ return parser.emitError(parser.getNameLoc())
+ << "expected " << kindAttrName
+ << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
+ }
}
+
if (masksInfo.empty())
return success();
if (masksInfo.size() != 2)
@@ -4056,11 +4064,18 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
scalableDimsRes);
}
- if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
- result.attributes.append(
- OuterProductOp::getKindAttrName(result.name),
- CombiningKindAttr::get(result.getContext(),
- OuterProductOp::getDefaultKind()));
+ StringAttr kindAttrName = getKindAttrName(result.name);
+ auto kindAttr = result.attributes.get(kindAttrName);
+ if (!kindAttr) {
+ result.addAttribute(
+ kindAttrName, CombiningKindAttr::get(result.getContext(),
+ OuterProductOp::getDefaultKind()));
+ } else {
+ if (!isa<CombiningKindAttr>(kindAttr)) {
+ return parser.emitError(parser.getNameLoc())
+ << "expected " << kindAttrName
+ << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
+ }
}
return failure(
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 79b09e172145b..de8758ecef258 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -223,6 +223,14 @@ func.func @outerproduct_non_vector_operand(%arg0: f32) {
// -----
+func.func @outerproduct_invalid_kind_attr(%arg0 : vector<[4]xf32>, %arg1 : vector<[8]xf32>) {
+ // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+ %0 = vector.outerproduct %arg0, %arg1 {kind = "invalid"} : vector<[4]xf32>, vector<[8]xf32>
+ return
+}
+
+// -----
+
func.func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
// expected-error at +1 {{expected 1-d vector for operand #1}}
%1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
@@ -994,6 +1002,27 @@ func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector
// -----
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+ ]
+#contraction_trait = {
+ kind = "invalid",
+ indexing_maps = #contraction_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+func.func @contraction_invalid_kind(%arg0: vector<4x3xf32>,
+ %arg1: vector<3x7xf32>,
+ %arg2: vector<4x7xf32>) {
+ // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+ %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
+ : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
+ return
+}
+
+// -----
+
func.func @create_mask_0d_no_operands() {
%c1 = arith.constant 1 : index
// expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}
>From ea2da7ceb33c80e8105e181eb2e8dc5302e5002e Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Wed, 31 Dec 2025 01:06:01 +0800
Subject: [PATCH 2/2] emit error in verifier
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 43 +++++++++++-------------
mlir/test/Dialect/Vector/invalid.mlir | 4 +--
2 files changed, 21 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 162d18ea405bc..1547cb04dfc48 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -874,20 +874,12 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
result.attributes.set(getIteratorTypesAttrName(result.name),
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
- StringAttr kindAttrName = getKindAttrName(result.name);
- auto kindAttr = result.attributes.get(kindAttrName);
- if (!kindAttr) {
+ if (!result.attributes.get(getKindAttrName(result.name))) {
result.addAttribute(
- kindAttrName, CombiningKindAttr::get(result.getContext(),
- ContractionOp::getDefaultKind()));
- } else {
- if (!isa<CombiningKindAttr>(kindAttr)) {
- return parser.emitError(parser.getNameLoc())
- << "expected " << kindAttrName
- << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
- }
+ getKindAttrName(result.name),
+ CombiningKindAttr::get(result.getContext(),
+ ContractionOp::getDefaultKind()));
}
-
if (masksInfo.empty())
return success();
if (masksInfo.size() != 2)
@@ -1099,6 +1091,11 @@ LogicalResult ContractionOp::verify() {
contractingDimMap, batchDimMap)))
return failure();
+ if (!getKindAttr()) {
+ return emitOpError("expected 'kind' attribute of type CombiningKind(e.g. "
+ "'vector.kind<add>')");
+ }
+
// Verify supported combining kind.
auto vectorType = llvm::dyn_cast<VectorType>(resType);
auto elementType = vectorType ? vectorType.getElementType() : resType;
@@ -4064,18 +4061,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
scalableDimsRes);
}
- StringAttr kindAttrName = getKindAttrName(result.name);
- auto kindAttr = result.attributes.get(kindAttrName);
- if (!kindAttr) {
- result.addAttribute(
- kindAttrName, CombiningKindAttr::get(result.getContext(),
- OuterProductOp::getDefaultKind()));
- } else {
- if (!isa<CombiningKindAttr>(kindAttr)) {
- return parser.emitError(parser.getNameLoc())
- << "expected " << kindAttrName
- << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
- }
+ if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
+ result.attributes.append(
+ OuterProductOp::getKindAttrName(result.name),
+ CombiningKindAttr::get(result.getContext(),
+ OuterProductOp::getDefaultKind()));
}
return failure(
@@ -4122,6 +4112,11 @@ LogicalResult OuterProductOp::verify() {
if (vACC && vACC != vRES)
return emitOpError("expected operand #3 of same type as result type");
+ if (!getKindAttr()) {
+ return emitOpError("expected 'kind' attribute of type CombiningKind(e.g. "
+ "'vector.kind<add>')");
+ }
+
// Verify supported combining kind.
if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
return emitOpError("unsupported outerproduct type");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index de8758ecef258..844edd2541132 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -224,7 +224,7 @@ func.func @outerproduct_non_vector_operand(%arg0: f32) {
// -----
func.func @outerproduct_invalid_kind_attr(%arg0 : vector<[4]xf32>, %arg1 : vector<[8]xf32>) {
- // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+ // expected-error at +1 {{expected 'kind' attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
%0 = vector.outerproduct %arg0, %arg1 {kind = "invalid"} : vector<[4]xf32>, vector<[8]xf32>
return
}
@@ -1015,7 +1015,7 @@ func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector
func.func @contraction_invalid_kind(%arg0: vector<4x3xf32>,
%arg1: vector<3x7xf32>,
%arg2: vector<4x7xf32>) {
- // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+ // expected-error at +1 {{expected 'kind' attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
return
More information about the Mlir-commits
mailing list