[Mlir-commits] [mlir] [mlir][ODS] Verify type constraints in Types and Attributes (PR #102326)

Matthias Springer llvmlistbot at llvm.org
Fri Aug 9 11:52:08 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/102326

>From 9968cd55e4f606728b90854f0ff6f0ec8b04a1ae Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 9 Aug 2024 20:21:44 +0200
Subject: [PATCH 1/2] [mlir][ODS] Consistent `cppType` / `cppClassName` usage

Make sure that the usage of `cppType` and `cppClassName` of type and attribute definitions/constraints is consistent in TableGen.

- `cppClassName`: The C++ class name of the type or attribute.
- `cppType`: The fully qualified C++ class name: C++ namespace and C++ class name.
---
 mlir/include/mlir/IR/AttrTypeBase.td          | 18 +++++-----
 mlir/include/mlir/IR/CommonAttrConstraints.td |  6 ++--
 mlir/include/mlir/IR/CommonTypeConstraints.td | 34 +++++++++----------
 mlir/include/mlir/IR/Constraints.td           |  4 +--
 mlir/include/mlir/TableGen/Type.h             |  4 +--
 mlir/lib/TableGen/Type.cpp                    | 17 ++--------
 mlir/lib/Tools/PDLL/Parser/Parser.cpp         |  5 ++-
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   | 19 ++---------
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        |  6 ++--
 9 files changed, 44 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 91c9283de8bd41..d176b36068f7a5 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -256,7 +256,7 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
       AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
   // The name of the C++ Attribute class.
   string cppClassName = name # "Attr";
-  let storageType = dialect.cppNamespace # "::" # name # "Attr";
+  let storageType = dialect.cppNamespace # "::" # cppClassName;
 
   // The underlying C++ value type
   let returnType = dialect.cppNamespace # "::" # cppClassName;
@@ -275,12 +275,10 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
   //
   // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
   // expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
-  let convertFromStorage = "::llvm::cast<" # dialect.cppNamespace #
-                                 "::" # cppClassName # ">($_self)";
+  let convertFromStorage = "::llvm::cast<" # cppType # ">($_self)";
 
   // The predicate for when this def is used as a constraint.
-  let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
-                                 "::" # cppClassName # ">($_self)">;
+  let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
 }
 
 // Define a new type, named `name`, belonging to `dialect` that inherits from
