[Mlir-commits] [mlir] [mlir][ODS] Verify type constraints in Types and Attributes (PR #102326)

Matthias Springer llvmlistbot at llvm.org
Wed Aug 7 09:16:41 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/102326

>From bb210672984e9b0f54d0d67d84d38e1086bf57f5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 7 Aug 2024 17:49:45 +0200
Subject: [PATCH] [mlir][ODS] Verify type constraints in Types and Attributes

When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:
```
def TestTypeVerification : Test_Type<"TestTypeVerification"> {
  let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
  // ...
}
```

No verification code was generated to ensure that `$param` is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
---
 mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h |  7 +-
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h  | 12 ++-
 mlir/include/mlir/Dialect/Quant/QuantTypes.h  | 43 ++++----
 .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.h   | 14 +--
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        | 10 +-
 mlir/include/mlir/IR/CommonTypeConstraints.td |  1 +
 mlir/include/mlir/IR/Constraints.td           |  3 +
 mlir/include/mlir/IR/StorageUniquerSupport.h  | 10 +-
 mlir/include/mlir/IR/Types.h                  | 19 ++--
 mlir/include/mlir/TableGen/AttrOrTypeDef.h    |  8 ++
 mlir/include/mlir/TableGen/Class.h            |  2 +
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  6 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      | 12 +--
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      | 39 ++++----
 mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp |  9 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      |  9 +-
 mlir/lib/TableGen/AttrOrTypeDef.cpp           | 13 +++
 mlir/test/IR/test-verifiers-type.mlir         |  9 ++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  1 -
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  6 ++
 mlir/test/mlir-tblgen/attr-or-type-format.td  | 20 ++++
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp   | 99 +++++++++++++++++--
 .../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp |  2 +-
 23 files changed, 254 insertions(+), 100 deletions(-)
 create mode 100644 mlir/test/IR/test-verifiers-type.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a841..57acd72610415f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -148,9 +148,10 @@ class MMAMatrixType
 
   /// Verify that shape and elementType are actually allowed for the
   /// MMAMatrixType.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<int64_t> shape, Type elementType,
-                              StringRef operand);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType,
+                   StringRef operand);
 
   /// Get number of dims.
   unsigned getNumDims() const;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 1befdfa74f67c5..2ea589a7c4c3bd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,11 +180,13 @@ class LLVMStructType
   ArrayRef<Type> getBody() const;
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              StringRef, bool);
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<Type> types, bool);
-  using Base::verify;
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
+                   bool);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<Type> types, bool);
+  using Base::verifyInvariants;
 
   /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
   /// DataLayout instance and query it instead.
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index de5aed0a91a209..57a2aa29833657 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
   /// The maximum number of bits supported for storage types.
   static constexpr unsigned MaxStorageBits = 32;
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
   static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
              int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 };
 
 /// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
              int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, double scale,
-                              int64_t zeroPoint, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, double scale,
+                   int64_t zeroPoint, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the scale term. The scale designates the difference between the real
   /// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
              int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, ArrayRef<double> scales,
-                              ArrayRef<int64_t> zeroPoints,
-                              int32_t quantizedDimension,
-                              int64_t storageTypeMin, int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+                   int32_t quantizedDimension, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the quantization scales. The scales designate the difference between
   /// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
              double min, double max);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type expressedType, double min, double max);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type expressedType, double min, double max);
   double getMin() const;
   double getMax() const;
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 5ebfa9ca5ec25c..2bdd7a5bf3dd83 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
   /// Returns `spirv::StorageClass`.
   std::optional<StorageClass> getStorageClass();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr descriptorSet, IntegerAttr binding,
-                              IntegerAttr storageClass);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr descriptorSet, IntegerAttr binding,
+                   IntegerAttr storageClass);
 
   static constexpr StringLiteral name = "spirv.interface_var_abi";
 };
@@ -128,9 +129,10 @@ class VerCapExtAttr
   /// Returns the capabilities as an integer array attribute.
   ArrayAttr getCapabilitiesAttr();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr version, ArrayAttr capabilities,
-                              ArrayAttr extensions);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr version, ArrayAttr capabilities,
+                   ArrayAttr extensions);
 
   static constexpr StringLiteral name = "spirv.ver_cap_ext";
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b44403..e2d04553d91b8b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -258,8 +258,9 @@ class SampledImageType
   static SampledImageType
   getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type imageType);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type imageType);
 
   Type getImageType() const;
 
