[Mlir-commits] [mlir] [mlir][LLVMIR] Check number of elements in `mlir.constant` verifier (PR #102906)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 12 07:08:39 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Check that the number of elements in the result type and the attribute of an `llvm.mlir.constant` op matches. Also fix a broken test where that was not the case.
---
Full diff: https://github.com/llvm/llvm-project/pull/102906.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+39-10)
- (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+8)
- (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+7-1)
``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 90610118a45cd2..82dadfc98de7b6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2666,6 +2666,20 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
// ConstantOp.
//===----------------------------------------------------------------------===//
+/// Compute the total number of elements in the given type, also taking into
+/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
+/// `LLVMFixedVectorType`. Everything else is treated as a scalar.
+static int64_t getNumElements(Type t) {
+ if (auto vecType = dyn_cast<VectorType>(t))
+ return vecType.getNumElements() * getNumElements(vecType.getElementType());
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
+ return arrayType.getNumElements() *
+ getNumElements(arrayType.getElementType());
+ if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
+ return vecType.getNumElements() * getNumElements(vecType.getElementType());
+ return 1;
+}
+
LogicalResult LLVM::ConstantOp::verify() {
if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
@@ -2708,14 +2722,12 @@ LogicalResult LLVM::ConstantOp::verify() {
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
return emitOpError() << "does not support target extension type.";
}
- if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
- return emitOpError()
- << "only supports integer, float, string or elements attributes";
+
+ // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
if (!llvm::isa<IntegerType>(getType()))
return emitOpError() << "expected integer type";
- }
- if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
+ } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
unsigned floatWidth = APFloat::getSizeInBits(sem);
if (auto floatTy = dyn_cast<FloatType>(getType())) {
@@ -2728,13 +2740,30 @@ LogicalResult LLVM::ConstantOp::verify() {
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
return emitOpError() << "expected integer type of width " << floatWidth;
}
- }
- if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
- if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
- !isa<LLVM::LLVMFixedVectorType>(getType()) &&
- !isa<LLVM::LLVMScalableVectorType>(getType()))
+ } else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
+ if (isa<LLVM::LLVMScalableVectorType>(getType())) {
+ // The exact number of elements of a scalable vector is unknown, so there
+ // is nothing more to verify.
+ return success();
+ }
+ if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
+ getType()))
return emitOpError() << "expected vector or array type";
+ // The number of elements of the attribute and the type must match.
+ int64_t attrNumElements;
+ if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
+ attrNumElements = elementsAttr.getNumElements();
+ else
+ attrNumElements = cast<ArrayAttr>(getValue()).size();
+ if (getNumElements(getType()) != attrNumElements)
+ return emitOpError()
+ << "type and attribute have a different number of elements: "
+ << getNumElements(getType()) << " vs. " << attrNumElements;
+ } else {
+ return emitOpError()
+ << "only supports integer, float, string or elements attributes";
}
+
return success();
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe288dab973f5a..7edf036201e1c0 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -414,6 +414,14 @@ llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !
// -----
+llvm.func @struct_wrong_element_types() -> vector<5xf64> {
+ // expected-error @+1{{type and attribute have a different number of elements: 5 vs. 2}}
+ %0 = llvm.mlir.constant(dense<[1.0, 1.0]> : tensor<2xf64>) : vector<5xf64>
+ llvm.return %0 : vector<5xf64>
+}
+
+// -----
+
func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
// expected-error at +2 {{expected LLVM IR Dialect type}}
llvm.insertvalue %a, %b[0] : tensor<*xi32>
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index fbdf725f3ec17b..8453983aa07c33 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1295,11 +1295,17 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
}
llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
- %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+ %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
// CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
}
+llvm.func @complexintconstantsingle() -> !llvm.array<1 x !llvm.struct<(i32, i32)>> {
+ %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<1 x !llvm.struct<(i32, i32)>>
+ // CHECK: ret [1 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }]
+ llvm.return %1 : !llvm.array<1 x !llvm.struct<(i32, i32)>>
+}
+
llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
%1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
// CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/102906
More information about the Mlir-commits
mailing list