[Mlir-commits] [mlir] 6edc1fa - [mlir][llvm dialect] Verify element type of nested types (#148975)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 21 05:40:07 PDT 2025
Author: James Newling
Date: 2025-07-21T05:40:02-07:00
New Revision: 6edc1faf3b9238a231f1aca10d447be8ab826816
URL: https://github.com/llvm/llvm-project/commit/6edc1faf3b9238a231f1aca10d447be8ab826816
DIFF: https://github.com/llvm/llvm-project/commit/6edc1faf3b9238a231f1aca10d447be8ab826816.diff
LOG: [mlir][llvm dialect] Verify element type of nested types (#148975)
Before this PR, this was valid
```
%0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32>
```
but this was not:
```
%0 = llvm.mlir.constant(1 : i32) : f32
```
because only scalar types were checked for compatibility, not the element types of nested types. Another additional check that this PR adds is to verify the float semantics. Before this PR,
```
%cst = llvm.mlir.constant(1.0 : bf16) : f16
```
was considered valid (because bf16 and f16 both have 16 bits), but with this PR it is not considered valid. This PR also moves all tests on the verifier of the llvm constant op into a single file. To summarize the state after this PR.
Invalid:
```mlir
%0 = llvm.mlir.constant(dense<[128, 1024]> : vector<2xi32>) :
vector<2xf32>
%0 = llvm.mlir.constant(dense<[128., 1024.]> : vector<2xbf16>) :
vector<2xf16>
```
Valid:
```mlir
%0 = llvm.mlir.constant(dense<[128., 1024.]> : vector<2xf32>) :
vector<2xi32>
%0 = llvm.mlir.constant(dense<[128, 1024]> : vector<2xi64>) :
vector<2xi8>
```
and identical valid/invalid cases for the scalar cases.
Added:
Modified:
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Target/LLVMIR/llvmir-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 34ffd1e98a767..5b01596eb522f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -24,6 +24,7 @@
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
@@ -3187,6 +3188,18 @@ static int64_t getNumElements(Type t) {
return 1;
}
+/// Determine the element type of `type`. Supported types are `VectorType`,
+/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
+static Type getElementType(Type type) {
+ while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
+ type = arrayType.getElementType();
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return vecType.getElementType();
+ if (auto tenType = dyn_cast<TensorType>(type))
+ return tenType.getElementType();
+ return type;
+}
+
/// Check if the given type is a scalable vector type or a vector/array type
/// that contains a nested scalable vector type.
static bool hasScalableVectorType(Type t) {
@@ -3281,60 +3294,69 @@ LogicalResult LLVM::ConstantOp::verify() {
}
if (auto structType = dyn_cast<LLVMStructType>(getType())) {
auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
- if (!arrayAttr) {
- return emitOpError() << "expected array attribute for a struct constant";
- }
+ if (!arrayAttr)
+ return emitOpError() << "expected array attribute for struct type";
ArrayRef<Type> elementTypes = structType.getBody();
if (arrayAttr.size() != elementTypes.size()) {
return emitOpError() << "expected array attribute of size "
<< elementTypes.size();
}
- for (auto elementTy : elementTypes) {
- if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
+ for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
+ if (!type.isSignlessIntOrIndexOrFloat()) {
return emitOpError() << "expected struct element types to be floating "
"point type or integer type";
}
- }
-
- for (size_t i = 0; i < elementTypes.size(); ++i) {
- Attribute element = arrayAttr[i];
- if (!isa<IntegerAttr, FloatAttr>(element)) {
- return emitOpError()
- << "expected struct element attribute types to be floating "
- "point type or integer type";
+ if (!isa<FloatAttr, IntegerAttr>(attr)) {
+ return emitOpError() << "expected element of array attribute to be "
+ "floating point or integer";
}
- auto elementType = cast<TypedAttr>(element).getType();
- if (elementType != elementTypes[i]) {
+ if (cast<TypedAttr>(attr).getType() != type)
return emitOpError()
<< "struct element at index " << i << " is of wrong type";
- }
}
return success();
}
- if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
+ if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
return emitOpError() << "does not support target extension type.";
- }
+
+ // Check that an attribute whose element type has floating point semantics
+ // `attributeFloatSemantics` is compatible with a type whose element type
+ // is `constantElementType`.
+ //
+ // Requirement is that either
+ // 1) They have identical floating point types.
+ // 2) `constantElementType` is an integer type of the same width as the float
+ // attribute. This is to support builtin MLIR float types without LLVM
+ // equivalents, see comments in getLLVMConstant for more details.
+ auto verifyFloatSemantics =
+ [this](const llvm::fltSemantics &attributeFloatSemantics,
+ Type constantElementType) -> LogicalResult {
+ if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
+ if (&floatType.getFloatSemantics() != &attributeFloatSemantics) {
+ return emitOpError()
+ << "attribute and type have
diff erent float semantics";
+ }
+ return success();
+ }
+ unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
+ if (isa<IntegerType>(constantElementType)) {
+ if (!constantElementType.isInteger(floatWidth))
+ return emitOpError() << "expected integer type of width " << floatWidth;
+
+ return success();
+ }
+ return success();
+ };
// Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
- if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
+ if (isa<IntegerAttr>(getValue())) {
if (!llvm::isa<IntegerType>(getType()))
return emitOpError() << "expected integer type";
} else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
- const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
- unsigned floatWidth = APFloat::getSizeInBits(sem);
- if (auto floatTy = dyn_cast<FloatType>(getType())) {
- if (floatTy.getWidth() != floatWidth) {
- return emitOpError() << "expected float type of width " << floatWidth;
- }
- }
- // See the comment for getLLVMConstant for more details about why 8-bit
- // floats can be represented by integers.
- if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
- return emitOpError() << "expected integer type of width " << floatWidth;
- }
- } else if (isa<ElementsAttr>(getValue())) {
+ return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType());
+ } else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
if (hasScalableVectorType(getType())) {
// The exact number of elements of a scalable vector is unknown, so we
// allow only splat attributes.
@@ -3346,18 +3368,32 @@ LogicalResult LLVM::ConstantOp::verify() {
}
if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
return emitOpError() << "expected vector or array type";
+
// The number of elements of the attribute and the type must match.
- if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
- int64_t attrNumElements = elementsAttr.getNumElements();
- if (getNumElements(getType()) != attrNumElements)
- return emitOpError()
- << "type and attribute have a
diff erent number of elements: "
- << getNumElements(getType()) << " vs. " << attrNumElements;
+ int64_t attrNumElements = elementsAttr.getNumElements();
+ if (getNumElements(getType()) != attrNumElements) {
+ return emitOpError()
+ << "type and attribute have a
diff erent number of elements: "
+ << getNumElements(getType()) << " vs. " << attrNumElements;
+ }
+
+ Type attrElmType = getElementType(elementsAttr.getType());
+ Type resultElmType = getElementType(getType());
+ if (auto floatType = dyn_cast<FloatType>(attrElmType))
+ return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType);
+
+ if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
+ return emitOpError(
+ "expected integer element type for integer elements attribute");
}
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
+
+ // The case where the constant is LLVMStructType has already been handled.
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
if (!arrayType)
- return emitOpError() << "expected array type";
+ return emitOpError()
+ << "expected array or struct type for array attribute";
+
// When the attribute is an ArrayAttr, check that its nesting matches the
// corresponding ArrayType or VectorType nesting.
return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 7f2c8c72e5cf9..ac1737444fcf0 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -394,7 +394,7 @@ llvm.func @array_attribute_two_
diff erent_types() -> !llvm.struct<(f64, f64)> {
// -----
llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
- // expected-error @+1 {{expected array attribute}}
+ // expected-error @+1 {{expected array attribute for struct type}}
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
@@ -439,6 +439,111 @@ llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
llvm.return %0 : vector<[4]xf64>
}
+
+// -----
+
+llvm.func @int_attr_requires_int_type() -> f32 {
+ // expected-error @below{{expected integer type}}
+ %0 = llvm.mlir.constant(1 : index) : f32
+ llvm.return %0 : f32
+}
+
+// -----
+
+llvm.func @vector_int_attr_requires_int_type() -> vector<2xf32> {
+ // expected-error @below{{expected integer element type}}
+ %0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32>
+ llvm.return %0 : vector<2xf32>
+}
+
+// -----
+
+llvm.func @float_attr_and_type_required_same() -> f16 {
+ // expected-error @below{{attribute and type have
diff erent float semantics}}
+ %cst = llvm.mlir.constant(1.0 : bf16) : f16
+ llvm.return %cst : f16
+}
+
+// -----
+
+llvm.func @vector_float_attr_and_type_required_same() -> vector<2xf16> {
+ // expected-error @below{{attribute and type have
diff erent float semantics}}
+ %cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xbf16>) : vector<2xf16>
+ llvm.return %cst : vector<2xf16>
+}
+
+// -----
+
+llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
+ // expected-error @below{{expected integer type of width 16}}
+ %cst = llvm.mlir.constant(1.0 : f16) : i32
+ llvm.return %cst : i32
+}
+
+// -----
+
+llvm.func @vector_incompatible_integer_type_for_float_attr() -> vector<2xi8> {
+ // expected-error @below{{expected integer type of width 16}}
+ %cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xf16>) : vector<2xi8>
+ llvm.return %cst : vector<2xi8>
+}
+
+// -----
+
+llvm.func @vector_with_non_vector_type() -> f32 {
+ // expected-error @below{{expected vector or array type}}
+ %cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
+ llvm.return %cst : f32
+}
+
+// -----
+
+llvm.func @array_attr_with_invalid_type() -> i32 {
+ // expected-error @below{{expected array or struct type for array attribute}}
+ %0 = llvm.mlir.constant([1 : i32]) : i32
+ llvm.return %0 : i32
+}
+
+// -----
+
+llvm.func @elements_attribute_incompatible_nested_array_struct1_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+ // expected-error @below{{expected integer element type for integer elements attribute}}
+ %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
+ llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
+}
+
+// -----
+
+llvm.func @elements_attribute_incompatible_nested_array_struct3_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+ // expected-error @below{{expected integer element type for integer elements attribute}}
+ %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+ llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+}
+
+// -----
+
+llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
+ // expected-error @below{{expected struct element types to be floating point type or integer type}}
+ %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
+ llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
+}
+
+// -----
+
+llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
+ // expected-error @below{{expected element of array attribute to be floating point or integer}}
+ %0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
+ llvm.return %0 : !llvm.struct<(f64, f64)>
+}
+
+// -----
+
+llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
+ // expected-error @below{{struct element at index 0 is of wrong type}}
+ %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
+ llvm.return %0 : !llvm.struct<(f64, f64)>
+}
+
// -----
func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
@@ -484,13 +589,13 @@ func.func @extractvalue_invalid_type(%a : !llvm.array<4 x vector<8xf32>>) -> !ll
return %b : !llvm.array<4 x vector<8xf32>>
}
-
// -----
func.func @extractvalue_non_llvm_type(%a : i32, %b : tensor<*xi32>) {
// expected-error at +2 {{expected LLVM IR Dialect type}}
llvm.extractvalue %b[0] : tensor<*xi32>
}
+
// -----
func.func @extractvalue_struct_out_of_bounds() {
@@ -659,6 +764,7 @@ func.func @atomicrmw_scalable_vector(%ptr : !llvm.ptr, %f32_vec : vector<[2]xf32
%0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<[2]xf32>
llvm.return
}
+
// -----
func.func @atomicrmw_vector_expected_float(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) {
@@ -1667,7 +1773,6 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
return
}
-
// -----
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index a8ef401fff27e..b09ceeeb86cc0 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -7,78 +7,6 @@ func.func @foo() {
// -----
-llvm.func @vector_with_non_vector_type() -> f32 {
- // expected-error @below{{expected vector or array type}}
- %cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
- llvm.return %cst : f32
-}
-
-// -----
-
-llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
- // expected-error @below{{expected an array attribute for a struct constant}}
- %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
- llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
-}
-
-// -----
-
-llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
- // expected-error @below{{expected an array attribute for a struct constant}}
- %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
- llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
-}
-
-// -----
-
-llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
- // expected-error @below{{expected struct element types to be floating point type or integer type}}
- %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
- llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
-}
-
-// -----
-
-llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
- // expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
- %0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
- llvm.return %0 : !llvm.struct<(f64, f64)>
-}
-
-// -----
-
-llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
- // expected-error @below{{struct element at index 0 is of wrong type}}
- %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
- llvm.return %0 : !llvm.struct<(f64, f64)>
-}
-
-// -----
-
-llvm.func @integer_with_float_type() -> f32 {
- // expected-error @+1 {{expected integer type}}
- %0 = llvm.mlir.constant(1 : index) : f32
- llvm.return %0 : f32
-}
-
-// -----
-
-llvm.func @incompatible_float_attribute_type() -> f32 {
- // expected-error @below{{expected float type of width 64}}
- %cst = llvm.mlir.constant(1.0 : f64) : f32
- llvm.return %cst : f32
-}
-
-// -----
-
-llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
- // expected-error @below{{expected integer type of width 16}}
- %cst = llvm.mlir.constant(1.0 : f16) : i32
- llvm.return %cst : i32
-}
-
-// -----
-
// expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
More information about the Mlir-commits
mailing list