@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
                                Type columnType, uint32_t columnCount);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type columnType, uint32_t columnCount);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type columnType, uint32_t columnCount);
 
   /// Returns true if the matrix elements are vectors of float elements.
   static bool isValidColumnType(Type columnType);
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa2420..4d3e1428e6c40b 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -180,6 +180,7 @@ class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
         summary),
     cppClassName> {
   list<Type> allowedTypes = allowedTypeList;
+  string cppType = cppClassName;
 }
 
 // A type that satisfies the constraints of all given types.
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8e..242c850f38f309 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -153,6 +153,9 @@ class TypeConstraint<Pred predicate, string summary = "",
     Constraint<predicate, summary> {
   // The name of the C++ Type class if known, or Type if not.
   string cppClassName = cppClassNameParam;
+  // TODO: This field is sometimes called `cppClassName` and sometimes
+  // `cppType`. Use a single name consistently.
+  string cppType = cppClassNameParam;
 }
 
 // Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index fb64f15162df5b..d6ccbbd8579947 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   template <typename... Args>
   static ConcreteT get(MLIRContext *ctx, Args &&...args) {
     // Ensure that the invariants are correct for construction.
-    assert(
-        succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
+    assert(succeeded(
+        ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
     return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
   }
 
@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                               MLIRContext *ctx, Args... args) {
     // If the construction invariants fail then we return a null attribute.
-    if (failed(ConcreteT::verify(emitErrorFn, args...)))
+    if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
       return ConcreteT();
     return UniquerT::template get<ConcreteT>(ctx, args...);
   }
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
 
   /// Default implementation that just returns success.
   template <typename... Args>
-  static LogicalResult verify(Args... args) {
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
+                   Args... args) {
     return success();
   }
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 60dc8fee0f4a96..91b457deeba2f6 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -34,7 +34,7 @@ class AsmState;
 /// Derived type classes are expected to implement several required
 /// implementation hooks:
 ///  * Optional:
-///    - static LogicalResult verify(
+///    - static LogicalResult verifyInvariants(
 ///                                function_ref<InFlightDiagnostic()> emitError,
 ///                                Args... args)
 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
@@ -97,20 +97,17 @@ class Type {
   bool operator!() const { return impl == nullptr; }
 
   template <typename... Tys>
-  [[deprecated("Use mlir::isa<U>() instead")]]
-  bool isa() const;
+  [[deprecated("Use mlir::isa<U>() instead")]] bool isa() const;
   template <typename... Tys>
-  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
-  bool isa_and_nonnull() const;
+  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] bool
+  isa_and_nonnull() const;
   template <typename U>
-  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
-  U dyn_cast() const;
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]] U dyn_cast() const;
   template <typename U>
-  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
-  U dyn_cast_or_null() const;
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] U
+  dyn_cast_or_null() const;
   template <typename U>
-  [[deprecated("Use mlir::cast<U>() instead")]]
-  U cast() const;
+  [[deprecated("Use mlir::cast<U>() instead")]] U cast() const;
 
   /// Return a unique identifier for the concrete type. This is used to support
   /// dynamic type casting.
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 19c3a9183ec2cf..22961b24e45af4 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -17,6 +17,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Builder.h"
 #include "mlir/TableGen/Trait.h"