@@ -289,6 +287,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
               string baseCppClass = "::mlir::Type">
     : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
       AttrOrTypeDef<"Type", name, traits, baseCppClass> {
+  // The name of the C++ Type class.
+  string cppClassName = name # "Type";
+
   // Make it possible to use such type as parameters for other types.
   string cppType = dialect.cppNamespace # "::" # cppClassName;
 
@@ -297,12 +298,11 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
 
   // A constant builder provided when the type has no parameters.
   let builderCall = !if(!empty(parameters),
-                           "$_builder.getType<" # dialect.cppNamespace #
-                               "::" # cppClassName # ">()",
+                           "$_builder.getType<" # cppType # ">()",
                            "");
+
   // The predicate for when this def is used as a constraint.
-  let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
-                                 "::" # cppClassName # ">($_self)">;
+  let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index d99bde1f87ef00..6774a7c568315d 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -169,14 +169,14 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {
 
 // Any attribute from the given list
 class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
-                string cppClassName = "::mlir::Attribute",
+                string cppType = "::mlir::Attribute",
                 string fromStorage = "$_self"> : Attr<
     // Satisfy any of the allowed attribute's condition
     Or<!foreach(allowedattr, allowedAttrs, allowedattr.predicate)>,
     !if(!eq(summary, ""),
         !interleave(!foreach(t, allowedAttrs, t.summary), " or "),
         summary)> {
-    let returnType = cppClassName;
+    let returnType = cppType;
     let convertFromStorage = fromStorage;
 }
 
@@ -369,7 +369,7 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
 }
 
 class TypeAttrOf<Type ty>
-   : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
+   : TypeAttrBase<ty.cppType, "type attribute of " # ty.summary,
                     ty.predicate> {
   let constBuilderCall = "::mlir::TypeAttr::get($0)";
 }
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa2420..65e1b1cbdc905a 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -98,23 +98,23 @@ def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;
 
 // A type, carries type constraints.
 class Type<Pred condition, string descr = "",
-           string cppClassName = "::mlir::Type"> :
-    TypeConstraint<condition, descr, cppClassName> {
+           string cppType = "::mlir::Type"> :
+    TypeConstraint<condition, descr, cppType> {
   string description = "";
   string builderCall = "";
 }
 
 // Allows providing an alternative name and summary to an existing type def.
 class TypeAlias<Type t, string summary = t.summary> :
-    Type<t.predicate, summary, t.cppClassName> {
+    Type<t.predicate, summary, t.cppType> {
   let description = t.description;
   let builderCall = t.builderCall;
 }
 
 // A type of a specific dialect.
 class DialectType<Dialect d, Pred condition, string descr = "",
-                  string cppClassName = "::mlir::Type"> :
-    Type<condition, descr, cppClassName> {
+                  string cppType = "::mlir::Type"> :
+    Type<condition, descr, cppType> {
   Dialect dialect = d;
 }
 
@@ -122,7 +122,7 @@ class DialectType<Dialect d, Pred condition, string descr = "",
 // class is used for supporting variadic operands/results.
 class Variadic<Type type> : TypeConstraint<type.predicate,
                                            "variadic of " # type.summary,
-                                           type.cppClassName> {
+                                           type.cppType> {
   Type baseType = type;
   int minSize = 0;
 }
@@ -140,7 +140,7 @@ class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
 // An optional type constraint. It expands to either zero or one of the base
 // type. This class is used for supporting optional operands/results.
 class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
-                                           type.cppClassName> {
+                                           type.cppType> {
   Type baseType = type;
 }
 
@@ -172,33 +172,33 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
 
 // Any type from the given list
 class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
-                string cppClassName = "::mlir::Type"> : Type<
+                string cppType = "::mlir::Type"> : Type<
     // Satisfy any of the allowed types' conditions.
     Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
     !if(!eq(summary, ""),
         !interleave(!foreach(t, allowedTypeList, t.summary), " or "),
         summary),
-    cppClassName> {
+    cppType> {
   list<Type> allowedTypes = allowedTypeList;
 }
 
 // A type that satisfies the constraints of all given types.
 class AllOfType<list<Type> allowedTypeList, string summary = "",
-                string cppClassName = "::mlir::Type"> : Type<
+                string cppType = "::mlir::Type"> : Type<
     // Satisfy all of the allowed types' conditions.
     And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
     !if(!eq(summary, ""),
         !interleave(!foreach(t, allowedTypeList, t.summary), " and "),
         summary),
-    cppClassName> {
+    cppType> {
   list<Type> allowedTypes = allowedTypeList;
 }
 
 // A type that satisfies additional predicates.
 class ConfinedType<Type type, list<Pred> predicates, string summary = "",
-                   string cppClassName = type.cppClassName> : Type<
+                   string cppType = type.cppType> : Type<
     And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
-    summary, cppClassName>;
+    summary, cppType>;
 
 // Integer types.
 
@@ -375,23 +375,23 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
 
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
-                    string descr, string cppClassName = "::mlir::Type"> :
+                    string descr, string cppType = "::mlir::Type"> :
     // First, check the container predicate.  Then, substitute the extracted
     // element into the element type checker.
     Type<And<[containerPred,
                 SubstLeaves<"$_self", !cast<string>(elementTypeCall),
                 etype.predicate>]>,
-         descr # " of " # etype.summary # " values", cppClassName>;
+         descr # " of " # etype.summary # " values", cppType>;
 
 class ShapedContainerType<list<Type> allowedTypes,
                           Pred containerPred, string descr,
-                          string cppClassName = "::mlir::Type"> :
+                          string cppType = "::mlir::Type"> :
     Type<And<[containerPred,
               Concat<"[](::mlir::Type elementType) { return ",
                 SubstLeaves<"$_self", "elementType",
                 AnyTypeOf<allowedTypes>.predicate>,
                 "; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>,
-         descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppClassName>;
+         descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>;
 
 // Whether a shaped type is ranked.
 def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">;
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8e..39bc55db63da1a 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -149,10 +149,10 @@ class Constraint<Pred pred, string desc = ""> {
 
 // Subclass for constraints on a type.
 class TypeConstraint<Pred predicate, string summary = "",
-                     string cppClassNameParam = "::mlir::Type"> :
+                     string cppTypeParam = "::mlir::Type"> :
     Constraint<predicate, summary> {
   // The name of the C++ Type class if known, or Type if not.
-  string cppClassName = cppClassNameParam;
+  string cppType = cppTypeParam;
 }
 
 // Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index 06cf4f5730d565..b974ac281041bc 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -56,8 +56,8 @@ class TypeConstraint : public Constraint {
   // returns std::nullopt otherwise.
   std::optional<StringRef> getBuilderCall() const;
 
-  // Return the C++ class name for this type (which may just be ::mlir::Type).
-  std::string getCPPClassName() const;
+  // Return the C++ type for this type (which may just be ::mlir::Type).
+  StringRef getCppType() const;
 };
 
 // Wrapper class with helper methods for accessing Types defined in TableGen.
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index e9f2394dd540af..cda752297988bb 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -59,20 +59,9 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
       .Default([](auto *) { return std::nullopt; });
 }
 
-// Return the C++ class name for this type (which may just be ::mlir::Type).
-std::string TypeConstraint::getCPPClassName() const {
-  StringRef className = def->getValueAsString("cppClassName");
-
-  // If the class name is already namespace resolved, use it.
-  if (className.contains("::"))
-    return className.str();
-
-  // Otherwise, check to see if there is a namespace from a dialect to prepend.
-  if (const llvm::RecordVal *value = def->getValue("dialect")) {
-    Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
-    return (dialect.getCppNamespace() + "::" + className).str();
-  }
-  return className.str();
+// Return the C++ type for this type (which may just be ::mlir::Type).
+StringRef TypeConstraint::getCppType() const {
+  return def->getValueAsString("cppType");
 }
 
 Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 1f0df033d43398..01c78e280080ee 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -879,8 +879,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
       -> const ods::TypeConstraint & {
     return odsContext.insertTypeConstraint(
         cst.constraint.getUniqueDefName(),
-        processDoc(cst.constraint.getSummary()),
-        cst.constraint.getCPPClassName());
+        processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
   };
   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
@@ -944,7 +943,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
     tblgen::TypeConstraint constraint(def);
     decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
         constraint, convertLocToRange(def->getLoc().front()), typeTy,
-        constraint.getCPPClassName()));
+        constraint.getCppType()));
   }
   /// OpInterfaces.
   ast::Type opTy = ast::OperationType::get(ctx);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a2ceefb34db453..66dbb16760ebb0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2085,21 +2085,8 @@ static void generateValueRangeStartAndEnd(
 }
 
 static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
-  std::string str = "::mlir::Value";
-  /// If the CPPClassName is not a fully qualified type. Uses of types
-  /// across Dialect fail because they are not in the correct namespace. So we
-  /// dont generate TypedValue unless the type is fully qualified.
-  /// getCPPClassName doesn't return the fully qualified path for
-  /// `mlir::pdl::OperationType` see
-  /// https://github.com/llvm/llvm-project/issues/57279.
-  /// Adaptor will have values that are not from the type of their operation and
-  /// this is expected, so we dont generate TypedValue for Adaptor
-  if (value.constraint.getCPPClassName() != "::mlir::Type" &&
-      StringRef(value.constraint.getCPPClassName()).starts_with("::"))
-    str = llvm::formatv("::mlir::TypedValue<{0}>",
-                        value.constraint.getCPPClassName())
-              .str();
-  return str;
+  return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType())
+      .str();
 }
 
 // Generates the named operand getter methods for the given Operator `op` and
@@ -3944,7 +3931,7 @@ void OpEmitter::genTraits() {
   // For single result ops with a known specific type, generate a OneTypedResult
   // trait.
   if (numResults == 1 && numVariadicResults == 0) {
-    auto cppName = op.getResults().begin()->constraint.getCPPClassName();
+    auto cppName = op.getResults().begin()->constraint.getCppType();
     opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
   }
 
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 27ad79a5c1efed..9a95f495b77658 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1657,7 +1657,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
       TypeSwitch<FormatElement *>(dir->getArg())
           .Case<OperandVariable, ResultVariable>([&](auto operand) {
             body << formatv(parserCode,
-                            operand->getVar()->constraint.getCPPClassName(),
+                            operand->getVar()->constraint.getCppType(),
                             listName);
           })
           .Default([&](auto operand) {
@@ -2603,7 +2603,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
     }
     if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
         !var->isOptional()) {
-      std::string cppClass = var->constraint.getCPPClassName();
+      StringRef cppType = var->constraint.getCppType();
       if (dir->shouldBeQualified()) {
         body << "   _odsPrinter << " << op.getGetterName(var->name)
              << "().getType();\n";
@@ -2612,7 +2612,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
       body << "  {\n"
            << "    auto type = " << op.getGetterName(var->name)
            << "().getType();\n"
-           << "    if (auto validType = ::llvm::dyn_cast<" << cppClass
+           << "    if (auto validType = ::llvm::dyn_cast<" << cppType
            << ">(type))\n"
            << "      _odsPrinter.printStrippedAttrOrType(validType);\n"
            << "   else\n"

>From 27f3ffae411d6b0f5227850ab11a0ad69090daca Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 7 Aug 2024 17:49:45 +0200
Subject: [PATCH 2/2] [mlir][ODS] Verify type constraints in Types and
 Attributes

When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:
```
def TestTypeVerification : Test_Type<"TestTypeVerification"> {
  let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
  // ...
}
```

No verification code was generated to ensure that `$param` is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
---
 mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h |  7 +-
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h  | 12 ++-
 mlir/include/mlir/Dialect/Quant/QuantTypes.h  | 43 +++++----
 .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.h   | 14 +--
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        | 10 +-
 mlir/include/mlir/IR/StorageUniquerSupport.h  | 10 +-
 mlir/include/mlir/IR/Types.h                  |  2 +-
 mlir/include/mlir/TableGen/AttrOrTypeDef.h    |  8 ++
 mlir/include/mlir/TableGen/Class.h            |  2 +
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  6 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      | 12 +--
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      | 39 ++++----
 mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp |  9 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      |  9 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  8 ++
 mlir/lib/TableGen/AttrOrTypeDef.cpp           | 13 +++
 mlir/test/IR/test-verifiers-type.mlir         |  9 ++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  1 -
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  6 ++
 mlir/test/mlir-tblgen/attr-or-type-format.td  | 79 ++++++++++++++-
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp   | 96 +++++++++++++++++--
 .../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp |  2 +-
 22 files changed, 304 insertions(+), 93 deletions(-)
 create mode 100644 mlir/test/IR/test-verifiers-type.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a841..57acd72610415f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -148,9 +148,10 @@ class MMAMatrixType
 
   /// Verify that shape and elementType are actually allowed for the
   /// MMAMatrixType.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<int64_t> shape, Type elementType,
-                              StringRef operand);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType,
+                   StringRef operand);
 
   /// Get number of dims.
   unsigned getNumDims() const;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 1befdfa74f67c5..2ea589a7c4c3bd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,11 +180,13 @@ class LLVMStructType
   ArrayRef<Type> getBody() const;
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              StringRef, bool);
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<Type> types, bool);
-  using Base::verify;
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
+                   bool);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<Type> types, bool);
+  using Base::verifyInvariants;
 
   /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
   /// DataLayout instance and query it instead.
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index de5aed0a91a209..57a2aa29833657 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
   /// The maximum number of bits supported for storage types.
   static constexpr unsigned MaxStorageBits = 32;
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
   static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
              int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 };
 
 /// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
              int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, double scale,
