[Mlir-commits] [mlir] [mlir][vector] Emit error when `kind` attribute is not a CombiningKind (PR #173659)

Longsheng Mou llvmlistbot at llvm.org
Fri Dec 26 04:56:37 PST 2025


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/173659

This PR fixes a crash by validating the type of the `kind` attribute. For `vector.contract` and `vector.outerproduct`, the parser now emits an error when `kind` is not a CombiningKindAttr. Fixes #173555.

>From 2f3a57cedc42bafe7d8914c41569ef344000b411 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 26 Dec 2025 20:54:04 +0800
Subject: [PATCH] [mlir][vector] Emit error when `kind` attribute is not a
 CombiningKind

This PR fixes a crash by validating the type of the `kind` attribute. For
`vector.contract` and `vector.outerproduct`, the parser now emits an error
when `kind` is not a CombiningKindAttr.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 33 +++++++++++++++++-------
 mlir/test/Dialect/Vector/invalid.mlir    | 29 +++++++++++++++++++++
 2 files changed, 53 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..162d18ea405bc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -874,12 +874,20 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   result.attributes.set(getIteratorTypesAttrName(result.name),
                         parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
 
-  if (!result.attributes.get(getKindAttrName(result.name))) {
+  StringAttr kindAttrName = getKindAttrName(result.name);
+  auto kindAttr = result.attributes.get(kindAttrName);
+  if (!kindAttr) {
     result.addAttribute(
-        getKindAttrName(result.name),
-        CombiningKindAttr::get(result.getContext(),
-                               ContractionOp::getDefaultKind()));
+        kindAttrName, CombiningKindAttr::get(result.getContext(),
+                                             ContractionOp::getDefaultKind()));
+  } else {
+    if (!isa<CombiningKindAttr>(kindAttr)) {
+      return parser.emitError(parser.getNameLoc())
+             << "expected " << kindAttrName
+             << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
+    }
   }
+
   if (masksInfo.empty())
     return success();
   if (masksInfo.size() != 2)
@@ -4056,11 +4064,18 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
                               scalableDimsRes);
   }
 
-  if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
-    result.attributes.append(
-        OuterProductOp::getKindAttrName(result.name),
-        CombiningKindAttr::get(result.getContext(),
-                               OuterProductOp::getDefaultKind()));
+  StringAttr kindAttrName = getKindAttrName(result.name);
+  auto kindAttr = result.attributes.get(kindAttrName);
+  if (!kindAttr) {
+    result.addAttribute(
+        kindAttrName, CombiningKindAttr::get(result.getContext(),
+                                             OuterProductOp::getDefaultKind()));
+  } else {
+    if (!isa<CombiningKindAttr>(kindAttr)) {
+      return parser.emitError(parser.getNameLoc())
+             << "expected " << kindAttrName
+             << " attribute of type CombiningKind(e.g. 'vector.kind<add>')";
+    }
   }
 
   return failure(
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 79b09e172145b..de8758ecef258 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -223,6 +223,14 @@ func.func @outerproduct_non_vector_operand(%arg0: f32) {
 
 // -----
 
+func.func @outerproduct_invalid_kind_attr(%arg0 : vector<[4]xf32>, %arg1 : vector<[8]xf32>) {
+  // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+  %0 = vector.outerproduct %arg0, %arg1 {kind = "invalid"} : vector<[4]xf32>, vector<[8]xf32>
+  return
+}
+
+// -----
+
 func.func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
   // expected-error at +1 {{expected 1-d vector for operand #1}}
   %1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
@@ -994,6 +1002,27 @@ func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector
 
 // -----
 
+#contraction_accesses = [
+        affine_map<(i, j, k) -> (i, k)>,
+        affine_map<(i, j, k) -> (k, j)>,
+        affine_map<(i, j, k) -> (i, j)>
+      ]
+#contraction_trait = {
+        kind = "invalid",
+        indexing_maps = #contraction_accesses,
+        iterator_types = ["parallel", "parallel", "reduction"]
+      }
+func.func @contraction_invalid_kind(%arg0: vector<4x3xf32>,
+                                    %arg1: vector<3x7xf32>,
+                                    %arg2: vector<4x7xf32>) {
+  // expected-error @+1 {{expected "kind" attribute of type CombiningKind(e.g. 'vector.kind<add>')}}
+  %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
+    : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
+  return
+}
+
+// -----
+
 func.func @create_mask_0d_no_operands() {
   %c1 = arith.constant 1 : index
   // expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}



More information about the Mlir-commits mailing list