+#include "mlir/TableGen/Type.h"
 
 namespace llvm {
 class DagInit;
@@ -85,6 +86,9 @@ class AttrOrTypeParameter {
   /// Get an optional C++ parameter parser.
   std::optional<StringRef> getParser() const;
 
+  /// If this is a type constraint, return it.
+  std::optional<TypeConstraint> getTypeConstraint() const;
+
   /// Get an optional C++ parameter printer.
   std::optional<StringRef> getPrinter() const;
 
@@ -198,6 +202,10 @@ class AttrOrTypeDef {
   /// method.
   bool genVerifyDecl() const;
 
+  /// Return true if we need to generate any type constraint verification and
+  /// the getChecked method.
+  bool genVerifyInvariantsImpl() const;
+
   /// Returns the def's extra class declaration code.
   std::optional<StringRef> getExtraDecls() const;
 
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 855952d19492db..f750a34a3b2ba4 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -67,6 +67,8 @@ class MethodParameter {
 
   /// Get the C++ type.
   StringRef getType() const { return type; }
+  /// Get the C++ parameter name.
+  StringRef getName() const { return name; }
   /// Returns true if the parameter has a default value.
   bool hasDefaultValue() const { return !defaultValue.empty(); }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb0..a1f87a637a6141 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
 }
 
 LogicalResult
-MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      ArrayRef<int64_t> shape, Type elementType,
-                      StringRef operand) {
+MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                ArrayRef<int64_t> shape, Type elementType,
+                                StringRef operand) {
   if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError() << "operand expected to be one of AOp, BOp or COp";
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index dc7aef8ef7f850..7f10a15ff31ff9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
 
 bool LLVMStructType::isValidElementType(Type type) {
   return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
-                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
-      type);
+                    LLVMFunctionType, LLVMTokenType>(type);
 }
 
 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
                         : getImpl()->getTypeList();
 }
 
-LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
-                                     StringRef, bool) {
+LogicalResult
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
+                                 bool) {
   return success();
 }
 
 LogicalResult
-LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
-                       ArrayRef<Type> types, bool) {
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<Type> types, bool) {
   for (Type t : types)
     if (!isValidElementType(t))
       return emitError() << "invalid LLVM structure element type: " << t;
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be2..c2ba9c04e8771d 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
 }
 
 LogicalResult
-QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      unsigned flags, Type storageType, Type expressedType,
-                      int64_t storageTypeMin, int64_t storageTypeMax) {
+QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                unsigned flags, Type storageType,
+                                Type expressedType, int64_t storageTypeMin,
+                                int64_t storageTypeMax) {
   // Verify that the storage type is integral.
   // This restriction may be lifted at some point in favor of using bf16
   // or f16 as exact representations on hardware where that is advantageous.
@@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 }
 
 LogicalResult
-AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                         unsigned flags, Type storageType, Type expressedType,
-                         int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                   unsigned flags, Type storageType,
+                                   Type expressedType, int64_t storageTypeMin,
+                                   int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
                           storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedType::verify(
+LogicalResult UniformQuantizedType::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
     int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
                           quantizedDimension, storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedPerAxisType::verify(
+LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
     Type storageType, Type expressedType, ArrayRef<double> scales,
     ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
     int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked(
                           min, max);
 }
 
-LogicalResult
-CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                                Type expressedType, double min, double max) {
+LogicalResult CalibratedQuantizedType::verifyInvariants(
+    function_ref<InFlightDiagnostic()> emitError, Type expressedType,
+    double min, double max) {
   // Verify that the expressed type is floating point.
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 8a0ee7a3d81367..b71be23fdf47d0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -162,7 +162,7 @@ spirv::InterfaceVarABIAttr::getStorageClass() {
   return std::nullopt;
 }
 
-LogicalResult spirv::InterfaceVarABIAttr::verify(
+LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
     IntegerAttr binding, IntegerAttr storageClass) {
   if (!descriptorSet.getType().isSignlessInteger(32))
@@ -257,10 +257,9 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
   return llvm::cast<ArrayAttr>(getImpl()->capabilities);
 }
 
-LogicalResult
-spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                             IntegerAttr version, ArrayAttr capabilities,
-                             ArrayAttr extensions) {
+LogicalResult spirv::VerCapExtAttr::verifyInvariants(
+    function_ref<InFlightDiagnostic()> emitError, IntegerAttr version,
+    ArrayAttr capabilities, ArrayAttr extensions) {
   if (!version.getType().isSignlessInteger(32))
     return emitError() << "expected 32-bit integer for version";
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3808620bdffa6d..c5590905b75045 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -817,8 +817,8 @@ SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
 
 LogicalResult
-SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
-                         Type imageType) {
+SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                   Type imageType) {
   if (!llvm::isa<ImageType>(imageType))
     return emitError() << "expected image type";
 
@@ -1181,8 +1181,9 @@ MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
                           columnCount);
 }
 
-LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
-                                 Type columnType, uint32_t columnCount) {
+LogicalResult
+MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                             Type columnType, uint32_t columnCount) {
   if (columnCount < 2 || columnCount > 4)
     return emitError() << "matrix can have 2, 3, or 4 columns only";
 
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index c9dbb3bc76b1fa..ed727a834e34d0 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -184,6 +184,12 @@ bool AttrOrTypeDef::genVerifyDecl() const {
   return def->getValueAsBit("genVerifyDecl");
 }
 
+bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
+  return any_of(parameters, [](const AttrOrTypeParameter &p) {
+    return p.getTypeConstraint() != std::nullopt;
+  });
+}
+
 std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
   auto value = def->getValueAsString("extraClassDeclaration");
   return value.empty() ? std::optional<StringRef>() : value;