-                              int64_t zeroPoint, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, double scale,
+                   int64_t zeroPoint, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the scale term. The scale designates the difference between the real
   /// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
              int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, ArrayRef<double> scales,
-                              ArrayRef<int64_t> zeroPoints,
-                              int32_t quantizedDimension,
-                              int64_t storageTypeMin, int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+                   int32_t quantizedDimension, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the quantization scales. The scales designate the difference between
   /// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
              double min, double max);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type expressedType, double min, double max);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type expressedType, double min, double max);
   double getMin() const;
   double getMax() const;
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 5ebfa9ca5ec25c..2bdd7a5bf3dd83 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
   /// Returns `spirv::StorageClass`.
   std::optional<StorageClass> getStorageClass();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr descriptorSet, IntegerAttr binding,
-                              IntegerAttr storageClass);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr descriptorSet, IntegerAttr binding,
+                   IntegerAttr storageClass);
 
   static constexpr StringLiteral name = "spirv.interface_var_abi";
 };
@@ -128,9 +129,10 @@ class VerCapExtAttr
   /// Returns the capabilities as an integer array attribute.
   ArrayAttr getCapabilitiesAttr();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr version, ArrayAttr capabilities,
-                              ArrayAttr extensions);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr version, ArrayAttr capabilities,
+                   ArrayAttr extensions);
 
   static constexpr StringLiteral name = "spirv.ver_cap_ext";
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b44403..e2d04553d91b8b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -258,8 +258,9 @@ class SampledImageType
   static SampledImageType
   getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type imageType);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type imageType);
 
   Type getImageType() const;
 
