[Mlir-commits] [mlir] [mlir][vector] Emit error when `kind` attribute is not a CombiningKind (PR #173659)

Longsheng Mou llvmlistbot at llvm.org
Tue Dec 30 09:07:23 PST 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173659

>From 2f3a57cedc42bafe7d8914c41569ef344000b411 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 26 Dec 2025 20:54:04 +0800
Subject: [PATCH 1/2] [mlir][vector] Emit error when `kind` attribute is not a
 CombiningKind

This PR fixes a crash by validating the type of the `kind` attribute. For
`vector.contract` and `vector.outerproduct`, the parser now emits an error
when `kind` is not a CombiningKindAttr.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 33 +++++++++++++++++-------
 mlir/test/Dialect/Vector/invalid.mlir    | 29 +++++++++++++++++++++
 2 files changed, 53 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..162d18ea405bc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -874,12 +874,20 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   result.attributes.set(getIteratorTypesAttrName(result.name),
                         parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
 
-  if (!result.attributes.get(getKindAttrName(result.name))) {
+  StringAttr kindAttrName = getKindAttrName(result.name);
+  auto kindAttr = result.attributes.get(kindAttrName);
+  if (!kindAttr) {
     result.addAttribute(
-        getKindAttrName(result.name),
-        CombiningKindAttr::get(result.getContext(),
-                               ContractionOp::getDefaultKind()));
+        kindAttrName, CombiningKindAttr::get(result.getContext(),
+                                             ContractionOp::getDefaultKind()));
+  } else {
+    if (!isa<CombiningKindAttr>(kindAttr)) {
+      return parser.emitError(parser.getNameLoc())
+             << "expected " << kindAttrName
+             << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
+    }
   }
+
   if (masksInfo.empty())
     return success();
   if (masksInfo.size() != 2)
@@ -4056,11 +4064,18 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
                               scalableDimsRes);
   }
 
-  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
-    result.attributes.append(
-        OuterProductOp::getKindAttrName(result.name),
-        CombiningKindAttr::get(result.getContext(),
-                               OuterProductOp::getDefaultKind()));
+  StringAttr kindAttrName = getKindAttrName(result.name);
+  auto kindAttr = result.attributes.get(kindAttrName);
+  if (!kindAttr) {
+    result.addAttribute(
+        kindAttrName, CombiningKindAttr::get(result.getContext(),
+                                             OuterProductOp::getDefaultKind()));
+  } else {
+    if (!isa<CombiningKindAttr>(kindAttr)) {
+      return parser.emitError(parser.getNameLoc())
+             << "expected " << kindAttrName
+             << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
+    }
   }
 
   return failure(
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 79b09e172145b..de8758ecef258 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -223,6 +223,14 @@ func.func @outerproduct_non_vector_operand(%arg0: f32) {
 
 // -----
 
+func.func @outerproduct_invalid_kind_attr(%arg0 : vector<[4]xf32>, %arg1 : vector<[8]xf32>) {
+  // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+  %0 = vector.outerproduct %arg0, %arg1 {kind = "invalid"} : vector<[4]xf32>, vector<[8]xf32>
+  return
+}
+
+// -----
+
 func.func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
   // expected-error at +1 {{expected 1-d vector for operand #1}}
   %1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
@@ -994,6 +1002,27 @@ func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector
 
 // -----
 
+#contraction_accesses = [
+        affine_map<(i, j, k) -> (i, k)>,
+        affine_map<(i, j, k) -> (k, j)>,
+        affine_map<(i, j, k) -> (i, j)>
+      ]
+#contraction_trait = {
+        kind = "invalid",
+        indexing_maps = #contraction_accesses,
+        iterator_types = ["parallel", "parallel", "reduction"]
+      }
+func.func @contraction_invalid_kind(%arg0: vector<4x3xf32>,
+                                    %arg1: vector<3x7xf32>,
+                                    %arg2: vector<4x7xf32>) {
+  // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+  %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
+    : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
+  return
+}
+
+// -----
+
 func.func @create_mask_0d_no_operands() {
   %c1 = arith.constant 1 : index
   // expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}

>From ea2da7ceb33c80e8105e181eb2e8dc5302e5002e Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Wed, 31 Dec 2025 01:06:01 +0800
Subject: [PATCH 2/2] emit error in verifier

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 43 +++++++++++-------------
 mlir/test/Dialect/Vector/invalid.mlir    |  4 +--
 2 files changed, 21 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 162d18ea405bc..1547cb04dfc48 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -874,20 +874,12 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   result.attributes.set(getIteratorTypesAttrName(result.name),
                         parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
 
-  StringAttr kindAttrName = getKindAttrName(result.name);
-  auto kindAttr = result.attributes.get(kindAttrName);
-  if (!kindAttr) {
+  if (!result.attributes.get(getKindAttrName(result.name))) {
     result.addAttribute(
-        kindAttrName, CombiningKindAttr::get(result.getContext(),
-                                             ContractionOp::getDefaultKind()));
-  } else {
-    if (!isa<CombiningKindAttr>(kindAttr)) {
-      return parser.emitError(parser.getNameLoc())
-             << "expected " << kindAttrName
-             << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
-    }
+        getKindAttrName(result.name),
+        CombiningKindAttr::get(result.getContext(),
+                               ContractionOp::getDefaultKind()));
   }
-
   if (masksInfo.empty())
     return success();
   if (masksInfo.size() != 2)
@@ -1099,6 +1091,11 @@ LogicalResult ContractionOp::verify() {
                                contractingDimMap, batchDimMap)))
     return failure();
 