@@ -331,6 +337,13 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
 
 llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
 
+std::optional<TypeConstraint> AttrOrTypeParameter::getTypeConstraint() const {
+  if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
+    if (param->getDef()->isSubClassOf("TypeConstraint"))
+      return TypeConstraint(param);
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // AttributeSelfTypeParameter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
new file mode 100644
index 00000000000000..96d0005eb7a19d
--- /dev/null
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// CHECK: "test.type_producer"() : () -> !test.type_verification<i16>
+"test.type_producer"() : () -> !test.type_verification<i16>
+
+// -----
+
+// expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_verification<f16>
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 8f109f8ce5e6dd..b3b94bd0ffea31 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -270,7 +270,6 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
   let assemblyFormat = "`<` $a `>`";
 
   let skipDefaultBuilders = 1;
-  let genVerifyDecl = 1;
   let builders = [AttrBuilder<(ins "int":$a), [{
     return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
   }], "::mlir::Attribute">];
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index d96152a0826f96..830475bed4e444 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -392,4 +392,10 @@ def TestRecursiveAlias
   }];
 }
 
+def TestTypeVerification : Test_Type<"TestTypeVerification"> {
+  let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
+  let mnemonic = "type_verification";
+  let assemblyFormat = "`<` $param `>`";
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 2884c4ed6a9081..d07b067824919a 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -663,6 +663,26 @@ def TypeO : TestType<"TestQ"> {
   let assemblyFormat = "(custom<AB>($a)^ `x`) : (`y`)?";
 }
 
+// TYPE: ::llvm::LogicalResult TestPType::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type a) {
+// TYPE:   if (!(((a.isSignlessInteger(16))) || ((a.isSignlessInteger(32))))) {
+// TYPE:     emitError() << "failed to verify 'a': 16-bit signless integer or 32-bit signless integer";
+// TYPE:     return ::mlir::failure();
+// TYPE:   }
+// TYPE:   return ::mlir::success();
+// TYPE: }
+
+// TYPE: ::llvm::LogicalResult TestPType::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type a) {
+// TYPE:   if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
+// TYPE:     return ::mlir::failure();
+// TYPE:   return ::mlir::success();
+// TYPE: }
+
+def TypeP : TestType<"TestP"> {
+  let parameters = (ins AnyTypeOf<[I16, I32]>:$a);
+  let mnemonic = "type_p";
+  let assemblyFormat = "$a";
+}
+
 // DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
 // DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
 // DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 8cc8314418104c..7d4745cbe8d51c 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -93,8 +93,14 @@ class DefGen {
   void emitDialectName();
   /// Emit attribute or type builders.
   void emitBuilders();
-  /// Emit a verifier for the def.
-  void emitVerifier();
+  /// Emit a verifier declaration for custom verification (impl. provided by
+  /// the users).
+  void emitVerifierDecl();
+  /// Emit a verifier that checks type constraints.
+  void emitInvariantsVerifierImpl();
+  /// Emit an entry poiunt for verification that calls the invariants and
+  /// custom verifier.
+  void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
   /// Emit parsers and printers.
   void emitParserPrinter();
   /// Emit parameter accessors, if required.
@@ -188,9 +194,17 @@ DefGen::DefGen(const AttrOrTypeDef &def)
   emitName();
   // Emit the dialect name.
   emitDialectName();
-  // Emit the verifier.
-  if (storageCls && def.genVerifyDecl())
-    emitVerifier();
+  // Emit verification of type constraints.
+  bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
+  if (storageCls && genVerifyInvariantsImpl)
+    emitInvariantsVerifierImpl();
+  // Emit the custom verifier (written by the user).
+  bool genVerifyDecl = def.genVerifyDecl();
+  if (storageCls && genVerifyDecl)
+    emitVerifierDecl();
+  // Emit the "verifyInvariants" function if there is any verification at all.
+  if (storageCls)
+    emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl);
   // Emit the mnemonic, if there is one, and any associated parser and printer.
   if (def.getMnemonic())
     emitParserPrinter();
