[Mlir-commits] [mlir] [mlir][llvm dialect] Verify element type of nested types (PR #148975)

James Newling llvmlistbot at llvm.org
Wed Jul 16 10:58:21 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/148975

>From b04ad6e056d076e5a41278cbd5636485159bf0cd Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Jul 2025 15:00:10 -0700
Subject: [PATCH 1/4] move them to the correct place

---
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 25 +++++++++++++++++++++
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir | 24 --------------------
 2 files changed, 25 insertions(+), 24 deletions(-)

diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index bd1106e304c60..e5fe78c077314 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -439,6 +439,31 @@ llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
   llvm.return %0 : vector<[4]xf64>
 }
 
+
+// -----
+
+llvm.func @integer_with_float_type() -> f32 {
+  // expected-error @+1 {{expected integer type}}
+  %0 = llvm.mlir.constant(1 : index) : f32
+  llvm.return %0 : f32
+}
+
+// -----
+
+llvm.func @incompatible_float_attribute_type() -> f32 {
+  // expected-error @below{{expected float type of width 64}}
+  %cst = llvm.mlir.constant(1.0 : f64) : f32
+  llvm.return %cst : f32
+}
+
+// -----
+
+llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
+  // expected-error @below{{expected integer type of width 16}}
+  %cst = llvm.mlir.constant(1.0 : f16) : i32
+  llvm.return %cst : i32
+}
+
 // -----
 
 func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index a8ef401fff27e..6c7a218d0676e 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -55,30 +55,6 @@ llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
 
 // -----
 
-llvm.func @integer_with_float_type() -> f32 {
-  // expected-error @+1 {{expected integer type}}
-  %0 = llvm.mlir.constant(1 : index) : f32
-  llvm.return %0 : f32
-}
-
-// -----
-
-llvm.func @incompatible_float_attribute_type() -> f32 {
-  // expected-error @below{{expected float type of width 64}}
-  %cst = llvm.mlir.constant(1.0 : f64) : f32
-  llvm.return %cst : f32
-}
-
-// -----
-
-llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
-  // expected-error @below{{expected integer type of width 16}}
-  %cst = llvm.mlir.constant(1.0 : f16) : i32
-  llvm.return %cst : i32
-}
-
-// -----
-
 // expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
 llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
 

>From 77f4673380f2236e7f0a4c676a5b929bd306a0ba Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Jul 2025 17:50:34 -0700
Subject: [PATCH 2/4] add checks on element attributes for types

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 102 +++++++++++++-------
 mlir/test/Dialect/LLVMIR/invalid.mlir       |  86 +++++++++++++++--
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir |  48 ---------
 3 files changed, 147 insertions(+), 89 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 62dce32bc4531..2f9a37f214b6c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Transforms/InliningUtils.h"
 
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Type.h"
@@ -3112,6 +3113,18 @@ static int64_t getNumElements(Type t) {
   return 1;
 }
 