+  if (!getKindAttr()) {
+    return emitOpError("expected 'kind' attribute of type CombiningKind(e.g. "
+                       "'vector.kind<add>')");
+  }
+
   // Verify supported combining kind.
   auto vectorType = llvm::dyn_cast<VectorType>(resType);
   auto elementType = vectorType ? vectorType.getElementType() : resType;
@@ -4064,18 +4061,11 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
                               scalableDimsRes);
   }
 
-  StringAttr kindAttrName = getKindAttrName(result.name);
-  auto kindAttr = result.attributes.get(kindAttrName);
-  if (!kindAttr) {
-    result.addAttribute(
-        kindAttrName, CombiningKindAttr::get(result.getContext(),
-                                             OuterProductOp::getDefaultKind()));
-  } else {
-    if (!isa<CombiningKindAttr>(kindAttr)) {
-      return parser.emitError(parser.getNameLoc())
-             << "expected " << kindAttrName
-             << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
-    }
+  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
+    result.attributes.append(
+        OuterProductOp::getKindAttrName(result.name),
+        CombiningKindAttr::get(result.getContext(),
+                               OuterProductOp::getDefaultKind()));
   }
 
   return failure(
@@ -4122,6 +4112,11 @@ LogicalResult OuterProductOp::verify() {
   if (vACC && vACC != vRES)
     return emitOpError("expected operand #3 of same type as result type");
 
+  if (!getKindAttr()) {
+    return emitOpError("expected 'kind' attribute of type CombiningKind(e.g. "
+                       "'vector.kind<add>')");
+  }
+
   // Verify supported combining kind.
   if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
     return emitOpError("unsupported outerproduct type");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index de8758ecef258..844edd2541132 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -224,7 +224,7 @@ func.func @outerproduct_non_vector_operand(%arg0: f32) {
 // -----
 
 func.func @outerproduct_invalid_kind_attr(%arg0 : vector<[4]xf32>, %arg1 : vector<[8]xf32>) {
-  // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+  // expected-error at +1 {{expected 'kind' attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
   %0 = vector.outerproduct %arg0, %arg1 {kind = "invalid"} : vector<[4]xf32>, vector<[8]xf32>
   return
 }
@@ -1015,7 +1015,7 @@ func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector
 func.func @contraction_invalid_kind(%arg0: vector<4x3xf32>,
                                     %arg1: vector<3x7xf32>,
                                     %arg2: vector<4x7xf32>) {
-  // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+  // expected-error at +1 {{expected 'kind' attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
   %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
     : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
   return



More information about the Mlir-commits mailing list