@@ -295,24 +309,91 @@ void DefGen::emitDialectName() {
 void DefGen::emitBuilders() {
   if (!def.skipDefaultBuilders()) {
     emitDefaultBuilder();
-    if (def.genVerifyDecl())
+    if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
       emitCheckedBuilder();
   }
   for (auto &builder : def.getBuilders()) {
     emitCustomBuilder(builder);
-    if (def.genVerifyDecl())
+    if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
       emitCheckedCustomBuilder(builder);
   }
 }
 
-void DefGen::emitVerifier() {
-  defCls.declare<UsingDeclaration>("Base::getChecked");
+void DefGen::emitVerifierDecl() {
   defCls.declareStaticMethod(
       "::llvm::LogicalResult", "verify",
       getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
                          "emitError"}}));
 }
 
+static const char *const patternParameterVerificationCode = R"(
+if (!({0})) {
+  emitError() << "failed to verify '{1}': {2}";
+  return ::mlir::failure();
+}
+)";
+
+void DefGen::emitInvariantsVerifierImpl() {
+  SmallVector<MethodParameter> builderParams = getBuilderParams(
+      {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+  Method *verifier =
+      defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl",
+                       Method::Static, builderParams);
+  verifier->body().indent();
+
+  // Generate verification for each parameter that is a type constraint.
+  for (auto it : llvm::enumerate(def.getParameters())) {
+    const AttrOrTypeParameter &param = it.value();
+    std::optional<TypeConstraint> constraint = param.getTypeConstraint();
+    // No verification needed for parameters that are not type constraints.
+    if (!constraint.has_value())
+      continue;
+    FmtContext ctx;
+    // Note: Skip over the first method parameter (`emitError`).
+    ctx.withSelf(builderParams[it.index() + 1].getName());
+    std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx);
+    verifier->body() << formatv(patternParameterVerificationCode, condition,
+                                param.getName(), constraint->getSummary())
+                     << "\n";
+  }
+  verifier->body() << "return ::mlir::success();";
+}
+
+void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
+  if (!hasImpl && !hasCustomVerifier)
+    return;
+  defCls.declare<UsingDeclaration>("Base::getChecked");
+  SmallVector<MethodParameter> builderParams = getBuilderParams(
+      {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+  Method *verifier =
+      defCls.addMethod("::llvm::LogicalResult", "verifyInvariants",
+                       Method::Static, builderParams);
+  verifier->body().indent();
+  if (hasImpl) {
+    // Call the verifier that checks the type constraints.
+    verifier->body() << "if (::mlir::failed(verifyInvariantsImpl(";
+    for (int i = 0, e = builderParams.size(); i < e; ++i) {
+      if (i > 0)
+        verifier->body() << ", ";
+      verifier->body() << builderParams[i].getName();
+    }
+    verifier->body() << ")))\n";
+    verifier->body() << "  return ::mlir::failure();\n";
+  }
+  if (hasCustomVerifier) {
+    // Call the custom verifier that is provided by the user.
+    verifier->body() << "if (::mlir::failed(verify(";
+    for (int i = 0, e = builderParams.size(); i < e; ++i) {
+      if (i > 0)
+        verifier->body() << ", ";
+      verifier->body() << builderParams[i].getName();
+    }
+    verifier->body() << ")))\n";
+    verifier->body() << "  return ::mlir::failure();\n";
+  }
+  verifier->body() << "return ::mlir::success();";
+}
+
 void DefGen::emitParserPrinter() {
   auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
       "::llvm::StringLiteral", "getMnemonic");
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index dacc20b6ba2086..a4ae271edb6bd2 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -323,7 +323,7 @@ void DefFormat::genParser(MethodBody &os) {
 
   // Generate call to the attribute or type builder. Use the checked getter
   // if one was generated.
-  if (def.genVerifyDecl()) {
+  if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) {
     os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
                 &ctx, def.getCppClassName());
   } else {



More information about the Mlir-commits mailing list