[Mlir-commits] [mlir] [mlir][ODS] Verify type constraints in Types and Attributes (PR #102326)
Matthias Springer
llvmlistbot at llvm.org
Wed Aug 7 09:05:00 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/102326
>From 2134ae98f878c2b5b7a696a6960da66e35a71d2a Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 7 Aug 2024 17:49:45 +0200
Subject: [PATCH] [mlir][ODS] Verify type constraints in Types and Attributes
When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.
Example:
```
def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
// ...
}
```
No verification code was generated to ensure that `$param` is I16 or I32.
When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated.
When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
---
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 7 +-
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 12 ++-
mlir/include/mlir/Dialect/Quant/QuantTypes.h | 43 ++++----
.../mlir/Dialect/SPIRV/IR/SPIRVAttributes.h | 14 +--
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 10 +-
mlir/include/mlir/IR/CommonTypeConstraints.td | 1 +
mlir/include/mlir/IR/Constraints.td | 3 +
mlir/include/mlir/IR/StorageUniquerSupport.h | 10 +-
mlir/include/mlir/IR/Types.h | 19 ++--
mlir/include/mlir/TableGen/AttrOrTypeDef.h | 8 ++
mlir/include/mlir/TableGen/Class.h | 2 +
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 6 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 12 +--
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 39 ++++----
mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp | 9 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 9 +-
mlir/lib/TableGen/AttrOrTypeDef.cpp | 13 +++
mlir/test/IR/test-verifiers-type.mlir | 9 ++
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 1 -
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 6 ++
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 99 +++++++++++++++++--
.../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 2 +-
22 files changed, 234 insertions(+), 100 deletions(-)
create mode 100644 mlir/test/IR/test-verifiers-type.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a841..57acd72610415f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -148,9 +148,10 @@ class MMAMatrixType
/// Verify that shape and elementType are actually allowed for the
/// MMAMatrixType.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<int64_t> shape, Type elementType,
- StringRef operand);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand);
/// Get number of dims.
unsigned getNumDims() const;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 1befdfa74f67c5..2ea589a7c4c3bd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,11 +180,13 @@ class LLVMStructType
ArrayRef<Type> getBody() const;
/// Verifies that the type about to be constructed is well-formed.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- StringRef, bool);
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<Type> types, bool);
- using Base::verify;
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
+ bool);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Type> types, bool);
+ using Base::verifyInvariants;
/// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
/// DataLayout instance and query it instead.
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index de5aed0a91a209..57a2aa29833657 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType,
- Type expressedType, int64_t storageTypeMin,
- int64_t storageTypeMax);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType,
- Type expressedType, int64_t storageTypeMin,
- int64_t storageTypeMax);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax);
};
/// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType,
- Type expressedType, double scale,
- int64_t zeroPoint, int64_t storageTypeMin,
- int64_t storageTypeMax);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, double scale,
+ int64_t zeroPoint, int64_t storageTypeMin,
+ int64_t storageTypeMax);
/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
int64_t storageTypeMin, int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType,
- Type expressedType, ArrayRef<double> scales,
- ArrayRef<int64_t> zeroPoints,
- int32_t quantizedDimension,
- int64_t storageTypeMin, int64_t storageTypeMax);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType,
+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+ int32_t quantizedDimension, int64_t storageTypeMin,
+ int64_t storageTypeMax);
/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
double min, double max);
/// Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- Type expressedType, double min, double max);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type expressedType, double min, double max);
double getMin() const;
double getMax() const;
};
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 5ebfa9ca5ec25c..2bdd7a5bf3dd83 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
/// Returns `spirv::StorageClass`.
std::optional<StorageClass> getStorageClass();
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- IntegerAttr descriptorSet, IntegerAttr binding,
- IntegerAttr storageClass);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ IntegerAttr descriptorSet, IntegerAttr binding,
+ IntegerAttr storageClass);
static constexpr StringLiteral name = "spirv.interface_var_abi";
};
@@ -128,9 +129,10 @@ class VerCapExtAttr
/// Returns the capabilities as an integer array attribute.
ArrayAttr getCapabilitiesAttr();
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- IntegerAttr version, ArrayAttr capabilities,
- ArrayAttr extensions);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ IntegerAttr version, ArrayAttr capabilities,
+ ArrayAttr extensions);
static constexpr StringLiteral name = "spirv.ver_cap_ext";
};
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b44403..e2d04553d91b8b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -258,8 +258,9 @@ class SampledImageType
static SampledImageType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- Type imageType);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type imageType);
Type getImageType() const;
@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- Type columnType, uint32_t columnCount);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type columnType, uint32_t columnCount);
/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa2420..4d3e1428e6c40b 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -180,6 +180,7 @@ class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
summary),
cppClassName> {
list<Type> allowedTypes = allowedTypeList;
+ string cppType = cppClassName;
}
// A type that satisfies the constraints of all given types.
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8e..242c850f38f309 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -153,6 +153,9 @@ class TypeConstraint<Pred predicate, string summary = "",
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppClassName = cppClassNameParam;
+ // TODO: This field is sometimes called `cppClassName` and sometimes
+ // `cppType`. Use a single name consistently.
+ string cppType = cppClassNameParam;
}
// Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index fb64f15162df5b..d6ccbbd8579947 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
template <typename... Args>
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
// Ensure that the invariants are correct for construction.
- assert(
- succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
+ assert(succeeded(
+ ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
}
@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
MLIRContext *ctx, Args... args) {
// If the construction invariants fail then we return a null attribute.
- if (failed(ConcreteT::verify(emitErrorFn, args...)))
+ if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
return ConcreteT();
return UniquerT::template get<ConcreteT>(ctx, args...);
}
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// Default implementation that just returns success.
template <typename... Args>
- static LogicalResult verify(Args... args) {
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
+ Args... args) {
return success();
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 60dc8fee0f4a96..91b457deeba2f6 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -34,7 +34,7 @@ class AsmState;
/// Derived type classes are expected to implement several required
/// implementation hooks:
/// * Optional:
-/// - static LogicalResult verify(
+/// - static LogicalResult verifyInvariants(
/// function_ref<InFlightDiagnostic()> emitError,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
@@ -97,20 +97,17 @@ class Type {
bool operator!() const { return impl == nullptr; }
template <typename... Tys>
- [[deprecated("Use mlir::isa<U>() instead")]]
- bool isa() const;
+ [[deprecated("Use mlir::isa<U>() instead")]] bool isa() const;
template <typename... Tys>
- [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
- bool isa_and_nonnull() const;
+ [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] bool
+ isa_and_nonnull() const;
template <typename U>
- [[deprecated("Use mlir::dyn_cast<U>() instead")]]
- U dyn_cast() const;
+ [[deprecated("Use mlir::dyn_cast<U>() instead")]] U dyn_cast() const;
template <typename U>
- [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
- U dyn_cast_or_null() const;
+ [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] U
+ dyn_cast_or_null() const;
template <typename U>
- [[deprecated("Use mlir::cast<U>() instead")]]
- U cast() const;
+ [[deprecated("Use mlir::cast<U>() instead")]] U cast() const;
/// Return a unique identifier for the concrete type. This is used to support
/// dynamic type casting.
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 19c3a9183ec2cf..22961b24e45af4 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -17,6 +17,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Trait.h"
+#include "mlir/TableGen/Type.h"
namespace llvm {
class DagInit;
@@ -85,6 +86,9 @@ class AttrOrTypeParameter {
/// Get an optional C++ parameter parser.
std::optional<StringRef> getParser() const;
+ /// If this is a type constraint, return it.
+ std::optional<TypeConstraint> getTypeConstraint() const;
+
/// Get an optional C++ parameter printer.
std::optional<StringRef> getPrinter() const;
@@ -198,6 +202,10 @@ class AttrOrTypeDef {
/// method.
bool genVerifyDecl() const;
+ /// Return true if we need to generate any type constraint verification and
+ /// the getChecked method.
+ bool genVerifyInvariantsImpl() const;
+
/// Returns the def's extra class declaration code.
std::optional<StringRef> getExtraDecls() const;
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 855952d19492db..f750a34a3b2ba4 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -67,6 +67,8 @@ class MethodParameter {
/// Get the C++ type.
StringRef getType() const { return type; }
+ /// Get the C++ parameter name.
+ StringRef getName() const { return name; }
/// Returns true if the parameter has a default value.
bool hasDefaultValue() const { return !defaultValue.empty(); }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb0..a1f87a637a6141 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
}
LogicalResult
-MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<int64_t> shape, Type elementType,
- StringRef operand) {
+MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand) {
if (operand != "AOp" && operand != "BOp" && operand != "COp")
return emitError() << "operand expected to be one of AOp, BOp or COp";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index dc7aef8ef7f850..7f10a15ff31ff9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
bool LLVMStructType::isValidElementType(Type type) {
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
- type);
+ LLVMFunctionType, LLVMTokenType>(type);
}
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
: getImpl()->getTypeList();
}
-LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
- StringRef, bool) {
+LogicalResult
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
+ bool) {
return success();
}
LogicalResult
-LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<Type> types, bool) {
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Type> types, bool) {
for (Type t : types)
if (!isValidElementType(t))
return emitError() << "invalid LLVM structure element type: " << t;
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be2..c2ba9c04e8771d 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
}
LogicalResult
-QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType, Type expressedType,
- int64_t storageTypeMin, int64_t storageTypeMax) {
+QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ unsigned flags, Type storageType,
+ Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
@@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
}
LogicalResult
-AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType, Type expressedType,
- int64_t storageTypeMin, int64_t storageTypeMax) {
- if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
- storageTypeMin, storageTypeMax))) {
+AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ unsigned flags, Type storageType,
+ Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+ expressedType, storageTypeMin,
+ storageTypeMax))) {
return failure();
}
@@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
storageTypeMin, storageTypeMax);
}
-LogicalResult UniformQuantizedType::verify(
+LogicalResult UniformQuantizedType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
- if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
- storageTypeMin, storageTypeMax))) {
+ if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+ expressedType, storageTypeMin,
+ storageTypeMax))) {
return failure();
}
@@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
quantizedDimension, storageTypeMin, storageTypeMax);
}
-LogicalResult UniformQuantizedPerAxisType::verify(
+LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
- if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
- storageTypeMin, storageTypeMax))) {
+ if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+ expressedType, storageTypeMin,
+ storageTypeMax))) {
return failure();
}
@@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked(
min, max);
}
-LogicalResult
-CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type expressedType, double min, double max) {
+LogicalResult CalibratedQuantizedType::verifyInvariants(
+ function_ref<InFlightDiagnostic()> emitError, Type expressedType,
+ double min, double max) {
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 8a0ee7a3d81367..b71be23fdf47d0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -162,7 +162,7 @@ spirv::InterfaceVarABIAttr::getStorageClass() {
return std::nullopt;
}
-LogicalResult spirv::InterfaceVarABIAttr::verify(
+LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
IntegerAttr binding, IntegerAttr storageClass) {
if (!descriptorSet.getType().isSignlessInteger(32))
@@ -257,10 +257,9 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
return llvm::cast<ArrayAttr>(getImpl()->capabilities);
}
-LogicalResult
-spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- IntegerAttr version, ArrayAttr capabilities,
- ArrayAttr extensions) {
+LogicalResult spirv::VerCapExtAttr::verifyInvariants(
+ function_ref<InFlightDiagnostic()> emitError, IntegerAttr version,
+ ArrayAttr capabilities, ArrayAttr extensions) {
if (!version.getType().isSignlessInteger(32))
return emitError() << "expected 32-bit integer for version";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3808620bdffa6d..c5590905b75045 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -817,8 +817,8 @@ SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type SampledImageType::getImageType() const { return getImpl()->imageType; }
LogicalResult
-SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type imageType) {
+SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type imageType) {
if (!llvm::isa<ImageType>(imageType))
return emitError() << "expected image type";
@@ -1181,8 +1181,9 @@ MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
columnCount);
}
-LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type columnType, uint32_t columnCount) {
+LogicalResult
+MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type columnType, uint32_t columnCount) {
if (columnCount < 2 || columnCount > 4)
return emitError() << "matrix can have 2, 3, or 4 columns only";
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index c9dbb3bc76b1fa..ed727a834e34d0 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -184,6 +184,12 @@ bool AttrOrTypeDef::genVerifyDecl() const {
return def->getValueAsBit("genVerifyDecl");
}
+bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
+ return any_of(parameters, [](const AttrOrTypeParameter &p) {
+ return p.getTypeConstraint() != std::nullopt;
+ });
+}
+
std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? std::optional<StringRef>() : value;
@@ -331,6 +337,13 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
+std::optional<TypeConstraint> AttrOrTypeParameter::getTypeConstraint() const {
+ if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
+ if (param->getDef()->isSubClassOf("TypeConstraint"))
+ return TypeConstraint(param);
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// AttributeSelfTypeParameter
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
new file mode 100644
index 00000000000000..96d0005eb7a19d
--- /dev/null
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// CHECK: "test.type_producer"() : () -> !test.type_verification<i16>
+"test.type_producer"() : () -> !test.type_verification<i16>
+
+// -----
+
+// expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_verification<f16>
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 8f109f8ce5e6dd..b3b94bd0ffea31 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -270,7 +270,6 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
let assemblyFormat = "`<` $a `>`";
let skipDefaultBuilders = 1;
- let genVerifyDecl = 1;
let builders = [AttrBuilder<(ins "int":$a), [{
return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
}], "::mlir::Attribute">];
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index d96152a0826f96..830475bed4e444 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -392,4 +392,10 @@ def TestRecursiveAlias
}];
}
+def TestTypeVerification : Test_Type<"TestTypeVerification"> {
+ let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
+ let mnemonic = "type_verification";
+ let assemblyFormat = "`<` $param `>`";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 8cc8314418104c..7d4745cbe8d51c 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -93,8 +93,14 @@ class DefGen {
void emitDialectName();
/// Emit attribute or type builders.
void emitBuilders();
- /// Emit a verifier for the def.
- void emitVerifier();
+ /// Emit a verifier declaration for custom verification (impl. provided by
+ /// the users).
+ void emitVerifierDecl();
+ /// Emit a verifier that checks type constraints.
+ void emitInvariantsVerifierImpl();
+ /// Emit an entry poiunt for verification that calls the invariants and
+ /// custom verifier.
+ void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
/// Emit parsers and printers.
void emitParserPrinter();
/// Emit parameter accessors, if required.
@@ -188,9 +194,17 @@ DefGen::DefGen(const AttrOrTypeDef &def)
emitName();
// Emit the dialect name.
emitDialectName();
- // Emit the verifier.
- if (storageCls && def.genVerifyDecl())
- emitVerifier();
+ // Emit verification of type constraints.
+ bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
+ if (storageCls && genVerifyInvariantsImpl)
+ emitInvariantsVerifierImpl();
+ // Emit the custom verifier (written by the user).
+ bool genVerifyDecl = def.genVerifyDecl();
+ if (storageCls && genVerifyDecl)
+ emitVerifierDecl();
+ // Emit the "verifyInvariants" function if there is any verification at all.
+ if (storageCls)
+ emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl);
// Emit the mnemonic, if there is one, and any associated parser and printer.
if (def.getMnemonic())
emitParserPrinter();
@@ -295,24 +309,91 @@ void DefGen::emitDialectName() {
void DefGen::emitBuilders() {
if (!def.skipDefaultBuilders()) {
emitDefaultBuilder();
- if (def.genVerifyDecl())
+ if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
emitCheckedBuilder();
}
for (auto &builder : def.getBuilders()) {
emitCustomBuilder(builder);
- if (def.genVerifyDecl())
+ if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
emitCheckedCustomBuilder(builder);
}
}
-void DefGen::emitVerifier() {
- defCls.declare<UsingDeclaration>("Base::getChecked");
+void DefGen::emitVerifierDecl() {
defCls.declareStaticMethod(
"::llvm::LogicalResult", "verify",
getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
"emitError"}}));
}
+static const char *const patternParameterVerificationCode = R"(
+if (!({0})) {
+ emitError() << "failed to verify '{1}': {2}";
+ return ::mlir::failure();
+}
+)";
+
+void DefGen::emitInvariantsVerifierImpl() {
+ SmallVector<MethodParameter> builderParams = getBuilderParams(
+ {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+ Method *verifier =
+ defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl",
+ Method::Static, builderParams);
+ verifier->body().indent();
+
+ // Generate verification for each parameter that is a type constraint.
+ for (auto it : llvm::enumerate(def.getParameters())) {
+ const AttrOrTypeParameter ¶m = it.value();
+ std::optional<TypeConstraint> constraint = param.getTypeConstraint();
+ // No verification needed for parameters that are not type constraints.
+ if (!constraint.has_value())
+ continue;
+ FmtContext ctx;
+ // Note: Skip over the first method parameter (`emitError`).
+ ctx.withSelf(builderParams[it.index() + 1].getName());
+ std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx);
+ verifier->body() << formatv(patternParameterVerificationCode, condition,
+ param.getName(), constraint->getSummary())
+ << "\n";
+ }
+ verifier->body() << "return ::mlir::success();";
+}
+
+void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
+ if (!hasImpl && !hasCustomVerifier)
+ return;
+ defCls.declare<UsingDeclaration>("Base::getChecked");
+ SmallVector<MethodParameter> builderParams = getBuilderParams(
+ {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+ Method *verifier =
+ defCls.addMethod("::llvm::LogicalResult", "verifyInvariants",
+ Method::Static, builderParams);
+ verifier->body().indent();
+ if (hasImpl) {
+ // Call the verifier that checks the type constraints.
+ verifier->body() << "if (::mlir::failed(verifyInvariantsImpl(";
+ for (int i = 0, e = builderParams.size(); i < e; ++i) {
+ if (i > 0)
+ verifier->body() << ", ";
+ verifier->body() << builderParams[i].getName();
+ }
+ verifier->body() << ")))\n";
+ verifier->body() << " return ::mlir::failure();\n";
+ }
+ if (hasCustomVerifier) {
+ // Call the custom verifier that is provided by the user.
+ verifier->body() << "if (::mlir::failed(verify(";
+ for (int i = 0, e = builderParams.size(); i < e; ++i) {
+ if (i > 0)
+ verifier->body() << ", ";
+ verifier->body() << builderParams[i].getName();
+ }
+ verifier->body() << ")))\n";
+ verifier->body() << " return ::mlir::failure();\n";
+ }
+ verifier->body() << "return ::mlir::success();";
+}
+
void DefGen::emitParserPrinter() {
auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
"::llvm::StringLiteral", "getMnemonic");
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index dacc20b6ba2086..a4ae271edb6bd2 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -323,7 +323,7 @@ void DefFormat::genParser(MethodBody &os) {
// Generate call to the attribute or type builder. Use the checked getter
// if one was generated.
- if (def.genVerifyDecl()) {
+ if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) {
os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
&ctx, def.getCppClassName());
} else {
More information about the Mlir-commits
mailing list