+/// Determine the element type of `type`. Supported types are `VectorType`,
+/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
+static Type getElementType(Type type) {
+  while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
+    type = arrayType.getElementType();
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return vecType.getElementType();
+  if (auto tenType = dyn_cast<TensorType>(type))
+    return tenType.getElementType();
+  return type;
+}
+
 /// Check if the given type is a scalable vector type or a vector/array type
 /// that contains a nested scalable vector type.
 static bool hasScalableVectorType(Type t) {
@@ -3215,22 +3228,13 @@ LogicalResult LLVM::ConstantOp::verify() {
       return emitOpError() << "expected array attribute of size "
                            << elementTypes.size();
     }
-    for (auto elementTy : elementTypes) {
-      if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
-        return emitOpError() << "expected struct element types to be floating "
-                                "point type or integer type";
-      }
-    }
-
-    for (size_t i = 0; i < elementTypes.size(); ++i) {
-      Attribute element = arrayAttr[i];
-      if (!isa<IntegerAttr, FloatAttr>(element)) {
-        return emitOpError()
-               << "expected struct element attribute types to be floating "
-                  "point type or integer type";
+    for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
+      if (!isa<IntegerAttr, FloatAttr>(attr)) {
+        return emitOpError() << "expected element of array attribute to be "
+                                "floating point or integer";
       }
-      auto elementType = cast<TypedAttr>(element).getType();
-      if (elementType != elementTypes[i]) {
+      auto attrType = cast<TypedAttr>(attr).getType();
+      if (attrType != type) {
         return emitOpError()
                << "struct element at index " << i << " is of wrong type";
       }
@@ -3242,24 +3246,42 @@ LogicalResult LLVM::ConstantOp::verify() {
     return emitOpError() << "does not support target extension type.";
   }
 
+  // Check that an attribute whose element type has floating point semantics
+  // `attributeFloatSemantics` is compatible with a type whose element type
+  // is `constantElementType`.
+  //
+  // Requirement is that either
+  // 1) They have identical floating point types.
+  // 2) `constantElementType` is an integer type of the same width as the float
+  //     attribute. This is to support builtin MLIR float types without LLVM
+  //     equivalents, see comments in getLLVMConstant for more details.
+  auto verifyFloatSemantics =
+      [this](const llvm::fltSemantics &attributeFloatSemantics,
+             Type constantElementType) -> LogicalResult {
+    if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
+      if (&floatType.getFloatSemantics() != &attributeFloatSemantics) {
+        return emitOpError()
+               << "attribute and type have different float semantics";
+      }
+      return success();
+    }
+    unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
+    if (isa<IntegerType>(constantElementType)) {
+      if (!constantElementType.isInteger(floatWidth)) {
+        return emitOpError() << "expected integer type of width " << floatWidth;
+      }
+      return success();
+    }
+    return success();
+  };
+
   // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
-  if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
+  if (isa<IntegerAttr>(getValue())) {
     if (!llvm::isa<IntegerType>(getType()))
       return emitOpError() << "expected integer type";
   } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
-    const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
-    unsigned floatWidth = APFloat::getSizeInBits(sem);
-    if (auto floatTy = dyn_cast<FloatType>(getType())) {
-      if (floatTy.getWidth() != floatWidth) {
-        return emitOpError() << "expected float type of width " << floatWidth;
-      }
-    }
-    // See the comment for getLLVMConstant for more details about why 8-bit
-    // floats can be represented by integers.
-    if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
-      return emitOpError() << "expected integer type of width " << floatWidth;
-    }
-  } else if (isa<ElementsAttr>(getValue())) {
+    return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType());
+  } else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
     if (hasScalableVectorType(getType())) {
       // The exact number of elements of a scalable vector is unknown, so we
       // allow only splat attributes.
@@ -3271,13 +3293,23 @@ LogicalResult LLVM::ConstantOp::verify() {
     }
     if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
       return emitOpError() << "expected vector or array type";
+
     // The number of elements of the attribute and the type must match.
-    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
-      int64_t attrNumElements = elementsAttr.getNumElements();
-      if (getNumElements(getType()) != attrNumElements)
-        return emitOpError()
-               << "type and attribute have a different number of elements: "
-               << getNumElements(getType()) << " vs. " << attrNumElements;
+    int64_t attrNumElements = elementsAttr.getNumElements();
+    if (getNumElements(getType()) != attrNumElements) {
+      return emitOpError()
+             << "type and attribute have a different number of elements: "
+             << getNumElements(getType()) << " vs. " << attrNumElements;
+    }
+
+    Type attrElmType = getElementType(elementsAttr.getType());
+    Type resultElmType = getElementType(getType());
+    if (auto floatType = dyn_cast<FloatType>(attrElmType)) {
+      return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType);
+    }
+    if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
+      return emitOpError(
+          "expected integer element type for integer elements attribute");
     }
   } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
     auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index e5fe78c077314..f7a05743da061 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -418,7 +418,7 @@ llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
 // -----
 
 llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)> {
-  // expected-error @+1 {{expected struct element types to be floating point type or integer type}}
+  // expected-error @+1 {{expected element of array attribute to be floating point or integer}}
   %0 = llvm.mlir.constant([dense<[1.0, 1.0]> : tensor<2xf64>, dense<[1.0, 1.0]> : tensor<2xf64>]) : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
   llvm.return %0 : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
 }