@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
                                Type columnType, uint32_t columnCount);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type columnType, uint32_t columnCount);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type columnType, uint32_t columnCount);
 
   /// Returns true if the matrix elements are vectors of float elements.
   static bool isValidColumnType(Type columnType);
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index fb64f15162df5b..d6ccbbd8579947 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   template <typename... Args>
   static ConcreteT get(MLIRContext *ctx, Args &&...args) {
     // Ensure that the invariants are correct for construction.
-    assert(
-        succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
+    assert(succeeded(
+        ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
     return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
   }
 
@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                               MLIRContext *ctx, Args... args) {
     // If the construction invariants fail then we return a null attribute.
-    if (failed(ConcreteT::verify(emitErrorFn, args...)))
+    if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
       return ConcreteT();
     return UniquerT::template get<ConcreteT>(ctx, args...);
   }
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
 
   /// Default implementation that just returns success.
   template <typename... Args>
-  static LogicalResult verify(Args... args) {
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
+                   Args... args) {
     return success();
   }
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 60dc8fee0f4a96..dfc7e69472c891 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -34,7 +34,7 @@ class AsmState;
 /// Derived type classes are expected to implement several required
 /// implementation hooks:
 ///  * Optional:
-///    - static LogicalResult verify(
+///    - static LogicalResult verifyInvariants(
 ///                                function_ref<InFlightDiagnostic()> emitError,
 ///                                Args... args)
 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 19c3a9183ec2cf..36744c85bc7086 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Builder.h"
+#include "mlir/TableGen/Constraint.h"
 #include "mlir/TableGen/Trait.h"
 
 namespace llvm {
@@ -85,6 +86,9 @@ class AttrOrTypeParameter {
   /// Get an optional C++ parameter parser.
   std::optional<StringRef> getParser() const;
 
+  /// If this is a type constraint, return it.
+  std::optional<Constraint> getConstraint() const;
+
   /// Get an optional C++ parameter printer.
   std::optional<StringRef> getPrinter() const;
 
@@ -198,6 +202,10 @@ class AttrOrTypeDef {
   /// method.
   bool genVerifyDecl() const;
 
+  /// Return true if we need to generate any type constraint verification and
+  /// the getChecked method.
+  bool genVerifyInvariantsImpl() const;
+
   /// Returns the def's extra class declaration code.
   std::optional<StringRef> getExtraDecls() const;
 
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 855952d19492db..f750a34a3b2ba4 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -67,6 +67,8 @@ class MethodParameter {
 
   /// Get the C++ type.
   StringRef getType() const { return type; }
+  /// Get the C++ parameter name.
+  StringRef getName() const { return name; }
   /// Returns true if the parameter has a default value.
   bool hasDefaultValue() const { return !defaultValue.empty(); }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb0..a1f87a637a6141 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
 }
 
 LogicalResult
-MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      ArrayRef<int64_t> shape, Type elementType,
-                      StringRef operand) {
+MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                ArrayRef<int64_t> shape, Type elementType,
+                                StringRef operand) {
   if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError() << "operand expected to be one of AOp, BOp or COp";
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index dc7aef8ef7f850..7f10a15ff31ff9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
 
 bool LLVMStructType::isValidElementType(Type type) {
   return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
-                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
-      type);
+                    LLVMFunctionType, LLVMTokenType>(type);
 }
 
 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
                         : getImpl()->getTypeList();
 }
 
-LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
-                                     StringRef, bool) {
+LogicalResult
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
+                                 bool) {
   return success();
 }
 
 LogicalResult
-LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
-                       ArrayRef<Type> types, bool) {
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<Type> types, bool) {
   for (Type t : types)
     if (!isValidElementType(t))
       return emitError() << "invalid LLVM structure element type: " << t;
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be2..c2ba9c04e8771d 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
 }
 
 LogicalResult
-QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      unsigned flags, Type storageType, Type expressedType,
-                      int64_t storageTypeMin, int64_t storageTypeMax) {
+QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                unsigned flags, Type storageType,
+                                Type expressedType, int64_t storageTypeMin,
+                                int64_t storageTypeMax) {
   // Verify that the storage type is integral.
   // This restriction may be lifted at some point in favor of using bf16
   // or f16 as exact representations on hardware where that is advantageous.
@@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 }
 
 LogicalResult
-AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                         unsigned flags, Type storageType, Type expressedType,
-                         int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                   unsigned flags, Type storageType,
+                                   Type expressedType, int64_t storageTypeMin,
+                                   int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
                           storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedType::verify(
+LogicalResult UniformQuantizedType::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
     int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
                           quantizedDimension, storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedPerAxisType::verify(
+LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
     Type storageType, Type expressedType, ArrayRef<double> scales,
     ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
     int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked(
                           min, max);
 }
 
-LogicalResult
-CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                                Type expressedType, double min, double max) {
+LogicalResult CalibratedQuantizedType::verifyInvariants(
+    function_ref<InFlightDiagnostic()> emitError, Type expressedType,
+    double min, double max) {
   // Verify that the expressed type is floating point.
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 8a0ee7a3d81367..b71be23fdf47d0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -162,7 +162,7 @@ spirv::InterfaceVarABIAttr::getStorageClass() {
   return std::nullopt;
 }
 
-LogicalResult spirv::InterfaceVarABIAttr::verify(
+LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
     IntegerAttr binding, IntegerAttr storageClass) {
   if (!descriptorSet.getType().isSignlessInteger(32))
@@ -257,10 +257,9 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
   return llvm::cast<ArrayAttr>(getImpl()->capabilities);
 }
 
-LogicalResult
-spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                             IntegerAttr version, ArrayAttr capabilities,
-                             ArrayAttr extensions) {
+LogicalResult spirv::VerCapExtAttr::verifyInvariants(
+    function_ref<InFlightDiagnostic()> emitError, IntegerAttr version,
+    ArrayAttr capabilities, ArrayAttr extensions) {
   if (!version.getType().isSignlessInteger(32))
     return emitError() << "expected 32-bit integer for version";
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3808620bdffa6d..c5590905b75045 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -817,8 +817,8 @@ SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
 
 LogicalResult
-SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
-                         Type imageType) {
+SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                   Type imageType) {
   if (!llvm::isa<ImageType>(imageType))
     return emitError() << "expected image type";
 
@@ -1181,8 +1181,9 @@ MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
                           columnCount);
 }
 
-LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
-                                 Type columnType, uint32_t columnCount) {
+LogicalResult
+MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                             Type columnType, uint32_t columnCount) {
   if (columnCount < 2 || columnCount > 4)
     return emitError() << "matrix can have 2, 3, or 4 columns only";
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1135ea32fe1abb..a284aa2f1f020f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1227,6 +1227,14 @@ StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
 }
 