@@ -442,7 +442,7 @@ llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
 
 // -----
 
-llvm.func @integer_with_float_type() -> f32 {
+llvm.func @int_attr_requires_int_type() -> f32 {
   // expected-error @+1 {{expected integer type}}
   %0 = llvm.mlir.constant(1 : index) : f32
   llvm.return %0 : f32
@@ -450,10 +450,26 @@ llvm.func @integer_with_float_type() -> f32 {
 
 // -----
 
-llvm.func @incompatible_float_attribute_type() -> f32 {
-  // expected-error @below{{expected float type of width 64}}
-  %cst = llvm.mlir.constant(1.0 : f64) : f32
-  llvm.return %cst : f32
+llvm.func @vector_int_attr_requires_int_type() -> vector<2xf32> {
+  // expected-error @+1 {{expected integer element type}}
+  %0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32>
+  llvm.return %0 : vector<2xf32>
+}
+
+// -----
+
+llvm.func @float_attr_and_type_required_same() -> f16 {
+  // expected-error @below{{attribute and type have different float semantics}}
+  %cst = llvm.mlir.constant(1.0 : bf16) : f16
+  llvm.return %cst : f16
+}
+
+// -----
+
+llvm.func @vector_float_attr_and_type_required_same() -> vector<2xf16> {
+  // expected-error @below{{attribute and type have different float semantics}}
+  %cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xbf16>) : vector<2xf16>
+  llvm.return %cst : vector<2xf16>
 }
 
 // -----
@@ -466,6 +482,64 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
 
 // -----
 
+llvm.func @vector_incompatible_integer_type_for_float_attr() -> vector<2xi8> {
+  // expected-error @below{{expected integer type of width 16}}
+  %cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xf16>) : vector<2xi8>
+  llvm.return %cst : vector<2xi8>
+}
+
+// -----
+
+llvm.func @vector_with_non_vector_type() -> f32 {
+  // expected-error @below{{expected vector or array type}}
+  %cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
+  llvm.return %cst : f32
+}
+
+// -----
+
+llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+  // expected-error @below{{expected integer element type for integer elements attribute}}
+  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
+  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
+}
+
+// -----
+
+llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+  // expected-error @below{{expected integer element type for integer elements attribute}}
+  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+}
+
+// -----
+
+llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
+  // expected-error @below{{expected element of array attribute to be floating point or integer}}
+  %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
+  llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
+}
+
+// -----
+
+llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
+  // expected-error @below{{expected element of array attribute to be floating point or integer}}
+  %0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
+  llvm.return %0 : !llvm.struct<(f64, f64)>
+}
+
+// -----
+
+llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
+  // expected-error @below{{struct element at index 0 is of wrong type}}
+  %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
+  llvm.return %0 : !llvm.struct<(f64, f64)>
+}
+
+// -----
+
+// -----
+
 func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
   // expected-error at +2 {{expected LLVM IR Dialect type}}
   llvm.insertvalue %a, %b[0] : tensor<*xi32>
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 6c7a218d0676e..b09ceeeb86cc0 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -7,54 +7,6 @@ func.func @foo() {
 
 // -----
 
-llvm.func @vector_with_non_vector_type() -> f32 {
-  // expected-error @below{{expected vector or array type}}
-  %cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
-  llvm.return %cst : f32
-}
-
-// -----
-
-llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
-  // expected-error @below{{expected an array attribute for a struct constant}}
-  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
-  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
-}
-
-// -----
-
-llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
-  // expected-error @below{{expected an array attribute for a struct constant}}
-  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
-  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
-}
-
-// -----
-
-llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
-  // expected-error @below{{expected struct element types to be floating point type or integer type}}
-  %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
-  llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
-}
-
-// -----
-
-llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
-  %0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
-  llvm.return %0 : !llvm.struct<(f64, f64)>
-}
-
-// -----
-
-llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @below{{struct element at index 0 is of wrong type}}
-  %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
-  llvm.return %0 : !llvm.struct<(f64, f64)>
-}
-
-// -----
-
 // expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
 llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
 

>From 68be7ae5884806302bd5b88ab33e0f32b90b902a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 16 Jul 2025 10:10:29 -0700
Subject: [PATCH 3/4] address a few review comments

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 21 ++++++++++---------
 mlir/test/Dialect/LLVMIR/invalid.mlir      | 24 ++++++++++++++--------
 2 files changed, 26 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2f9a37f214b6c..b1a306a853245 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3219,9 +3219,8 @@ LogicalResult LLVM::ConstantOp::verify() {
   }
   if (auto structType = dyn_cast<LLVMStructType>(getType())) {
     auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
-    if (!arrayAttr) {
-      return emitOpError() << "expected array attribute for a struct constant";
-    }
+    if (!arrayAttr)
+      return emitOpError() << "expected array attribute for struct type";
 
     ArrayRef<Type> elementTypes = structType.getBody();
     if (arrayAttr.size() != elementTypes.size()) {
@@ -3234,17 +3233,15 @@ LogicalResult LLVM::ConstantOp::verify() {
                                 "floating point or integer";
       }
       auto attrType = cast<TypedAttr>(attr).getType();
-      if (attrType != type) {
+      if (attrType != type)
         return emitOpError()
                << "struct element at index " << i << " is of wrong type";
-      }
     }
 
     return success();
   }
-  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
+  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
     return emitOpError() << "does not support target extension type.";
-  }
 
   // Check that an attribute whose element type has floating point semantics
   // `attributeFloatSemantics` is compatible with a type whose element type
@@ -3267,9 +3264,9 @@ LogicalResult LLVM::ConstantOp::verify() {
     }
     unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
     if (isa<IntegerType>(constantElementType)) {
-      if (!constantElementType.isInteger(floatWidth)) {
+      if (!constantElementType.isInteger(floatWidth))
         return emitOpError() << "expected integer type of width " << floatWidth;
-      }
+
       return success();
     }
     return success();
@@ -3312,9 +3309,13 @@ LogicalResult LLVM::ConstantOp::verify() {
           "expected integer element type for integer elements attribute");
     }
   } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
+
+    // The case where the constant is LLVMStructType has already been handled.
     auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
     if (!arrayType)
-      return emitOpError() << "expected array type";
+      return emitOpError()
+             << "expected array or struct type for array attribute";
+
     // When the attribute is an ArrayAttr, check that its nesting matches the
     // corresponding ArrayType or VectorType nesting.
     return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f7a05743da061..1c8bfea4b1e72 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -394,7 +394,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute}}
+  // expected-error @+1 {{expected array attribute for struct type}}
   %0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -443,7 +443,7 @@ llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
 // -----
 
 llvm.func @int_attr_requires_int_type() -> f32 {
-  // expected-error @+1 {{expected integer type}}
+  // expected-error @below{{expected integer type}}
   %0 = llvm.mlir.constant(1 : index) : f32
   llvm.return %0 : f32
 }