+StorageSpecifierType
+StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                 MLIRContext *ctx,
+                                 SparseTensorEncodingAttr encoding) {
+  return Base::getChecked(emitError, ctx,
+                          getNormalizedEncodingForSpecifier(encoding));
+}
+
 //===----------------------------------------------------------------------===//
 // SparseTensorDialect Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index c9dbb3bc76b1fa..9b9d9fd2317d99 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -184,6 +184,12 @@ bool AttrOrTypeDef::genVerifyDecl() const {
   return def->getValueAsBit("genVerifyDecl");
 }
 
+bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
+  return any_of(parameters, [](const AttrOrTypeParameter &p) {
+    return p.getConstraint() != std::nullopt;
+  });
+}
+
 std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
   auto value = def->getValueAsString("extraClassDeclaration");
   return value.empty() ? std::optional<StringRef>() : value;
@@ -331,6 +337,13 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
 
 llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
 
+std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
+  if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
+    if (param->getDef()->isSubClassOf("Constraint"))
+      return Constraint(param->getDef());
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // AttributeSelfTypeParameter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
new file mode 100644
index 00000000000000..96d0005eb7a19d
--- /dev/null
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// CHECK: "test.type_producer"() : () -> !test.type_verification<i16>
+"test.type_producer"() : () -> !test.type_verification<i16>
+
+// -----
+
+// expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_verification<f16>
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 8f109f8ce5e6dd..b3b94bd0ffea31 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -270,7 +270,6 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
   let assemblyFormat = "`<` $a `>`";
 
   let skipDefaultBuilders = 1;
-  let genVerifyDecl = 1;
   let builders = [AttrBuilder<(ins "int":$a), [{
     return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
   }], "::mlir::Attribute">];
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index d96152a0826f96..830475bed4e444 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -392,4 +392,10 @@ def TestRecursiveAlias
   }];
 }
 