@@ -451,7 +451,7 @@ llvm.func @int_attr_requires_int_type() -> f32 {
 // -----
 
 llvm.func @vector_int_attr_requires_int_type() -> vector<2xf32> {
-  // expected-error @+1 {{expected integer element type}}
+  // expected-error @below{{expected integer element type}}
   %0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32>
   llvm.return %0 : vector<2xf32>
 }
@@ -498,7 +498,15 @@ llvm.func @vector_with_non_vector_type() -> f32 {
 
 // -----
 
-llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+llvm.func @array_attr_with_invalid_type() -> i32 {
+  // expected-error @below{{expected array or struct type for array attribute}}
+  %0 = llvm.mlir.constant([1 : i32]) : i32
+  llvm.return %0 : i32
+}
+
+// -----
+
+llvm.func @elements_attribute_incompatible_nested_array_struct1_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
   // expected-error @below{{expected integer element type for integer elements attribute}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
@@ -506,7 +514,7 @@ llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x st
 
 // -----
 
-llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+llvm.func @elements_attribute_incompatible_nested_array_struct3_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
   // expected-error @below{{expected integer element type for integer elements attribute}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
@@ -538,8 +546,6 @@ llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
 
 // -----
 
-// -----
-
 func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
   // expected-error at +2 {{expected LLVM IR Dialect type}}
   llvm.insertvalue %a, %b[0] : tensor<*xi32>
@@ -583,13 +589,13 @@ func.func @extractvalue_invalid_type(%a : !llvm.array<4 x vector<8xf32>>) -> !ll
   return %b : !llvm.array<4 x vector<8xf32>>
 }
 
-
 // -----
 
 func.func @extractvalue_non_llvm_type(%a : i32, %b : tensor<*xi32>) {
   // expected-error at +2 {{expected LLVM IR Dialect type}}
   llvm.extractvalue %b[0] : tensor<*xi32>
 }
+
 // -----
 
 func.func @extractvalue_struct_out_of_bounds() {
@@ -758,6 +764,7 @@ func.func @atomicrmw_scalable_vector(%ptr : !llvm.ptr, %f32_vec : vector<[2]xf32
   %0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<[2]xf32>
   llvm.return
 }
+
 // -----
 
 func.func @atomicrmw_vector_expected_float(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) {
@@ -1766,7 +1773,6 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
   return
 }
 
-
 // -----
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {

>From f22984071db14d10027e4c6bce40578f41a3fa66 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 16 Jul 2025 10:59:22 -0700
Subject: [PATCH 4/4] use original error message (as per review comment)

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 9 ++++++---
 mlir/test/Dialect/LLVMIR/invalid.mlir      | 4 ++--
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index b1a306a853245..fc435263698d3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3228,12 +3228,15 @@ LogicalResult LLVM::ConstantOp::verify() {
                            << elementTypes.size();
     }
     for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
-      if (!isa<IntegerAttr, FloatAttr>(attr)) {
+      if (!type.isSignlessIntOrIndexOrFloat()) {
+        return emitOpError() << "expected struct element types to be floating "
+                                "point type or integer type";
+      }
+      if (!isa<FloatAttr, IntegerAttr>(attr)) {
         return emitOpError() << "expected element of array attribute to be "
                                 "floating point or integer";
       }
-      auto attrType = cast<TypedAttr>(attr).getType();
-      if (attrType != type)
+      if (cast<TypedAttr>(attr).getType() != type)
         return emitOpError()
                << "struct element at index " << i << " is of wrong type";
     }
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 1c8bfea4b1e72..6918fb1636479 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -418,7 +418,7 @@ llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
 // -----
 
 llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)> {
-  // expected-error @+1 {{expected element of array attribute to be floating point or integer}}
+  // expected-error @+1 {{expected struct element types to be floating point type or integer type}}
   %0 = llvm.mlir.constant([dense<[1.0, 1.0]> : tensor<2xf64>, dense<[1.0, 1.0]> : tensor<2xf64>]) : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
   llvm.return %0 : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
 }
@@ -523,7 +523,7 @@ llvm.func @elements_attribute_incompatible_nested_array_struct3_type() -> !llvm.
 // -----
 
 llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
-  // expected-error @below{{expected element of array attribute to be floating point or integer}}
+  // expected-error @below{{expected struct element types to be floating point type or integer type}}
   %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
   llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
 }



More information about the Mlir-commits mailing list