+def TestTypeVerification : Test_Type<"TestTypeVerification"> {
+  let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
+  let mnemonic = "type_verification";
+  let assemblyFormat = "`<` $param `>`";
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 2884c4ed6a9081..c5348409e8e44f 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -1,8 +1,9 @@
-// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-attrdef-defs -I %S/../../include | FileCheck %s --check-prefix=ATTR
-// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-typedef-defs -I %S/../../include | FileCheck %s --check-prefix=TYPE
-// RUN: sed 's/DEFAULT_TYPE_PARSER/1/' %s | mlir-tblgen -gen-typedef-defs -I %S/../../include | FileCheck %s --check-prefix=TYPE --check-prefix=DEFAULT_TYPE_PARSER
+// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-attrdef-defs -attrdefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=ATTR
+// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-typedef-defs -typedefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=TYPE
+// RUN: sed 's/DEFAULT_TYPE_PARSER/1/' %s | mlir-tblgen -gen-typedef-defs -typedefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=TYPE --check-prefix=DEFAULT_TYPE_PARSER
 
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributes.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 
@@ -663,6 +664,78 @@ def TypeO : TestType<"TestQ"> {
   let assemblyFormat = "(custom<AB>($a)^ `x`) : (`y`)?";
 }
 
+// Test attr / type verification.
+
+// TYPE: ::llvm::LogicalResult TestPType::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type a) {
+// TYPE:   if (!(((a.isSignlessInteger(16))) || ((a.isSignlessInteger(32))))) {
+// TYPE:     emitError() << "failed to verify 'a': 16-bit signless integer or 32-bit signless integer";
+// TYPE:     return ::mlir::failure();
+// TYPE:   }
+// TYPE:   return ::mlir::success();
+// TYPE: }
+
+// TYPE: ::llvm::LogicalResult TestPType::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type a) {
+// TYPE:   if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
+// TYPE:     return ::mlir::failure();
+// TYPE:   if (::mlir::failed(verify(emitError, a)))
+// TYPE:     return ::mlir::failure();
+// TYPE:   return ::mlir::success();
+// TYPE: }
+
+def TypeP : TestType<"TestP"> {
+  let parameters = (ins AnyTypeOf<[I16, I32]>:$a);
+  let mnemonic = "type_p";
+  let genVerifyDecl = 1;
+  let assemblyFormat = "$a";
+}
+
+// ATTR: ::llvm::LogicalResult TestRAttr::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::IntegerType a) {
+// ATTR:   if (!((a.isSignlessInteger(32)))) {
+// ATTR:     emitError() << "failed to verify 'a': 32-bit signless integer";
+// ATTR:     return ::mlir::failure();
+// ATTR:   }
+// ATTR:   return ::mlir::success();
+// ATTR: }
+
+// ATTR: ::llvm::LogicalResult TestRAttr::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::IntegerType a) {
+// ATTR:   if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
+// ATTR:     return ::mlir::failure();
+// ATTR:   if (::mlir::failed(verify(emitError, a)))
+// ATTR:     return ::mlir::failure();
+// ATTR:   return ::mlir::success();
+// ATTR: }
+
+def AttrR : TestAttr<"TestR"> {
+  let parameters = (ins I32:$a);
+  let mnemonic = "attr_r";
+  let genVerifyDecl = 1;
+  let assemblyFormat = "$a";
+}
+
+// TYPE: ::llvm::LogicalResult TestSType::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::ArrayAttr a) {
+// TYPE:   if (!((::llvm::isa<::mlir::ArrayAttr>(a)))) {
+// TYPE:     emitError() << "failed to verify 'a': A collection of other Attribute values";
+// TYPE:     return ::mlir::failure();
+// TYPE:   }
+// TYPE:   return ::mlir::success();
+// TYPE: }
+
+// TYPE: ::llvm::LogicalResult TestSType::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::ArrayAttr a) {
+// TYPE:   if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
+// TYPE:     return ::mlir::failure();
+// TYPE:   if (::mlir::failed(verify(emitError, a)))
+// TYPE:     return ::mlir::failure();
+// TYPE:   return ::mlir::success();
+// TYPE: }
+
+def TypeS : TestType<"TestS"> {
+  // TODO: Support attribute constraints as parameters.
+  let parameters = (ins Builtin_ArrayAttr:$a);
+  let mnemonic = "type_s";
+  let genVerifyDecl = 1;
+  let assemblyFormat = "$a";
+}
+
 // DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
 // DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
 // DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 8cc8314418104c..71ba6a5c73da9e 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -93,8 +93,14 @@ class DefGen {
   void emitDialectName();
   /// Emit attribute or type builders.
   void emitBuilders();
-  /// Emit a verifier for the def.
-  void emitVerifier();
+  /// Emit a verifier declaration for custom verification (impl. provided by
+  /// the users).
+  void emitVerifierDecl();
+  /// Emit a verifier that checks type constraints.
+  void emitInvariantsVerifierImpl();
+  /// Emit an entry poiunt for verification that calls the invariants and
+  /// custom verifier.
+  void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
   /// Emit parsers and printers.
   void emitParserPrinter();
   /// Emit parameter accessors, if required.
@@ -188,9 +194,17 @@ DefGen::DefGen(const AttrOrTypeDef &def)
   emitName();
   // Emit the dialect name.
   emitDialectName();
-  // Emit the verifier.
-  if (storageCls && def.genVerifyDecl())
-    emitVerifier();
+  // Emit verification of type constraints.
+  bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
+  if (storageCls && genVerifyInvariantsImpl)
+    emitInvariantsVerifierImpl();
+  // Emit the custom verifier (written by the user).
+  bool genVerifyDecl = def.genVerifyDecl();
+  if (storageCls && genVerifyDecl)
+    emitVerifierDecl();
+  // Emit the "verifyInvariants" function if there is any verification at all.
+  if (storageCls)
+    emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl);
   // Emit the mnemonic, if there is one, and any associated parser and printer.
   if (def.getMnemonic())
     emitParserPrinter();
@@ -295,24 +309,88 @@ void DefGen::emitDialectName() {
 void DefGen::emitBuilders() {
   if (!def.skipDefaultBuilders()) {
     emitDefaultBuilder();
-    if (def.genVerifyDecl())
+    if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
       emitCheckedBuilder();
   }
   for (auto &builder : def.getBuilders()) {
     emitCustomBuilder(builder);
-    if (def.genVerifyDecl())
+    if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
       emitCheckedCustomBuilder(builder);
   }
 }
 
-void DefGen::emitVerifier() {
-  defCls.declare<UsingDeclaration>("Base::getChecked");
+void DefGen::emitVerifierDecl() {
   defCls.declareStaticMethod(
       "::llvm::LogicalResult", "verify",
       getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
                          "emitError"}}));
 }
 
+static const char *const patternParameterVerificationCode = R"(
+if (!({0})) {
+  emitError() << "failed to verify '{1}': {2}";
+  return ::mlir::failure();
+}
+)";
+
+void DefGen::emitInvariantsVerifierImpl() {
+  SmallVector<MethodParameter> builderParams = getBuilderParams(
+      {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+  Method *verifier =
+      defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl",
+                       Method::Static, builderParams);
+  verifier->body().indent();
+
+  // Generate verification for each parameter that is a type constraint.
+  for (auto it : llvm::enumerate(def.getParameters())) {
+    const AttrOrTypeParameter &param = it.value();
+    std::optional<Constraint> constraint = param.getConstraint();
+    // No verification needed for parameters that are not type constraints.
+    if (!constraint.has_value())
+      continue;
+    FmtContext ctx;
+    // Note: Skip over the first method parameter (`emitError`).
+    ctx.withSelf(builderParams[it.index() + 1].getName());
+    std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx);
+    verifier->body() << formatv(patternParameterVerificationCode, condition,
+                                param.getName(), constraint->getSummary())
+                     << "\n";
+  }
+  verifier->body() << "return ::mlir::success();";
+}
+
+void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
+  if (!hasImpl && !hasCustomVerifier)
+    return;
+  defCls.declare<UsingDeclaration>("Base::getChecked");
+  SmallVector<MethodParameter> builderParams = getBuilderParams(
+      {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+  Method *verifier =
+      defCls.addMethod("::llvm::LogicalResult", "verifyInvariants",
+                       Method::Static, builderParams);
+  verifier->body().indent();
+
+  auto emitVerifierCall = [&](StringRef name) {
+    verifier->body() << strfmt("if (::mlir::failed({0}(", name);
+    llvm::interleaveComma(
+        llvm::map_range(builderParams,
+                        [](auto &param) { return param.getName(); }),
+        verifier->body());
+    verifier->body() << ")))\n";
+    verifier->body() << "  return ::mlir::failure();\n";
+  };
+
+  if (hasImpl) {
+    // Call the verifier that checks the type constraints.
+    emitVerifierCall("verifyInvariantsImpl");
+  }
+  if (hasCustomVerifier) {
+    // Call the custom verifier that is provided by the user.
+    emitVerifierCall("verify");
+  }
+  verifier->body() << "return ::mlir::success();";
+}
+
 void DefGen::emitParserPrinter() {
   auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
       "::llvm::StringLiteral", "getMnemonic");
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index dacc20b6ba2086..a4ae271edb6bd2 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -323,7 +323,7 @@ void DefFormat::genParser(MethodBody &os) {
 
   // Generate call to the attribute or type builder. Use the checked getter
   // if one was generated.
-  if (def.genVerifyDecl()) {
+  if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) {
     os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
                 &ctx, def.getCppClassName());
   } else {



More information about the Mlir-commits mailing list