[Mlir-commits] [mlir] 06e25d5 - [mlir][IR] Refactor the `getChecked` and `verifyConstructionInvariants` methods on Attributes/Types

River Riddle llvmlistbot at llvm.org
Mon Feb 22 17:38:13 PST 2021


Author: River Riddle
Date: 2021-02-22T17:37:49-08:00
New Revision: 06e25d56451977ef5b7052282faacfe3d42acb65

URL: https://github.com/llvm/llvm-project/commit/06e25d56451977ef5b7052282faacfe3d42acb65
DIFF: https://github.com/llvm/llvm-project/commit/06e25d56451977ef5b7052282faacfe3d42acb65.diff

LOG: [mlir][IR] Refactor the `getChecked` and `verifyConstructionInvariants` methods on Attributes/Types

`verifyConstructionInvariants` is intended to allow for verifying the invariants of an attribute/type on construction, and `getChecked` is intended to enable more graceful error handling aside from an assert. There are a few problems with the current implementation of these methods:
* `verifyConstructionInvariants` requires an mlir::Location for emitting errors, which is prohibitively costly in the situations that would most likely use them, e.g. the parser.
This creates an unfortunate code duplication between the verifier code and the parser code, given that the parser operates on llvm::SMLoc and it is an undesirable overhead to pre-emptively convert from that to an mlir::Location.
* `getChecked` effectively requires duplicating the definition of the `get` method, creating a quite clunky workflow due to the subtle different in its signature.

This revision aims to talk the above problems by refactoring the implementation to use a callback for error emission. Using a callback allows for deferring the costly part of error emission until it is actually necessary.

Due to the necessary signature change in each instance of these methods, this revision also takes this opportunity to cleanup the definition of these methods by:
* restructuring the signature of `getChecked` such that it can be generated from the same code block as the `get` method.
* renaming `verifyConstructionInvariants` to `verify` to match the naming scheme of the rest of the compiler.

Differential Revision: https://reviews.llvm.org/D97100

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIRType.h
    flang/include/flang/Optimizer/Dialect/FIRTypes.td
    flang/lib/Optimizer/Dialect/FIRType.cpp
    mlir/docs/OpDefinitions.md
    mlir/docs/Tutorials/DefiningAttributesAndTypes.md
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/Dialect/Quant/QuantTypes.h
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/Types.h
    mlir/include/mlir/TableGen/TypeDef.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
    mlir/lib/Dialect/Quant/IR/TypeParser.cpp
    mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Parser/DialectSymbolParser.cpp
    mlir/lib/TableGen/TypeDef.cpp
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/mlir-tblgen/typedefs.td
    mlir/tools/mlir-tblgen/TypeDefGen.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 7198a98837d5..e8972cfdb5e9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -23,8 +23,7 @@
 namespace llvm {
 class raw_ostream;
 class StringRef;
-template <typename>
-class ArrayRef;
+template <typename> class ArrayRef;
 class hash_code;
 } // namespace llvm
 
@@ -149,8 +148,9 @@ class HeapType : public mlir::Type::TypeBase<HeapType, mlir::Type,
 
   mlir::Type getEleTy() const;
 
-  static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
-                                                          mlir::Type eleTy);
+  static mlir::LogicalResult
+  verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+         mlir::Type eleTy);
 };
 
 /// The type of a LEN parameter name. Implementations may defer the layout of a
@@ -174,8 +174,9 @@ class PointerType : public mlir::Type::TypeBase<PointerType, mlir::Type,
 
   mlir::Type getEleTy() const;
 
-  static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
-                                                          mlir::Type eleTy);
+  static mlir::LogicalResult
+  verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+         mlir::Type eleTy);
 };
 
 /// The type of a reference to an entity in memory.
@@ -188,8 +189,9 @@ class ReferenceType
 
   mlir::Type getEleTy() const;
 
-  static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
-                                                          mlir::Type eleTy);
+  static mlir::LogicalResult
+  verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+         mlir::Type eleTy);
 };
 
 /// A sequence type is a multi-dimensional array of values. The sequence type
@@ -239,8 +241,8 @@ class SequenceType : public mlir::Type::TypeBase<SequenceType, mlir::Type,
   static constexpr Extent getUnknownExtent() { return -1; }
 
   static mlir::LogicalResult
-  verifyConstructionInvariants(mlir::Location loc, const Shape &shape,
-                               mlir::Type eleTy, mlir::AffineMapAttr map);
+  verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+         const Shape &shape, mlir::Type eleTy, mlir::AffineMapAttr map);
 };
 
 bool operator==(const SequenceType::Shape &, const SequenceType::Shape &);
@@ -256,8 +258,9 @@ class TypeDescType : public mlir::Type::TypeBase<TypeDescType, mlir::Type,
   static TypeDescType get(mlir::Type ofType);
   mlir::Type getOfTy() const;
 
-  static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
-                                                          mlir::Type ofType);
+  static mlir::LogicalResult
+  verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+         mlir::Type ofType);
 };
 
 // Derived types
@@ -290,8 +293,9 @@ class RecordType : public mlir::Type::TypeBase<RecordType, mlir::Type,
 
   detail::RecordTypeStorage const *uniqueKey() const;
 
-  static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
-                                                          llvm::StringRef name);
+  static mlir::LogicalResult
+  verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+         llvm::StringRef name);
 };
 
 /// Is `t` a FIR Real or MLIR Float type?
@@ -318,7 +322,8 @@ class VectorType : public mlir::Type::TypeBase<fir::VectorType, mlir::Type,
   uint64_t getLen() const;
 
   static mlir::LogicalResult
-  verifyConstructionInvariants(mlir::Location, uint64_t len, mlir::Type eleTy);
+  verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
+         mlir::Type eleTy);
   static bool isValidElementType(mlir::Type t) {
     return isa_real(t) || isa_integer(t);
   }

diff  --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index ab9bbd60ac09..07c2c58c7c16 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -68,7 +68,7 @@ def fir_BoxProcType : FIR_Type<"BoxProc", "boxproc"> {
   }];
 
   let genAccessors = 1;
-  let genVerifyInvariantsDecl = 1;
+  let genVerifyDecl = 1;
 }
 
 def fir_BoxType : FIR_Type<"Box", "box"> {
@@ -91,7 +91,7 @@ def fir_BoxType : FIR_Type<"Box", "box"> {
   }];
 
   let genAccessors = 1;
-  let genVerifyInvariantsDecl = 1;
+  let genVerifyDecl = 1;
 }
 
 def fir_CharacterType : FIR_Type<"Character", "char"> {

diff  --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 77a6d95e7406..ecfc597d820a 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -681,8 +681,7 @@ struct VectorTypeStorage : public mlir::TypeStorage {
 
 } // namespace detail
 
-template <typename A, typename B>
-bool inbounds(A v, B lb, B ub) {
+template <typename A, typename B> bool inbounds(A v, B lb, B ub) {
   return v >= lb && v < ub;
 }
 
@@ -759,8 +758,8 @@ RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) {
 KindTy fir::RealType::getFKind() const { return getImpl()->getFKind(); }
 
 mlir::LogicalResult
-fir::BoxType::verifyConstructionInvariants(mlir::Location, mlir::Type eleTy,
-                                           mlir::AffineMapAttr map) {
+fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()>,
+                     mlir::Type eleTy, mlir::AffineMapAttr map) {
   // TODO
   return mlir::success();
 }
@@ -775,15 +774,14 @@ mlir::Type fir::ReferenceType::getEleTy() const {
   return getImpl()->getElementType();
 }
 
-mlir::LogicalResult
-fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc,
-                                                 mlir::Type eleTy) {
+mlir::LogicalResult fir::ReferenceType::verify(
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+    mlir::Type eleTy) {
   if (eleTy.isa<ShapeType>() || eleTy.isa<ShapeShiftType>() ||
       eleTy.isa<SliceType>() || eleTy.isa<FieldType>() ||
       eleTy.isa<LenType>() || eleTy.isa<ReferenceType>() ||
       eleTy.isa<TypeDescType>())
-    return mlir::emitError(loc, "cannot build a reference to type: ")
-           << eleTy << '\n';
+    return emitError() << "cannot build a reference to type: " << eleTy << '\n';
   return mlir::success();
 }
 
@@ -807,12 +805,11 @@ static bool canBePointerOrHeapElementType(mlir::Type eleTy) {
          eleTy.isa<ReferenceType>() || eleTy.isa<TypeDescType>();
 }
 
-mlir::LogicalResult
-fir::PointerType::verifyConstructionInvariants(mlir::Location loc,
-                                               mlir::Type eleTy) {
+mlir::LogicalResult fir::PointerType::verify(
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+    mlir::Type eleTy) {
   if (canBePointerOrHeapElementType(eleTy))
-    return mlir::emitError(loc, "cannot build a pointer to type: ")
-           << eleTy << '\n';
+    return emitError() << "cannot build a pointer to type: " << eleTy << '\n';
   return mlir::success();
 }
 
@@ -828,11 +825,11 @@ mlir::Type fir::HeapType::getEleTy() const {
 }
 
 mlir::LogicalResult
-fir::HeapType::verifyConstructionInvariants(mlir::Location loc,
-                                            mlir::Type eleTy) {
+fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+                      mlir::Type eleTy) {
   if (canBePointerOrHeapElementType(eleTy))
-    return mlir::emitError(loc, "cannot build a heap pointer to type: ")
-           << eleTy << '\n';
+    return emitError() << "cannot build a heap pointer to type: " << eleTy
+                       << '\n';
   return mlir::success();
 }
 
@@ -884,8 +881,9 @@ bool fir::SequenceType::hasConstantInterior() const {
   return true;
 }
 
-mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
-    mlir::Location loc, const SequenceType::Shape &shape, mlir::Type eleTy,
+mlir::LogicalResult fir::SequenceType::verify(
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+    const SequenceType::Shape &shape, mlir::Type eleTy,
     mlir::AffineMapAttr map) {
   // DIMENSION attribute can only be applied to an intrinsic or record type
   if (eleTy.isa<BoxType>() || eleTy.isa<BoxCharType>() ||
@@ -895,8 +893,8 @@ mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
       eleTy.isa<PointerType>() || eleTy.isa<ReferenceType>() ||
       eleTy.isa<TypeDescType>() || eleTy.isa<fir::VectorType>() ||
       eleTy.isa<SequenceType>())
-    return mlir::emitError(loc, "cannot build an array of this element type: ")
-           << eleTy << '\n';
+    return emitError() << "cannot build an array of this element type: "
+                       << eleTy << '\n';
   return mlir::success();
 }
 
@@ -955,11 +953,11 @@ detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
   return getImpl();
 }
 
-mlir::LogicalResult
-fir::RecordType::verifyConstructionInvariants(mlir::Location loc,
-                                              llvm::StringRef name) {
+mlir::LogicalResult fir::RecordType::verify(
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+    llvm::StringRef name) {
   if (name.size() == 0)
-    return mlir::emitError(loc, "record types must have a name");
+    return emitError() << "record types must have a name";
   return mlir::success();
 }
 
@@ -981,16 +979,16 @@ TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
 
 mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); }
 
-mlir::LogicalResult
-fir::TypeDescType::verifyConstructionInvariants(mlir::Location loc,
-                                                mlir::Type eleTy) {
+mlir::LogicalResult fir::TypeDescType::verify(
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+    mlir::Type eleTy) {
   if (eleTy.isa<BoxType>() || eleTy.isa<BoxCharType>() ||
       eleTy.isa<BoxProcType>() || eleTy.isa<ShapeType>() ||
       eleTy.isa<ShapeShiftType>() || eleTy.isa<SliceType>() ||
       eleTy.isa<FieldType>() || eleTy.isa<LenType>() ||
       eleTy.isa<ReferenceType>() || eleTy.isa<TypeDescType>())
-    return mlir::emitError(loc, "cannot build a type descriptor of type: ")
-           << eleTy << '\n';
+    return emitError() << "cannot build a type descriptor of type: " << eleTy
+                       << '\n';
   return mlir::success();
 }
 
@@ -1006,12 +1004,11 @@ mlir::Type fir::VectorType::getEleTy() const { return getImpl()->getEleTy(); }
 
 uint64_t fir::VectorType::getLen() const { return getImpl()->getLen(); }
 
-mlir::LogicalResult
-fir::VectorType::verifyConstructionInvariants(mlir::Location loc, uint64_t len,
-                                              mlir::Type eleTy) {
+mlir::LogicalResult fir::VectorType::verify(
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
+    mlir::Type eleTy) {
   if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
-    return mlir::emitError(loc, "cannot build a vector of type ")
-           << eleTy << '\n';
+    return emitError() << "cannot build a vector of type " << eleTy << '\n';
   return mlir::success();
 }
 
@@ -1173,14 +1170,14 @@ mlir::Type BoxProcType::parse(mlir::MLIRContext *context,
 }
 
 mlir::LogicalResult
-BoxProcType::verifyConstructionInvariants(mlir::Location loc,
-                                          mlir::Type eleTy) {
+BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+                    mlir::Type eleTy) {
   if (eleTy.isa<mlir::FunctionType>())
     return mlir::success();
   if (auto refTy = eleTy.dyn_cast<ReferenceType>())
     if (refTy.isa<mlir::FunctionType>())
       return mlir::success();
-  return mlir::emitError(loc, "invalid type for boxproc") << eleTy << '\n';
+  return emitError() << "invalid type for boxproc" << eleTy << '\n';
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index d4dbf226c425..3eb0d80154de 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1525,11 +1525,10 @@ responsible for parsing/printing the types in `Dialect::printType` and
 -   If the `genAccessors` field is 1 (the default) accessor methods will be
     generated on the Type class (e.g. `int getWidth() const` in the example
     above).
--   If the `genVerifyInvariantsDecl` field is set, a declaration for a method
-    `static LogicalResult verifyConstructionInvariants(Location, parameters...)`
-    is added to the class as well as a `getChecked(Location, parameters...)`
-    method which gets the result of `verifyConstructionInvariants` before
-    calling `get`.
+-   If the `genVerifyDecl` field is set, a declaration for a method `static
+    LogicalResult verify(emitErrorFn, parameters...)` is added to the class as
+    well as a `getChecked(emitErrorFn, parameters...)` method which checks the
+    result of `verify` before calling `get`.
 -   The `storageClass` field can be used to set the name of the storage class.
 -   The `storageNamespace` field is used to set the namespace where the storage
     class should sit. Defaults to "detail".
@@ -1555,9 +1554,9 @@ The following builders are generated:
 // given set of parameters.
 static MyType get(MLIRContext *context, int intParam);
 
-// If `genVerifyInvariantsDecl` is set to 1, the following method is also
-// generated.
-static MyType getChecked(Location loc, int intParam);
+// If `genVerifyDecl` is set to 1, the following method is also generated.
+static MyType getChecked(function_ref<InFlightDiagnostic()> emitError,
+                         MLIRContext *context, int intParam);
 ```
 
 If these autogenerated methods are not desired, such as when they conflict with

diff  --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 0c0de7808674..6a261da8a6c2 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -161,27 +161,28 @@ public:
     return Base::get(type.getContext(), param, type);
   }
 
-  /// This method is used to get an instance of the 'ComplexType', defined at
-  /// the given location. If any of the construction invariants are invalid,
-  /// errors are emitted with the provided location and a null type is returned.
+  /// This method is used to get an instance of the 'ComplexType'. If any of the
+  /// construction invariants are invalid, errors are emitted with the provided
+  /// `emitError` function and a null type is returned.
   /// Note: This method is completely optional.
-  static ComplexType getChecked(unsigned param, Type type, Location location) {
+  static ComplexType getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                unsigned param, Type type) {
     // Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued
     // instance of this type. All parameters to the storage class are passed
-    // after the location.
-    return Base::getChecked(location, param, type);
+    // after the context.
+    return Base::getChecked(emitError, type.getContext(), param, type);
   }
 
   /// This method is used to verify the construction invariants passed into the
   /// 'get' and 'getChecked' methods. Note: This method is completely optional.
-  static LogicalResult verifyConstructionInvariants(
-      Location loc, unsigned param, Type type) {
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              unsigned param, Type type) {
     // Our type only allows non-zero parameters.
     if (param == 0)
-      return emitError(loc) << "non-zero parameter passed to 'ComplexType'";
+      return emitError() << "non-zero parameter passed to 'ComplexType'";
     // Our type also expects an integer type.
     if (!type.isa<IntegerType>())
-      return emitError(loc) << "non integer-type passed to 'ComplexType'";
+      return emitError() << "non integer-type passed to 'ComplexType'";
     return success();
   }
 

diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index cc6db74968fc..29df9cf60b8a 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -98,8 +98,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx,
 
 /// Same as "mlirFloatAttrDoubleGet", but if the type is not valid for a
 /// construction of a FloatAttr, returns a null MlirAttribute.
-MLIR_CAPI_EXPORTED MlirAttribute
-mlirFloatAttrDoubleGetChecked(MlirType type, double value, MlirLocation loc);
+MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc,
+                                                               MlirType type,
+                                                               double value);
 
 /// Returns the value stored in the given floating point attribute, interpreting
 /// the value as double.

diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index cb942da1ae1a..a706c58efc7d 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -170,10 +170,10 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGet(intptr_t rank,
 
 /// Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on
 /// illegal arguments, emitting appropriate diagnostics.
-MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(intptr_t rank,
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
+                                                     intptr_t rank,
                                                      const int64_t *shape,
-                                                     MlirType elementType,
-                                                     MlirLocation loc);
+                                                     MlirType elementType);
 
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked Tensor type.
@@ -196,10 +196,9 @@ MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank,
 
 /// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on
 /// illegal arguments, emitting appropriate diagnostics.
-MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(intptr_t rank,
-                                                           const int64_t *shape,
-                                                           MlirType elementType,
-                                                           MlirLocation loc);
+MLIR_CAPI_EXPORTED MlirType
+mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
+                               const int64_t *shape, MlirType elementType);
 
 /// Creates an unranked tensor type with the given element type in the same
 /// context as the element type. The type is owned by the context.
@@ -208,7 +207,7 @@ MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
 /// Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType
 /// on illegal arguments, emitting appropriate diagnostics.
 MLIR_CAPI_EXPORTED MlirType
-mlirUnrankedTensorTypeGetChecked(MlirType elementType, MlirLocation loc);
+mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType);
 
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked MemRef type.
@@ -230,8 +229,8 @@ MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(
 /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o
 /// illegal arguments, emitting appropriate diagnostics.
 MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(
-    MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
-    MlirAffineMap const *affineMaps, unsigned memorySpace, MlirLocation loc);
+    MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape,
+    intptr_t numMaps, MlirAffineMap const *affineMaps, unsigned memorySpace);
 
 /// Creates a MemRef type with the given rank, shape, memory space and element
 /// type in the same context as the element type. The type has no affine maps,
@@ -245,8 +244,8 @@ MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGet(MlirType elementType,
 /// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping
 /// MlirType on illegal arguments, emitting appropriate diagnostics.
 MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked(
-    MlirType elementType, intptr_t rank, const int64_t *shape,
-    unsigned memorySpace, MlirLocation loc);
+    MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape,
+    unsigned memorySpace);
 
 /// Creates an Unranked MemRef type with the given element type and in the given
 /// memory space. The type is owned by the context of element type.
@@ -256,7 +255,7 @@ MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
 /// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping
 /// MlirType on illegal arguments, emitting appropriate diagnostics.
 MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(
-    MlirType elementType, unsigned memorySpace, MlirLocation loc);
+    MlirLocation loc, MlirType elementType, unsigned memorySpace);
 
 /// Returns the number of affine layout maps in the given MemRef type.
 MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
index 16cca53aea5e..d78bd06416ea 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
@@ -43,9 +43,7 @@ def Async_ValueType : Async_Type<"Value", "value"> {
   let parameters = (ins "Type":$valueType);
   let builders = [
     TypeBuilderWithInferredContext<(ins "Type":$valueType), [{
-      return Base::get(valueType.getContext(), valueType);
-    }], [{
-      return Base::getChecked($_loc, valueType);
+      return $_get(valueType.getContext(), valueType);
     }]>
   ];
   let skipDefaultBuilders = 1;

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 95a8e4c6950e..4a1449db04e4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -68,6 +68,7 @@ class LLVMArrayType : public Type::TypeBase<LLVMArrayType, Type,
 public:
   /// Inherit base constructors.
   using Base::Base;
+  using Base::getChecked;
 
   /// Checks if the given type can be used inside an array type.
   static bool isValidElementType(Type type);
@@ -75,8 +76,8 @@ class LLVMArrayType : public Type::TypeBase<LLVMArrayType, Type,
   /// Gets or creates an instance of LLVM dialect array type containing
   /// `numElements` of `elementType`, in the same context as `elementType`.
   static LLVMArrayType get(Type elementType, unsigned numElements);
-  static LLVMArrayType getChecked(Location loc, Type elementType,
-                                  unsigned numElements);
+  static LLVMArrayType getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  Type elementType, unsigned numElements);
 
   /// Returns the element type of the array.
   Type getElementType();
@@ -85,9 +86,8 @@ class LLVMArrayType : public Type::TypeBase<LLVMArrayType, Type,
   unsigned getNumElements();
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Type elementType,
-                                                    unsigned numElements);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type elementType, unsigned numElements);
 };
 
 //===----------------------------------------------------------------------===//
@@ -103,6 +103,7 @@ class LLVMFunctionType
 public:
   /// Inherit base constructors.
   using Base::Base;
+  using Base::getChecked;
 
   /// Checks if the given type can be used an argument in a function type.
   static bool isValidArgumentType(Type type);
@@ -117,9 +118,9 @@ class LLVMFunctionType
   /// as the `result` type.
   static LLVMFunctionType get(Type result, ArrayRef<Type> arguments,
                               bool isVarArg = false);
-  static LLVMFunctionType getChecked(Location loc, Type result,
-                                     ArrayRef<Type> arguments,
-                                     bool isVarArg = false);
+  static LLVMFunctionType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, Type result,
+             ArrayRef<Type> arguments, bool isVarArg = false);
 
   /// Returns the result type of the function.
   Type getReturnType();
@@ -135,9 +136,8 @@ class LLVMFunctionType
   ArrayRef<Type> params() { return getParams(); }
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verifyConstructionInvariants(Location loc, Type result,
-                                                    ArrayRef<Type> arguments,
-                                                    bool);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type result, ArrayRef<Type> arguments, bool);
 };
 
 //===----------------------------------------------------------------------===//
@@ -152,6 +152,7 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
 public:
   /// Inherit base constructors.
   using Base::Base;
+  using Base::getChecked;
 
   /// Checks if the given type can have a pointer type pointing to it.
   static bool isValidElementType(Type type);
@@ -160,8 +161,9 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
   /// object of `pointee` type in the given address space. The pointer type is
   /// created in the same context as `pointee`.
   static LLVMPointerType get(Type pointee, unsigned addressSpace = 0);
-  static LLVMPointerType getChecked(Location loc, Type pointee,
-                                    unsigned addressSpace = 0);
+  static LLVMPointerType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, Type pointee,
+             unsigned addressSpace = 0);
 
   /// Returns the pointed-to type.
   Type getElementType();
@@ -170,8 +172,8 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
   unsigned getAddressSpace();
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verifyConstructionInvariants(Location loc, Type pointee,
-                                                    unsigned);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type pointee, unsigned);
 };
 
 //===----------------------------------------------------------------------===//
@@ -217,7 +219,9 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, Type,
   /// in the context. Instead, it will just return the existing struct,
   /// similarly to the rest of MLIR type ::get methods.
   static LLVMStructType getIdentified(MLIRContext *context, StringRef name);
-  static LLVMStructType getIdentifiedChecked(Location loc, StringRef name);
+  static LLVMStructType
+  getIdentifiedChecked(function_ref<InFlightDiagnostic()> emitError,
+                       MLIRContext *context, StringRef name);
 
   /// Gets a new identified struct with the given body. The body _cannot_ be
   /// changed later. If a struct with the given name already exists, renames
@@ -231,8 +235,10 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, Type,
   /// context.
   static LLVMStructType getLiteral(MLIRContext *context, ArrayRef<Type> types,
                                    bool isPacked = false);
-  static LLVMStructType getLiteralChecked(Location loc, ArrayRef<Type> types,
-                                          bool isPacked = false);
+  static LLVMStructType
+  getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
+                    MLIRContext *context, ArrayRef<Type> types,
+                    bool isPacked = false);
 
   /// Gets or creates an intentionally-opaque identified struct. Such a struct
   /// cannot have its body set. To create an opaque struct with a mutable body,
@@ -241,7 +247,9 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, Type,
   /// already exists in the context. Instead, it will just return the existing
   /// struct, similarly to the rest of MLIR type ::get methods.
   static LLVMStructType getOpaque(StringRef name, MLIRContext *context);
-  static LLVMStructType getOpaqueChecked(Location loc, StringRef name);
+  static LLVMStructType
+  getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
+                   MLIRContext *context, StringRef name);
 
   /// Set the body of an identified struct. Returns failure if the body could
   /// not be set, e.g. if the struct already has a body or if it was marked as
@@ -270,9 +278,10 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, Type,
   ArrayRef<Type> getBody();
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verifyConstructionInvariants(Location, StringRef, bool);
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    ArrayRef<Type> types, bool);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              StringRef, bool);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              ArrayRef<Type> types, bool);
 };
 
 //===----------------------------------------------------------------------===//
@@ -300,9 +309,8 @@ class LLVMVectorType : public Type {
   llvm::ElementCount getElementCount();
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Type elementType,
-                                                    unsigned numElements);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type elementType, unsigned numElements);
 };
 
 //===----------------------------------------------------------------------===//
@@ -317,12 +325,14 @@ class LLVMFixedVectorType
 public:
   /// Inherit base constructor.
   using Base::Base;
+  using Base::getChecked;
 
   /// Gets or creates a fixed vector type containing `numElements` of
   /// `elementType` in the same context as `elementType`.
   static LLVMFixedVectorType get(Type elementType, unsigned numElements);
-  static LLVMFixedVectorType getChecked(Location loc, Type elementType,
-                                        unsigned numElements);
+  static LLVMFixedVectorType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType,
+             unsigned numElements);
 
   /// Checks if the given type can be used in a vector type. This type supports
   /// only a subset of LLVM dialect types that don't have a built-in
@@ -336,9 +346,8 @@ class LLVMFixedVectorType
   unsigned getNumElements();
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Type elementType,
-                                                    unsigned numElements);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type elementType, unsigned numElements);
 };
 
 //===----------------------------------------------------------------------===//
@@ -354,12 +363,14 @@ class LLVMScalableVectorType
 public:
   /// Inherit base constructor.
   using Base::Base;
+  using Base::getChecked;
 
   /// Gets or creates a scalable vector type containing a non-zero multiple of
   /// `minNumElements` of `elementType` in the same context as `elementType`.
   static LLVMScalableVectorType get(Type elementType, unsigned minNumElements);
-  static LLVMScalableVectorType getChecked(Location loc, Type elementType,
-                                           unsigned minNumElements);
+  static LLVMScalableVectorType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType,
+             unsigned minNumElements);
 
   /// Checks if the given type can be used in a vector type.
   static bool isValidElementType(Type type);
@@ -373,9 +384,8 @@ class LLVMScalableVectorType
   unsigned getMinNumElements();
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Type elementType,
-                                                    unsigned minNumElements);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type elementType, unsigned minNumElements);
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 47e70fe48354..45e791b8d4c4 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -57,10 +57,10 @@ class QuantizedType : public Type {
   /// The maximum number of bits supported for storage types.
   static constexpr unsigned MaxStorageBits = 32;
 
-  static LogicalResult
-  verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
-                               Type expressedType, int64_t storageTypeMin,
-                               int64_t storageTypeMax);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              unsigned flags, Type storageType,
+                              Type expressedType, int64_t storageTypeMin,
+                              int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
   static bool classof(Type type);
@@ -199,6 +199,7 @@ class AnyQuantizedType
                             detail::AnyQuantizedTypeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
@@ -208,15 +209,16 @@ class AnyQuantizedType
 
   /// Gets an instance of the type with all specified parameters checked.
   /// Returns a nullptr convertible type on failure.
-  static AnyQuantizedType getChecked(unsigned flags, Type storageType,
-                                     Type expressedType, int64_t storageTypeMin,
-                                     int64_t storageTypeMax, Location location);
+  static AnyQuantizedType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+             Type storageType, Type expressedType, int64_t storageTypeMin,
+             int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult
-  verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
-                               Type expressedType, int64_t storageTypeMin,
-                               int64_t storageTypeMax);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              unsigned flags, Type storageType,
+                              Type expressedType, int64_t storageTypeMin,
+                              int64_t storageTypeMax);
 };
 
 /// Represents a family of uniform, quantized types.
@@ -256,6 +258,7 @@ class UniformQuantizedType
                             detail::UniformQuantizedTypeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
@@ -267,16 +270,16 @@ class UniformQuantizedType
   /// Gets an instance of the type with all specified parameters checked.
   /// Returns a nullptr convertible type on failure.
   static UniformQuantizedType
-  getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
-             int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
-             Location location);
+  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+             Type storageType, Type expressedType, double scale,
+             int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult
-  verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
-                               Type expressedType, double scale,
-                               int64_t zeroPoint, int64_t storageTypeMin,
-                               int64_t storageTypeMax);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              unsigned flags, Type storageType,
+                              Type expressedType, double scale,
+                              int64_t zeroPoint, int64_t storageTypeMin,
+                              int64_t storageTypeMax);
 
   /// Gets the scale term. The scale designates the 
diff erence between the real
   /// values corresponding to consecutive quantized values 
diff ering by 1.
@@ -313,6 +316,7 @@ class UniformQuantizedPerAxisType
                             detail::UniformQuantizedPerAxisTypeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
@@ -325,18 +329,18 @@ class UniformQuantizedPerAxisType
   /// Gets an instance of the type with all specified parameters checked.
   /// Returns a nullptr convertible type on failure.
   static UniformQuantizedPerAxisType
-  getChecked(unsigned flags, Type storageType, Type expressedType,
-             ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
-             int32_t quantizedDimension, int64_t storageTypeMin,
-             int64_t storageTypeMax, Location location);
+  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+             Type storageType, Type expressedType, ArrayRef<double> scales,
+             ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+             int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult
-  verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
-                               Type expressedType, ArrayRef<double> scales,
-                               ArrayRef<int64_t> zeroPoints,
-                               int32_t quantizedDimension,
-                               int64_t storageTypeMin, int64_t storageTypeMax);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              unsigned flags, Type storageType,
+                              Type expressedType, ArrayRef<double> scales,
+                              ArrayRef<int64_t> zeroPoints,
+                              int32_t quantizedDimension,
+                              int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Gets the quantization scales. The scales designate the 
diff erence between
   /// the real values corresponding to consecutive quantized values 
diff ering
@@ -381,6 +385,7 @@ class CalibratedQuantizedType
                             detail::CalibratedQuantizedTypeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Gets an instance of the type with all parameters specified but not
   /// checked.
@@ -389,13 +394,13 @@ class CalibratedQuantizedType
 
   /// Gets an instance of the type with all specified parameters checked.
   /// Returns a nullptr convertible type on failure.
-  static CalibratedQuantizedType getChecked(Type expressedType, double min,
-                                            double max, Location location);
+  static CalibratedQuantizedType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, Type expressedType,
+             double min, double max);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Type expressedType,
-                                                    double min, double max);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type expressedType, double min, double max);
   double getMin() const;
   double getMax() const;
 };

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 1c253f9f5984..fd9f576b3bc5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -69,10 +69,9 @@ class InterfaceVarABIAttr
   /// Returns `spirv::StorageClass`.
   Optional<StorageClass> getStorageClass();
 
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    IntegerAttr descriptorSet,
-                                                    IntegerAttr binding,
-                                                    IntegerAttr storageClass);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              IntegerAttr descriptorSet, IntegerAttr binding,
+                              IntegerAttr storageClass);
 };
 
 /// An attribute that specifies the SPIR-V (version, capabilities, extensions)
@@ -120,10 +119,9 @@ class VerCapExtAttr
   /// Returns the capabilities as an integer array attribute.
   ArrayAttr getCapabilitiesAttr();
 
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    IntegerAttr version,
-                                                    ArrayAttr capabilities,
-                                                    ArrayAttr extensions);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              IntegerAttr version, ArrayAttr capabilities,
+                              ArrayAttr extensions);
 };
 
 /// An attribute that specifies the target version, allowed extensions and
@@ -174,10 +172,10 @@ class TargetEnvAttr
   /// Returns the target resource limits.
   ResourceLimitsAttr getResourceLimits() const;
 
-  static LogicalResult
-  verifyConstructionInvariants(Location loc, VerCapExtAttr triple,
-                               Vendor vendorID, DeviceType deviceType,
-                               uint32_t deviceID, DictionaryAttr limits);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              VerCapExtAttr triple, Vendor vendorID,
+                              DeviceType deviceType, uint32_t deviceID,
+                              DictionaryAttr limits);
 };
 } // namespace spirv
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 007d4ea22311..4ec6cc52db7a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -243,10 +243,11 @@ class SampledImageType
 
   static SampledImageType get(Type imageType);
 
-  static SampledImageType getChecked(Type imageType, Location location);
+  static SampledImageType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
 
-  static LogicalResult verifyConstructionInvariants(Location Loc,
-                                                    Type imageType);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type imageType);
 
   Type getImageType() const;
 
@@ -426,12 +427,11 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
 
   static MatrixType get(Type columnType, uint32_t columnCount);
 
-  static MatrixType getChecked(Type columnType, uint32_t columnCount,
-                               Location location);
+  static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
+                               Type columnType, uint32_t columnCount);
 
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Type columnType,
-                                                    uint32_t columnCount);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type columnType, uint32_t columnCount);
 
   /// Returns true if the matrix elements are vectors of float elements.
   static bool isValidColumnType(Type columnType);

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 571c9126f163..994308eff1ec 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -182,17 +182,20 @@ class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
                                              detail::FloatAttributeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
   using ValueType = APFloat;
 
   /// Return a float attribute for the specified value in the specified type.
   /// These methods should only be used for simple constant values, e.g 1.0/2.0,
   /// that are known-valid both as host double and the 'type' format.
   static FloatAttr get(Type type, double value);
-  static FloatAttr getChecked(Type type, double value, Location loc);
+  static FloatAttr getChecked(function_ref<InFlightDiagnostic()> emitError,
+                              Type type, double value);
 
   /// Return a float attribute for the specified value in the specified type.
   static FloatAttr get(Type type, const APFloat &value);
-  static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
+  static FloatAttr getChecked(function_ref<InFlightDiagnostic()> emitError,
+                              Type type, const APFloat &value);
 
   APFloat getValue() const;
 
@@ -202,10 +205,10 @@ class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
   static double getValueAsDouble(APFloat val);
 
   /// Verify the construction invariants for a double value.
-  static LogicalResult verifyConstructionInvariants(Location loc, Type type,
-                                                    double value);
-  static LogicalResult verifyConstructionInvariants(Location loc, Type type,
-                                                    const APFloat &value);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type type, double value);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type type, const APFloat &value);
 };
 
 //===----------------------------------------------------------------------===//
@@ -234,10 +237,10 @@ class IntegerAttr
   /// an unsigned integer.
   uint64_t getUInt() const;
 
-  static LogicalResult verifyConstructionInvariants(Location loc, Type type,
-                                                    int64_t value);
-  static LogicalResult verifyConstructionInvariants(Location loc, Type type,
-                                                    const APInt &value);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type type, int64_t value);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type type, const APInt &value);
 };
 
 //===----------------------------------------------------------------------===//
@@ -290,6 +293,7 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
                                               detail::OpaqueAttributeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
   static OpaqueAttr get(MLIRContext *context, Identifier dialect,
@@ -298,8 +302,9 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
   /// If the given identifier is not a valid namespace for a dialect, then a
   /// null attribute is returned.
-  static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
-                               Type type, Location location);
+  static OpaqueAttr getChecked(function_ref<InFlightDiagnostic()> emitError,
+                               Identifier dialect, StringRef attrData,
+                               Type type);
 
   /// Returns the dialect namespace of the opaque attribute.
   Identifier getDialectNamespace() const;
@@ -308,10 +313,9 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
   StringRef getAttrData() const;
 
   /// Verify the construction of an opaque attribute.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Identifier dialect,
-                                                    StringRef attrData,
-                                                    Type type);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Identifier dialect, StringRef attrData,
+                              Type type);
 };
 
 //===----------------------------------------------------------------------===//
@@ -428,10 +432,8 @@ class UnitAttr
 //===----------------------------------------------------------------------===//
 
 namespace detail {
-template <typename T>
-class ElementsAttrIterator;
-template <typename T>
-class ElementsAttrRange;
+template <typename T> class ElementsAttrIterator;
+template <typename T> class ElementsAttrRange;
 } // namespace detail
 
 /// A base attribute that represents a reference to a static shaped tensor or
@@ -439,10 +441,8 @@ class ElementsAttrRange;
 class ElementsAttr : public Attribute {
 public:
   using Attribute::Attribute;
-  template <typename T>
-  using iterator = detail::ElementsAttrIterator<T>;
-  template <typename T>
-  using iterator_range = detail::ElementsAttrRange<T>;
+  template <typename T> using iterator = detail::ElementsAttrIterator<T>;
+  template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
 
   /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
   /// with static shape.
@@ -454,16 +454,14 @@ class ElementsAttr : public Attribute {
 
   /// Return the value of type 'T' at the given index, where 'T' corresponds to
   /// an Attribute type.
-  template <typename T>
-  T getValue(ArrayRef<uint64_t> index) const {
+  template <typename T> T getValue(ArrayRef<uint64_t> index) const {
     return getValue(index).template cast<T>();
   }
 
   /// Return the elements of this attribute as a value of type 'T'. Note:
   /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
   /// iteration.
-  template <typename T>
-  iterator_range<T> getValues() const;
+  template <typename T> iterator_range<T> getValues() const;
 
   /// Return if the given 'index' refers to a valid element in this attribute.
   bool isValidIndex(ArrayRef<uint64_t> index) const;
@@ -540,8 +538,7 @@ class DenseElementIndexedIteratorImpl
 };
 
 /// Type trait detector that checks if a given type T is a complex type.
-template <typename T>
-struct is_complex_t : public std::false_type {};
+template <typename T> struct is_complex_t : public std::false_type {};
 template <typename T>
 struct is_complex_t<std::complex<T>> : public std::true_type {};
 } // namespace detail
@@ -556,8 +553,7 @@ class DenseElementsAttr : public ElementsAttr {
   /// floating point type that can be used to access the underlying element
   /// types of a DenseElementsAttr.
   // TODO: Use std::disjunction when C++17 is supported.
-  template <typename T>
-  struct is_valid_cpp_fp_type {
+  template <typename T> struct is_valid_cpp_fp_type {
     /// The type is a valid floating point type if it is a builtin floating
     /// point type, or is a potentially user defined floating point type. The
     /// latter allows for supporting users that have custom types defined for
@@ -826,8 +822,7 @@ class DenseElementsAttr : public ElementsAttr {
   Attribute getValue(ArrayRef<uint64_t> index) const {
     return getValue<Attribute>(index);
   }
-  template <typename T>
-  T getValue(ArrayRef<uint64_t> index) const {
+  template <typename T> T getValue(ArrayRef<uint64_t> index) const {
     // Skip to the element corresponding to the flattened index.
     return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
   }
@@ -1236,8 +1231,7 @@ class SparseElementsAttr
 
   /// Return the values of this attribute in the form of the given type 'T'. 'T'
   /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc.
-  template <typename T>
-  llvm::iterator_range<iterator<T>> getValues() const {
+  template <typename T> llvm::iterator_range<iterator<T>> getValues() const {
     auto zeroValue = getZeroValue<T>();
     auto valueIt = getValues().getValues<T>().begin();
     const std::vector<ptr
diff _t> flatSparseIndices(getFlattenedSparseIndices());
@@ -1379,28 +1373,22 @@ class ElementsAttrIterator
   }
 
   /// Utility functors used to generically implement the iterators methods.
-  template <typename ItT>
-  struct PlusAssign {
+  template <typename ItT> struct PlusAssign {
     void operator()(ItT &it, ptr
diff _t offset) { it += offset; }
   };
-  template <typename ItT>
-  struct Minus {
+  template <typename ItT> struct Minus {
     ptr
diff _t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
   };
-  template <typename ItT>
-  struct MinusAssign {
+  template <typename ItT> struct MinusAssign {
     void operator()(ItT &it, ptr
diff _t offset) { it -= offset; }
   };
-  template <typename ItT>
-  struct Dereference {
+  template <typename ItT> struct Dereference {
     T operator()(ItT &it) { return *it; }
   };
-  template <typename ItT>
-  struct ConstructIter {
+  template <typename ItT> struct ConstructIter {
     void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
   };
-  template <typename ItT>
-  struct DestructIter {
+  template <typename ItT> struct DestructIter {
     void operator()(ItT &it) { it.~ItT(); }
   };
 

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 9064e3294e81..e3b8d597a2a7 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -167,22 +167,21 @@ class VectorType
     : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Get or create a new VectorType of the provided shape and element type.
   /// Assumes the arguments define a well-formed VectorType.
   static VectorType get(ArrayRef<int64_t> shape, Type elementType);
 
-  /// Get or create a new VectorType of the provided shape and element type
-  /// declared at the given, potentially unknown, location.  If the VectorType
-  /// defined by the arguments would be ill-formed, emit errors and return
-  /// nullptr-wrapping type.
-  static VectorType getChecked(Location location, ArrayRef<int64_t> shape,
-                               Type elementType);
+  /// Get or create a new VectorType of the provided shape and element type. If
+  /// the VectorType defined by the arguments would be ill-formed, an error is
+  /// emitted to `emitError` and a null type is returned.
+  static VectorType getChecked(function_ref<InFlightDiagnostic()> emitError,
+                               ArrayRef<int64_t> shape, Type elementType);
 
   /// Verify the construction of a vector type.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    ArrayRef<int64_t> shape,
-                                                    Type elementType);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              ArrayRef<int64_t> shape, Type elementType);
 
   /// Returns true of the given type can be used as an element of a vector type.
   /// In particular, vectors can consist of integer or float primitives.
@@ -226,22 +225,23 @@ class RankedTensorType
                             detail::RankedTensorTypeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Get or create a new RankedTensorType of the provided shape and element
   /// type. Assumes the arguments define a well-formed type.
   static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
 
   /// Get or create a new RankedTensorType of the provided shape and element
-  /// type declared at the given, potentially unknown, location.  If the
-  /// RankedTensorType defined by the arguments would be ill-formed, emit errors
-  /// and return a nullptr-wrapping type.
-  static RankedTensorType getChecked(Location location, ArrayRef<int64_t> shape,
-                                     Type elementType);
+  /// type. If the RankedTensorType defined by the arguments would be
+  /// ill-formed, an error is emitted to `emitError` and a null type is
+  /// returned.
+  static RankedTensorType
+  getChecked(function_ref<InFlightDiagnostic()> emitError,
+             ArrayRef<int64_t> shape, Type elementType);
 
   /// Verify the construction of a ranked tensor type.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    ArrayRef<int64_t> shape,
-                                                    Type elementType);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              ArrayRef<int64_t> shape, Type elementType);
 
   ArrayRef<int64_t> getShape() const;
 };
@@ -256,20 +256,22 @@ class UnrankedTensorType
                             detail::UnrankedTensorTypeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Get or create a new UnrankedTensorType of the provided shape and element
   /// type. Assumes the arguments define a well-formed type.
   static UnrankedTensorType get(Type elementType);
 
   /// Get or create a new UnrankedTensorType of the provided shape and element
-  /// type declared at the given, potentially unknown, location.  If the
-  /// UnrankedTensorType defined by the arguments would be ill-formed, emit
-  /// errors and return a nullptr-wrapping type.
-  static UnrankedTensorType getChecked(Location location, Type elementType);
+  /// type. If the RankedTensorType defined by the arguments would be
+  /// ill-formed, an error is emitted to `emitError` and a null type is
+  /// returned.
+  static UnrankedTensorType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType);
 
   /// Verify the construction of a unranked tensor type.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Type elementType);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type elementType);
 
   ArrayRef<int64_t> getShape() const { return llvm::None; }
 };
@@ -351,6 +353,7 @@ class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
   };
 
   using Base::Base;
+  using Base::getChecked;
 
   /// Get or create a new MemRefType based on shape, element type, affine
   /// map composition, and memory space.  Assumes the arguments define a
@@ -361,13 +364,11 @@ class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
                         unsigned memorySpace = 0);
 
   /// Get or create a new MemRefType based on shape, element type, affine
-  /// map composition, and memory space declared at the given location.
-  /// If the location is unknown, the last argument should be an instance of
-  /// UnknownLoc.  If the MemRefType defined by the arguments would be
-  /// ill-formed, emits errors (to the handler registered with the context or to
-  /// the error stream) and returns nullptr.
-  static MemRefType getChecked(Location location, ArrayRef<int64_t> shape,
-                               Type elementType,
+  /// map composition, and memory space. If the MemRefType defined by the
+  /// arguments would be ill-formed, an error is emitted to `emitError` and a
+  /// null type is returned.
+  static MemRefType getChecked(function_ref<InFlightDiagnostic()> emitError,
+                               ArrayRef<int64_t> shape, Type elementType,
                                ArrayRef<AffineMap> affineMapComposition,
                                unsigned memorySpace);
 
@@ -386,11 +387,11 @@ class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
 
 private:
   /// Get or create a new MemRefType defined by the arguments.  If the resulting
-  /// type would be ill-formed, return nullptr.  If the location is provided,
-  /// emit detailed error messages.
+  /// type would be ill-formed, return nullptr.
   static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
                             ArrayRef<AffineMap> affineMapComposition,
-                            unsigned memorySpace, Optional<Location> location);
+                            unsigned memorySpace,
+                            function_ref<InFlightDiagnostic()> emitError);
   using Base::getImpl;
 };
 
@@ -404,22 +405,23 @@ class UnrankedMemRefType
                             detail::UnrankedMemRefTypeStorage> {
 public:
   using Base::Base;
+  using Base::getChecked;
 
   /// Get or create a new UnrankedMemRefType of the provided element
   /// type and memory space
   static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
 
   /// Get or create a new UnrankedMemRefType of the provided element
-  /// type and memory space declared at the given, potentially unknown,
-  /// location. If the UnrankedMemRefType defined by the arguments would be
-  /// ill-formed, emit errors and return a nullptr-wrapping type.
-  static UnrankedMemRefType getChecked(Location location, Type elementType,
-                                       unsigned memorySpace);
+  /// type and memory space. If the UnrankedMemRefType defined by the arguments
+  /// would be ill-formed, an error is emitted to `emitError` and a null type is
+  /// returned.
+  static UnrankedMemRefType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType,
+             unsigned memorySpace);
 
   /// Verify the construction of a unranked memref type.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    Type elementType,
-                                                    unsigned memorySpace);
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              Type elementType, unsigned memorySpace);
 
   ArrayRef<int64_t> getShape() const { return llvm::None; }
 };

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 711ae45377c7..4133b056e27a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -52,13 +52,11 @@ def Builtin_Complex : Builtin_Type<"Complex"> {
   let parameters = (ins "Type":$elementType);
   let builders = [
     TypeBuilderWithInferredContext<(ins "Type":$elementType), [{
-      return Base::get(elementType.getContext(), elementType);
-    }], [{
-      return Base::getChecked($_loc, elementType);
+      return $_get(elementType.getContext(), elementType);
     }]>
   ];
   let skipDefaultBuilders = 1;
-  let genVerifyInvariantsDecl = 1;
+  let genVerifyDecl = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -137,7 +135,7 @@ def Builtin_Function : Builtin_Type<"Function"> {
   let parameters = (ins "ArrayRef<Type>":$inputs, "ArrayRef<Type>":$results);
   let builders = [
     TypeBuilder<(ins CArg<"TypeRange">:$inputs, CArg<"TypeRange">:$results), [{
-      return Base::get($_ctxt, inputs, results);
+      return $_get($_ctxt, inputs, results);
     }]>
   ];
   let skipDefaultBuilders = 1;
@@ -225,7 +223,7 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
   // memory.
   let genStorageClass = 0;
   let skipDefaultBuilders = 1;
-  let genVerifyInvariantsDecl = 1;
+  let genVerifyDecl = 1;
   let extraClassDeclaration = [{
     /// Signedness semantics.
     enum SignednessSemantics : uint32_t {
@@ -295,7 +293,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
     "Identifier":$dialectNamespace,
     StringRefParameter<"">:$typeData
   );
-  let genVerifyInvariantsDecl = 1;
+  let genVerifyDecl = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -334,10 +332,10 @@ def Builtin_Tuple : Builtin_Type<"Tuple"> {
   let parameters = (ins "ArrayRef<Type>":$types);
   let builders = [
     TypeBuilder<(ins "TypeRange":$elementTypes), [{
-      return Base::get($_ctxt, elementTypes);
+      return $_get($_ctxt, elementTypes);
     }]>,
     TypeBuilder<(ins), [{
-      return Base::get($_ctxt, TypeRange());
+      return $_get($_ctxt, TypeRange());
     }]>
   ];
   let skipDefaultBuilders = 1;

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 4a5731cb2f86..417c32570adb 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2492,15 +2492,21 @@ def replaceWithValue;
 //
 // If an empty string is passed in for `body`, then *only* the builder
 // declaration will be generated; this provides a way to define complicated
-// builders entirely in C++.
+// builders entirely in C++. If a `body` string is provided, the `Base::get`
+// method should be invoked using `$_get`, e.g.:
+//
+// ```
+// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{
+//   return $_get($_ctxt, integerArg, floatArg);
+// }]>
+// ```
+//
+// This is necessary because the `body` is also used to generate `getChecked`
+// methods, which have a 
diff erent underlying `Base::get*` call.
 //
-// `checkedBody` is similar to `body`, but is the code block used when
-// generating a `getChecked` method.
-class TypeBuilder<dag parameters, code bodyCode = "",
-                  code checkedBodyCode = ""> {
+class TypeBuilder<dag parameters, code bodyCode = ""> {
   dag dagParams = parameters;
   code body = bodyCode;
-  code checkedBody = checkedBodyCode;
 
   // The context parameter can be inferred from one of the other parameters and
   // is not implicitly added to the parameter list.
@@ -2510,10 +2516,8 @@ class TypeBuilder<dag parameters, code bodyCode = "",
 // A class of TypeBuilder that is able to infer the MLIRContext parameter from
 // one of the other builder parameters. Instances of this builder do not have
 // `MLIRContext *` implicitly added to the parameter list.
-class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
-                                     code checkedBodyCode = "">
+class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
   : TypeBuilder<parameters, bodyCode> {
-  code checkedBody = checkedBodyCode;
   let hasInferredContextParam = 1;
 }
 
@@ -2590,9 +2594,8 @@ class TypeDef<Dialect dialect, string name,
   // Avoid generating default get/getChecked functions. Custom get methods must
   // be provided.
   bit skipDefaultBuilders = 0;
-  // Generate the verifyConstructionInvariants declaration and getChecked
-  // method.
-  bit genVerifyInvariantsDecl = 0;
+  // Generate the verify and getChecked methods.
+  bit genVerifyDecl = 0;
   // Extra code to include in the class declaration.
   code extraClassDeclaration = [{}];
 

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 719bb1a62f97..71c6703a4055 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -17,16 +17,21 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/StorageUniquer.h"
 #include "mlir/Support/TypeID.h"
+#include "llvm/ADT/FunctionExtras.h"
 
 namespace mlir {
-class AttributeStorage;
+class InFlightDiagnostic;
+class Location;
 class MLIRContext;
 
 namespace detail {
-/// Utility method to generate a raw default location for use when checking the
-/// construction invariants of a storage object. This is defined out-of-line to
-/// avoid the need to include Location.h.
-const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx);
+/// Utility method to generate a callback that can be used to generate a
+/// diagnostic when checking the construction invariants of a storage object.
+/// This is defined out-of-line to avoid the need to include Location.h.
+llvm::unique_function<InFlightDiagnostic()>
+getDefaultDiagnosticEmitFn(MLIRContext *ctx);
+llvm::unique_function<InFlightDiagnostic()>
+getDefaultDiagnosticEmitFn(const Location &loc);
 
 //===----------------------------------------------------------------------===//
 // StorageUserTraitBase
@@ -88,20 +93,30 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   template <typename... Args>
   static ConcreteT get(MLIRContext *ctx, Args... args) {
     // Ensure that the invariants are correct for construction.
-    assert(succeeded(ConcreteT::verifyConstructionInvariants(
-        generateUnknownStorageLocation(ctx), args...)));
+    assert(
+        succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
     return UniquerT::template get<ConcreteT>(ctx, args...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx, defined at
   /// the given, potentially unknown, location. If the arguments provided are
-  /// invalid then emit errors and return a null object.
-  template <typename LocationT, typename... Args>
-  static ConcreteT getChecked(LocationT loc, Args... args) {
+  /// invalid, errors are emitted using the provided location and a null object
+  /// is returned.
+  template <typename... Args>
+  static ConcreteT getChecked(const Location &loc, Args... args) {
+    return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
+  }
+
+  /// Get or create a new ConcreteT instance within the ctx. If the arguments
+  /// provided are invalid, errors are emitted using the provided `emitError`
+  /// and a null object is returned.
+  template <typename... Args>
+  static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
+                              MLIRContext *ctx, Args... args) {
     // If the construction invariants fail then we return a null attribute.
-    if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
+    if (failed(ConcreteT::verify(emitErrorFn, args...)))
       return ConcreteT();
-    return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
+    return UniquerT::template get<ConcreteT>(ctx, args...);
   }
 
   /// Get an instance of the concrete type from a void pointer.
@@ -119,8 +134,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   }
 
   /// Default implementation that just returns success.
-  template <typename... Args>
-  static LogicalResult verifyConstructionInvariants(Args... args) {
+  template <typename... Args> static LogicalResult verify(Args... args) {
     return success();
   }
 

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 9546ff3e8e13..a95447d53097 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -32,8 +32,9 @@ namespace mlir {
 /// Derived type classes are expected to implement several required
 /// implementation hooks:
 ///  * Optional:
-///    - static LogicalResult verifyConstructionInvariants(Location loc,
-///                                                        Args... args)
+///    - static LogicalResult verify(
+///                                function_ref<InFlightDiagnostic()> emitError,
+///                                Args... args)
 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
 ///        methods to ensure that the arguments passed in are valid to construct
 ///        a type instance with.
@@ -92,8 +93,7 @@ class Type {
   bool operator!() const { return impl == nullptr; }
 
   template <typename U> bool isa() const;
-  template <typename First, typename Second, typename... Rest>
-  bool isa() const;
+  template <typename First, typename Second, typename... Rest> bool isa() const;
   template <typename U> U dyn_cast() const;
   template <typename U> U dyn_cast_or_null() const;
   template <typename U> U cast() const;

diff  --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h
index 73a3a1002d0a..a82f85d48863 100644
--- a/mlir/include/mlir/TableGen/TypeDef.h
+++ b/mlir/include/mlir/TableGen/TypeDef.h
@@ -36,10 +36,6 @@ class TypeBuilder : public Builder {
 public:
   using Builder::Builder;
 
-  /// Return an optional code body used for the `getChecked` variant of this
-  /// builder.
-  Optional<StringRef> getCheckedBody() const;
-
   /// Returns true if this builder is able to infer the MLIRContext parameter.
   bool hasInferredContextParameter() const;
 };
@@ -106,9 +102,9 @@ class TypeDef {
   // generated.
   bool genAccessors() const;
 
-  // Return true if we need to generate the verifyConstructionInvariants
-  // declaration and getChecked method.
-  bool genVerifyInvariantsDecl() const;
+  // Return true if we need to generate the verify declaration and getChecked
+  // method.
+  bool genVerifyDecl() const;
 
   // Returns the dialects extra class declaration code.
   Optional<StringRef> getExtraDecls() const;

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index a840ce01273b..9152fd06d36a 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1466,8 +1466,7 @@ namespace {
 /// CRTP base class for Python MLIR values that subclass Value and should be
 /// castable from it. The value hierarchy is one level deep and is not supposed
 /// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
+template <typename DerivedTy> class PyConcreteValue : public PyValue {
 public:
   // Derived classes must define statics for:
   //   IsAFunctionTy isaFunction
@@ -1868,7 +1867,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
     c.def_static(
         "get",
         [](PyType &type, double value, DefaultingPyLocation loc) {
-          MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
+          MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirAttributeIsNull(attr)) {
@@ -2765,8 +2764,8 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
         "get",
         [](std::vector<int64_t> shape, PyType &elementType,
            DefaultingPyLocation loc) {
-          MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
-                                                elementType, loc);
+          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+                                                elementType);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -2797,7 +2796,7 @@ class PyRankedTensorType
         [](std::vector<int64_t> shape, PyType &elementType,
            DefaultingPyLocation loc) {
           MlirType t = mlirRankedTensorTypeGetChecked(
-              shape.size(), shape.data(), elementType, loc);
+              loc, shape.size(), shape.data(), elementType);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -2828,7 +2827,7 @@ class PyUnrankedTensorType
     c.def_static(
         "get",
         [](PyType &elementType, DefaultingPyLocation loc) {
-          MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
+          MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
           // TODO: Rework error reporting once diagnostic engine is exposed
           // in C API.
           if (mlirTypeIsNull(t)) {
@@ -2869,9 +2868,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
            for (PyAffineMap &map : layout)
              maps.push_back(map);
 
-           MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(),
+           MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
                                                  shape.data(), maps.size(),
-                                                 maps.data(), memorySpace, loc);
+                                                 maps.data(), memorySpace);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {
@@ -2948,7 +2947,7 @@ class PyUnrankedMemRefType
          [](PyType &elementType, unsigned memorySpace,
             DefaultingPyLocation loc) {
            MlirType t =
-               mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
+               mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace);
            // TODO: Rework error reporting once diagnostic engine is exposed
            // in C API.
            if (mlirTypeIsNull(t)) {

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 9e61e3a9d6e0..6d36da6297bf 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -103,9 +103,9 @@ MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
   return wrap(FloatAttr::get(unwrap(type), value));
 }
 
-MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value,
-                                            MlirLocation loc) {
-  return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc)));
+MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
+                                            double value) {
+  return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
 }
 
 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 2de2fa1afde2..10cdbc4b1658 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -169,8 +169,8 @@ MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                       unwrap(elementType)));
 }
 
-MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
-                                  MlirType elementType, MlirLocation loc) {
+MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
+                                  const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
       unwrap(elementType)));
@@ -197,9 +197,9 @@ MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
       unwrap(elementType)));
 }
 
-MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
-                                        MlirType elementType,
-                                        MlirLocation loc) {
+MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
+                                        const int64_t *shape,
+                                        MlirType elementType) {
   return wrap(RankedTensorType::getChecked(
       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
       unwrap(elementType)));
@@ -209,8 +209,8 @@ MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
   return wrap(UnrankedTensorType::get(unwrap(elementType)));
 }
 
-MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
-                                          MlirLocation loc) {
+MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
+                                          MlirType elementType) {
   return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
 }
 
@@ -231,10 +231,11 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
                       unwrap(elementType), maps, memorySpace));
 }
 
-MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank,
-                                  const int64_t *shape, intptr_t numMaps,
+MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
+                                  intptr_t rank, const int64_t *shape,
+                                  intptr_t numMaps,
                                   MlirAffineMap const *affineMaps,
-                                  unsigned memorySpace, MlirLocation loc) {
+                                  unsigned memorySpace) {
   SmallVector<AffineMap, 1> maps;
   (void)unwrapList(numMaps, affineMaps, maps);
   return wrap(MemRefType::getChecked(
@@ -250,10 +251,10 @@ MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
                       unwrap(elementType), llvm::None, memorySpace));
 }
 
-MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
+MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
+                                            MlirType elementType, intptr_t rank,
                                             const int64_t *shape,
-                                            unsigned memorySpace,
-                                            MlirLocation loc) {
+                                            unsigned memorySpace) {
   return wrap(MemRefType::getChecked(
       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
       unwrap(elementType), llvm::None, memorySpace));
@@ -280,9 +281,9 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
   return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
 }
 
-MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
-                                          unsigned memorySpace,
-                                          MlirLocation loc) {
+MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
+                                          MlirType elementType,
+                                          unsigned memorySpace) {
   return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
                                              memorySpace));
 }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 61499ee59a30..f32137a78479 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -187,7 +187,7 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
   // Function type without arguments.
   if (succeeded(parser.parseOptionalRParen())) {
     if (succeeded(parser.parseGreater()))
-      return LLVMFunctionType::getChecked(loc, returnType, {},
+      return LLVMFunctionType::getChecked(loc, returnType, llvm::None,
                                           /*isVarArg=*/false);
     return LLVMFunctionType();
   }
@@ -345,7 +345,8 @@ static LLVMStructType parseStructType(DialectAsmParser &parser) {
     if (knownStructNames.count(name)) {
       if (failed(parser.parseGreater()))
         return LLVMStructType();
-      return LLVMStructType::getIdentifiedChecked(loc, name);
+      return LLVMStructType::getIdentifiedChecked(
+          [loc] { return emitError(loc); }, loc.getContext(), name);
     }
     if (failed(parser.parseComma()))
       return LLVMStructType();
@@ -359,7 +360,8 @@ static LLVMStructType parseStructType(DialectAsmParser &parser) {
              LLVMStructType();
     if (failed(parser.parseGreater()))
       return LLVMStructType();
-    auto type = LLVMStructType::getOpaqueChecked(loc, name);
+    auto type = LLVMStructType::getOpaqueChecked(
+        [loc] { return emitError(loc); }, loc.getContext(), name);
     if (!type.isOpaque()) {
       parser.emitError(kwLoc, "redeclaring defined struct as opaque");
       return LLVMStructType();
@@ -377,8 +379,10 @@ static LLVMStructType parseStructType(DialectAsmParser &parser) {
     if (failed(parser.parseGreater()))
       return LLVMStructType();
     if (!isIdentified)
-      return LLVMStructType::getLiteralChecked(loc, {}, isPacked);
-    auto type = LLVMStructType::getIdentifiedChecked(loc, name);
+      return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
+                                               loc.getContext(), {}, isPacked);
+    auto type = LLVMStructType::getIdentifiedChecked(
+        [loc] { return emitError(loc); }, loc.getContext(), name);
     return trySetStructBody(type, {}, isPacked, parser, kwLoc);
   }
 
@@ -402,8 +406,10 @@ static LLVMStructType parseStructType(DialectAsmParser &parser) {
 
   // Construct the struct with body.
   if (!isIdentified)
-    return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked);
-  auto type = LLVMStructType::getIdentifiedChecked(loc, name);
+    return LLVMStructType::getLiteralChecked(
+        [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
+  auto type = LLVMStructType::getIdentifiedChecked(
+      [loc] { return emitError(loc); }, loc.getContext(), name);
   return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 0d5189083c75..f2abe89b54bb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -39,10 +39,12 @@ LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
   return Base::get(elementType.getContext(), elementType, numElements);
 }
 
-LLVMArrayType LLVMArrayType::getChecked(Location loc, Type elementType,
-                                        unsigned numElements) {
+LLVMArrayType
+LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                          Type elementType, unsigned numElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::getChecked(loc, elementType, numElements);
+  return Base::getChecked(emitError, elementType.getContext(), elementType,
+                          numElements);
 }
 
 Type LLVMArrayType::getElementType() { return getImpl()->elementType; }
@@ -50,10 +52,10 @@ Type LLVMArrayType::getElementType() { return getImpl()->elementType; }
 unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; }
 
 LogicalResult
-LLVMArrayType::verifyConstructionInvariants(Location loc, Type elementType,
-                                            unsigned numElements) {
+LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
+                      Type elementType, unsigned numElements) {
   if (!isValidElementType(elementType))
-    return emitError(loc, "invalid array element type: ") << elementType;
+    return emitError() << "invalid array element type: " << elementType;
   return success();
 }
 
@@ -75,11 +77,13 @@ LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
   return Base::get(result.getContext(), result, arguments, isVarArg);
 }
 
-LLVMFunctionType LLVMFunctionType::getChecked(Location loc, Type result,
-                                              ArrayRef<Type> arguments,
-                                              bool isVarArg) {
+LLVMFunctionType
+LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                             Type result, ArrayRef<Type> arguments,
+                             bool isVarArg) {
   assert(result && "expected non-null result");
-  return Base::getChecked(loc, result, arguments, isVarArg);
+  return Base::getChecked(emitError, result.getContext(), result, arguments,
+                          isVarArg);
 }
 
 Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); }
@@ -99,14 +103,14 @@ ArrayRef<Type> LLVMFunctionType::getParams() {
 }
 
 LogicalResult
-LLVMFunctionType::verifyConstructionInvariants(Location loc, Type result,
-                                               ArrayRef<Type> arguments, bool) {
+LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
+                         Type result, ArrayRef<Type> arguments, bool) {
   if (!isValidResultType(result))
-    return emitError(loc, "invalid function result type: ") << result;
+    return emitError() << "invalid function result type: " << result;
 
   for (Type arg : arguments)
     if (!isValidArgumentType(arg))
-      return emitError(loc, "invalid function argument type: ") << arg;
+      return emitError() << "invalid function argument type: " << arg;
 
   return success();
 }
@@ -125,20 +129,22 @@ LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
   return Base::get(pointee.getContext(), pointee, addressSpace);
 }
 
-LLVMPointerType LLVMPointerType::getChecked(Location loc, Type pointee,
-                                            unsigned addressSpace) {
-  return Base::getChecked(loc, pointee, addressSpace);
+LLVMPointerType
+LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                            Type pointee, unsigned addressSpace) {
+  return Base::getChecked(emitError, pointee.getContext(), pointee,
+                          addressSpace);
 }
 
 Type LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
 
 unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; }
 
-LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
-                                                            Type pointee,
-                                                            unsigned) {
+LogicalResult
+LLVMPointerType::verify(function_ref<InFlightDiagnostic()> emitError,
+                        Type pointee, unsigned) {
   if (!isValidElementType(pointee))
-    return emitError(loc, "invalid pointer element type: ") << pointee;
+    return emitError() << "invalid pointer element type: " << pointee;
   return success();
 }
 
@@ -156,9 +162,10 @@ LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
   return Base::get(context, name, /*opaque=*/false);
 }
 
-LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
-                                                    StringRef name) {
-  return Base::getChecked(loc, name, /*opaque=*/false);
+LLVMStructType LLVMStructType::getIdentifiedChecked(
+    function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+    StringRef name) {
+  return Base::getChecked(emitError, context, name, /*opaque=*/false);
 }
 
 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
@@ -183,18 +190,21 @@ LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
   return Base::get(context, types, isPacked);
 }
 
-LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
-                                                 ArrayRef<Type> types,
-                                                 bool isPacked) {
-  return Base::getChecked(loc, types, isPacked);
+LLVMStructType
+LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  MLIRContext *context, ArrayRef<Type> types,
+                                  bool isPacked) {
+  return Base::getChecked(emitError, context, types, isPacked);
 }
 
 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
   return Base::get(context, name, /*opaque=*/true);
 }
 
-LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
-  return Base::getChecked(loc, name, /*opaque=*/true);
+LLVMStructType
+LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
+                                 MLIRContext *context, StringRef name) {
+  return Base::getChecked(emitError, context, name, /*opaque=*/true);
 }
 
 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
@@ -217,17 +227,17 @@ ArrayRef<Type> LLVMStructType::getBody() {
                         : getImpl()->getTypeList();
 }
 
-LogicalResult LLVMStructType::verifyConstructionInvariants(Location, StringRef,
-                                                           bool) {
+LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
+                                     StringRef, bool) {
   return success();
 }
 
-LogicalResult LLVMStructType::verifyConstructionInvariants(Location loc,
-                                                           ArrayRef<Type> types,
-                                                           bool) {
+LogicalResult
+LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
+                       ArrayRef<Type> types, bool) {
   for (Type t : types)
     if (!isValidElementType(t))
-      return emitError(loc, "invalid LLVM structure element type: ") << t;
+      return emitError() << "invalid LLVM structure element type: " << t;
 
   return success();
 }
@@ -238,14 +248,14 @@ LogicalResult LLVMStructType::verifyConstructionInvariants(Location loc,
 
 /// Verifies that the type about to be constructed is well-formed.
 template <typename VecTy>
-static LogicalResult verifyVectorConstructionInvariants(Location loc,
-                                                        Type elementType,
-                                                        unsigned numElements) {
+static LogicalResult
+verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                   Type elementType, unsigned numElements) {
   if (numElements == 0)
-    return emitError(loc, "the number of vector elements must be positive");
+    return emitError() << "the number of vector elements must be positive";
 
   if (!VecTy::isValidElementType(elementType))
-    return emitError(loc, "invalid vector element type");
+    return emitError() << "invalid vector element type";
 
   return success();
 }
@@ -256,11 +266,12 @@ LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
   return Base::get(elementType.getContext(), elementType, numElements);
 }
 
-LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
-                                                    Type elementType,
-                                                    unsigned numElements) {
+LLVMFixedVectorType
+LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                Type elementType, unsigned numElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::getChecked(loc, elementType, numElements);
+  return Base::getChecked(emitError, elementType.getContext(), elementType,
+                          numElements);
 }
 
 Type LLVMFixedVectorType::getElementType() {
@@ -275,10 +286,11 @@ bool LLVMFixedVectorType::isValidElementType(Type type) {
   return type.isa<LLVMPointerType, LLVMPPCFP128Type>();
 }
 
-LogicalResult LLVMFixedVectorType::verifyConstructionInvariants(
-    Location loc, Type elementType, unsigned numElements) {
+LogicalResult
+LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
+                            Type elementType, unsigned numElements) {
   return verifyVectorConstructionInvariants<LLVMFixedVectorType>(
-      loc, elementType, numElements);
+      emitError, elementType, numElements);
 }
 
 //===----------------------------------------------------------------------===//
@@ -292,10 +304,11 @@ LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
 }
 
 LLVMScalableVectorType
-LLVMScalableVectorType::getChecked(Location loc, Type elementType,
-                                   unsigned minNumElements) {
+LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                   Type elementType, unsigned minNumElements) {
   assert(elementType && "expected non-null subtype");
-  return Base::getChecked(loc, elementType, minNumElements);
+  return Base::getChecked(emitError, elementType.getContext(), elementType,
+                          minNumElements);
 }
 
 Type LLVMScalableVectorType::getElementType() {
@@ -313,10 +326,11 @@ bool LLVMScalableVectorType::isValidElementType(Type type) {
   return isCompatibleFloatingPointType(type) || type.isa<LLVMPointerType>();
 }
 
-LogicalResult LLVMScalableVectorType::verifyConstructionInvariants(
-    Location loc, Type elementType, unsigned numElements) {
+LogicalResult
+LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
+                               Type elementType, unsigned numElements) {
   return verifyVectorConstructionInvariants<LLVMScalableVectorType>(
-      loc, elementType, numElements);
+      emitError, elementType, numElements);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 195847d5d602..7901b6d01e50 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -28,20 +28,21 @@ bool QuantizedType::classof(Type type) {
   return llvm::isa<QuantizationDialect>(type.getDialect());
 }
 
-LogicalResult QuantizedType::verifyConstructionInvariants(
-    Location loc, unsigned flags, Type storageType, Type expressedType,
-    int64_t storageTypeMin, int64_t storageTypeMax) {
+LogicalResult
+QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
+                      unsigned flags, Type storageType, Type expressedType,
+                      int64_t storageTypeMin, int64_t storageTypeMax) {
   // Verify that the storage type is integral.
   // This restriction may be lifted at some point in favor of using bf16
   // or f16 as exact representations on hardware where that is advantageous.
   auto intStorageType = storageType.dyn_cast<IntegerType>();
   if (!intStorageType)
-    return emitError(loc, "storage type must be integral");
+    return emitError() << "storage type must be integral";
   unsigned integralWidth = intStorageType.getWidth();
 
   // Verify storage width.
   if (integralWidth == 0 || integralWidth > MaxStorageBits)
-    return emitError(loc, "illegal storage type size: ") << integralWidth;
+    return emitError() << "illegal storage type size: " << integralWidth;
 
   // Verify storageTypeMin and storageTypeMax.
   bool isSigned =
@@ -53,8 +54,8 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
   if (storageTypeMax - storageTypeMin <= 0 ||
       storageTypeMin < defaultIntegerMin ||
       storageTypeMax > defaultIntegerMax) {
-    return emitError(loc, "illegal storage min and storage max: (")
-           << storageTypeMin << ":" << storageTypeMax << ")";
+    return emitError() << "illegal storage min and storage max: ("
+                       << storageTypeMin << ":" << storageTypeMax << ")";
   }
   return success();
 }
@@ -208,21 +209,22 @@ AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
                    storageTypeMin, storageTypeMax);
 }
 
-AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
-                                              Type expressedType,
-                                              int64_t storageTypeMin,
-                                              int64_t storageTypeMax,
-                                              Location location) {
-  return Base::getChecked(location, flags, storageType, expressedType,
-                          storageTypeMin, storageTypeMax);
+AnyQuantizedType
+AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                             unsigned flags, Type storageType,
+                             Type expressedType, int64_t storageTypeMin,
+                             int64_t storageTypeMax) {
+  return Base::getChecked(emitError, storageType.getContext(), flags,
+                          storageType, expressedType, storageTypeMin,
+                          storageTypeMax);
 }
 
-LogicalResult AnyQuantizedType::verifyConstructionInvariants(
-    Location loc, unsigned flags, Type storageType, Type expressedType,
-    int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verifyConstructionInvariants(
-          loc, flags, storageType, expressedType, storageTypeMin,
-          storageTypeMax))) {
+LogicalResult
+AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
+                         unsigned flags, Type storageType, Type expressedType,
+                         int64_t storageTypeMin, int64_t storageTypeMax) {
+  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
+                                   storageTypeMin, storageTypeMax))) {
     return failure();
   }
 
@@ -230,7 +232,7 @@ LogicalResult AnyQuantizedType::verifyConstructionInvariants(
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
   if (expressedType && !expressedType.isa<FloatType>())
-    return emitError(loc, "expressed type must be floating point");
+    return emitError() << "expressed type must be floating point";
 
   return success();
 }
@@ -244,39 +246,38 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
                    scale, zeroPoint, storageTypeMin, storageTypeMax);
 }
 
-UniformQuantizedType
-UniformQuantizedType::getChecked(unsigned flags, Type storageType,
-                                 Type expressedType, double scale,
-                                 int64_t zeroPoint, int64_t storageTypeMin,
-                                 int64_t storageTypeMax, Location location) {
-  return Base::getChecked(location, flags, storageType, expressedType, scale,
-                          zeroPoint, storageTypeMin, storageTypeMax);
+UniformQuantizedType UniformQuantizedType::getChecked(
+    function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+    Type storageType, Type expressedType, double scale, int64_t zeroPoint,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  return Base::getChecked(emitError, storageType.getContext(), flags,
+                          storageType, expressedType, scale, zeroPoint,
+                          storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedType::verifyConstructionInvariants(
-    Location loc, unsigned flags, Type storageType, Type expressedType,
-    double scale, int64_t zeroPoint, int64_t storageTypeMin,
-    int64_t storageTypeMax) {
-  if (failed(QuantizedType::verifyConstructionInvariants(
-          loc, flags, storageType, expressedType, storageTypeMin,
-          storageTypeMax))) {
+LogicalResult UniformQuantizedType::verify(
+    function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+    Type storageType, Type expressedType, double scale, int64_t zeroPoint,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
+                                   storageTypeMin, storageTypeMax))) {
     return failure();
   }
 
   // Uniform quantization requires fully expressed parameters, including
   // expressed type.
   if (!expressedType)
-    return emitError(loc, "uniform quantization requires expressed type");
+    return emitError() << "uniform quantization requires expressed type";
 
   // Verify that the expressed type is floating point.
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
   if (!expressedType.isa<FloatType>())
-    return emitError(loc, "expressed type must be floating point");
+    return emitError() << "expressed type must be floating point";
 
   // Verify scale.
   if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
-    return emitError(loc, "illegal scale: ") << scale;
+    return emitError() << "illegal scale: " << scale;
 
   return success();
 }
@@ -298,46 +299,45 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
 }
 
 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
-    unsigned flags, Type storageType, Type expressedType,
-    ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
-    int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
-    Location location) {
-  return Base::getChecked(location, flags, storageType, expressedType, scales,
-                          zeroPoints, quantizedDimension, storageTypeMin,
-                          storageTypeMax);
+    function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+    Type storageType, Type expressedType, ArrayRef<double> scales,
+    ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  return Base::getChecked(emitError, storageType.getContext(), flags,
+                          storageType, expressedType, scales, zeroPoints,
+                          quantizedDimension, storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
-    Location loc, unsigned flags, Type storageType, Type expressedType,
-    ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
-    int32_t quantizedDimension, int64_t storageTypeMin,
-    int64_t storageTypeMax) {
-  if (failed(QuantizedType::verifyConstructionInvariants(
-          loc, flags, storageType, expressedType, storageTypeMin,
-          storageTypeMax))) {
+LogicalResult UniformQuantizedPerAxisType::verify(
+    function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+    Type storageType, Type expressedType, ArrayRef<double> scales,
+    ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
+                                   storageTypeMin, storageTypeMax))) {
     return failure();
   }
 
   // Uniform quantization requires fully expressed parameters, including
   // expressed type.
   if (!expressedType)
-    return emitError(loc, "uniform quantization requires expressed type");
+    return emitError() << "uniform quantization requires expressed type";
 
   // Verify that the expressed type is floating point.
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
   if (!expressedType.isa<FloatType>())
-    return emitError(loc, "expressed type must be floating point");
+    return emitError() << "expressed type must be floating point";
 
   // Ensure that the number of scales and zeroPoints match.
   if (scales.size() != zeroPoints.size())
-    return emitError(loc, "illegal number of scales and zeroPoints: ")
-           << scales.size() << ", " << zeroPoints.size();
+    return emitError() << "illegal number of scales and zeroPoints: "
+                       << scales.size() << ", " << zeroPoints.size();
 
   // Verify scale.
   for (double scale : scales) {
     if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
-      return emitError(loc, "illegal scale: ") << scale;
+      return emitError() << "illegal scale: " << scale;
   }
 
   return success();
@@ -360,22 +360,23 @@ CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
   return Base::get(expressedType.getContext(), expressedType, min, max);
 }
 
-CalibratedQuantizedType CalibratedQuantizedType::getChecked(Type expressedType,
-                                                            double min,
-                                                            double max,
-                                                            Location location) {
-  return Base::getChecked(location, expressedType, min, max);
+CalibratedQuantizedType CalibratedQuantizedType::getChecked(
+    function_ref<InFlightDiagnostic()> emitError, Type expressedType,
+    double min, double max) {
+  return Base::getChecked(emitError, expressedType.getContext(), expressedType,
+                          min, max);
 }
 
-LogicalResult CalibratedQuantizedType::verifyConstructionInvariants(
-    Location loc, Type expressedType, double min, double max) {
+LogicalResult
+CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
+                                Type expressedType, double min, double max) {
   // Verify that the expressed type is floating point.
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
   if (!expressedType.isa<FloatType>())
-    return emitError(loc, "expressed type must be floating point");
+    return emitError() << "expressed type must be floating point";
   if (max <= min)
-    return emitError(loc, "illegal min and max: (") << min << ":" << max << ")";
+    return emitError() << "illegal min and max: (" << min << ":" << max << ")";
 
   return success();
 }

diff  --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 4a62bfe8c9fc..636cf7ddb96c 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -155,8 +155,9 @@ static Type parseAnyType(DialectAsmParser &parser, Location loc) {
     return nullptr;
   }
 
-  return AnyQuantizedType::getChecked(typeFlags, storageType, expressedType,
-                                      storageTypeMin, storageTypeMax, loc);
+  return AnyQuantizedType::getChecked(loc, typeFlags, storageType,
+                                      expressedType, storageTypeMin,
+                                      storageTypeMax);
 }
 
 static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
@@ -279,13 +280,13 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
     ArrayRef<double> scalesRef(scales.begin(), scales.end());
     ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
     return UniformQuantizedPerAxisType::getChecked(
-        typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
-        quantizedDimension, storageTypeMin, storageTypeMax, loc);
+        loc, typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
+        quantizedDimension, storageTypeMin, storageTypeMax);
   }
 
-  return UniformQuantizedType::getChecked(typeFlags, storageType, expressedType,
-                                          scales.front(), zeroPoints.front(),
-                                          storageTypeMin, storageTypeMax, loc);
+  return UniformQuantizedType::getChecked(
+      loc, typeFlags, storageType, expressedType, scales.front(),
+      zeroPoints.front(), storageTypeMin, storageTypeMax);
 }
 
 /// Parses an CalibratedQuantizedType.
@@ -313,7 +314,7 @@ static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
     return nullptr;
   }
 
-  return CalibratedQuantizedType::getChecked(expressedType, min, max, loc);
+  return CalibratedQuantizedType::getChecked(loc, expressedType, min, max);
 }
 
 /// Parse a type registered to this dialect.

diff  --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
index dd1d0585b414..7750ad478f9d 100644
--- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
@@ -123,17 +123,17 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
   // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
   // points and dequantized to 0.0.
   if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
-    return UniformQuantizedType::getChecked(flags, storageType, expressedType,
-                                            1.0, qmin, qmin, qmax, loc);
+    return UniformQuantizedType::getChecked(
+        loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax);
   }
 
   double scale;
   int64_t nudgedZeroPoint;
   getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
 
-  return UniformQuantizedType::getChecked(flags, storageType, expressedType,
-                                          scale, nudgedZeroPoint, qmin, qmax,
-                                          loc);
+  return UniformQuantizedType::getChecked(loc, flags, storageType,
+                                          expressedType, scale, nudgedZeroPoint,
+                                          qmin, qmax);
 }
 
 UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
@@ -179,6 +179,6 @@ UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
 
   unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
   return UniformQuantizedPerAxisType::getChecked(
-      flags, storageType, expressedType, scales, zeroPoints, quantizedDimension,
-      qmin, qmax, loc);
+      loc, flags, storageType, expressedType, scales, zeroPoints,
+      quantizedDimension, qmin, qmax);
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 98e0b94c8bc7..c74c34d88dd7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -162,23 +162,23 @@ Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
   return llvm::None;
 }
 
-LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants(
-    Location loc, IntegerAttr descriptorSet, IntegerAttr binding,
-    IntegerAttr storageClass) {
+LogicalResult spirv::InterfaceVarABIAttr::verify(
+    function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
+    IntegerAttr binding, IntegerAttr storageClass) {
   if (!descriptorSet.getType().isSignlessInteger(32))
-    return emitError(loc, "expected 32-bit integer for descriptor set");
+    return emitError() << "expected 32-bit integer for descriptor set";
 
   if (!binding.getType().isSignlessInteger(32))
-    return emitError(loc, "expected 32-bit integer for binding");
+    return emitError() << "expected 32-bit integer for binding";
 
   if (storageClass) {
     if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
       auto storageClassValue =
           spirv::symbolizeStorageClass(storageClassAttr.getInt());
       if (!storageClassValue)
-        return emitError(loc, "unknown storage class");
+        return emitError() << "unknown storage class";
     } else {
-      return emitError(loc, "expected valid storage class");
+      return emitError() << "expected valid storage class";
     }
   }
 
@@ -257,11 +257,12 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
   return getImpl()->capabilities.cast<ArrayAttr>();
 }
 
-LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants(
-    Location loc, IntegerAttr version, ArrayAttr capabilities,
-    ArrayAttr extensions) {
+LogicalResult
+spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                             IntegerAttr version, ArrayAttr capabilities,
+                             ArrayAttr extensions) {
   if (!version.getType().isSignlessInteger(32))
-    return emitError(loc, "expected 32-bit integer for version");
+    return emitError() << "expected 32-bit integer for version";
 
   if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
         if (auto intAttr = attr.dyn_cast<IntegerAttr>())
@@ -269,7 +270,7 @@ LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants(
             return true;
         return false;
       }))
-    return emitError(loc, "unknown capability in capability list");
+    return emitError() << "unknown capability in capability list";
 
   if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
         if (auto strAttr = attr.dyn_cast<StringAttr>())
@@ -277,7 +278,7 @@ LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants(
             return true;
         return false;
       }))
-    return emitError(loc, "unknown extension in extension list");
+    return emitError() << "unknown extension in extension list";
 
   return success();
 }
@@ -338,12 +339,14 @@ spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
   return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
 }
 
-LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
-    Location loc, spirv::VerCapExtAttr /*triple*/, spirv::Vendor /*vendorID*/,
-    spirv::DeviceType /*deviceType*/, uint32_t /*deviceID*/,
-    DictionaryAttr limits) {
+LogicalResult
+spirv::TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                             spirv::VerCapExtAttr /*triple*/,
+                             spirv::Vendor /*vendorID*/,
+                             spirv::DeviceType /*deviceType*/,
+                             uint32_t /*deviceID*/, DictionaryAttr limits) {
   if (!limits.isa<spirv::ResourceLimitsAttr>())
-    return emitError(loc, "expected spirv::ResourceLimitsAttr for limits");
+    return emitError() << "expected spirv::ResourceLimitsAttr for limits";
 
   return success();
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 7e80c8d47a53..2bfd9b8f084f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -260,42 +260,33 @@ void CooperativeMatrixNVType::getCapabilities(
 // ImageType
 //===----------------------------------------------------------------------===//
 
-template <typename T>
-static constexpr unsigned getNumBits() {
-  return 0;
-}
-template <>
-constexpr unsigned getNumBits<Dim>() {
+template <typename T> static constexpr unsigned getNumBits() { return 0; }
+template <> constexpr unsigned getNumBits<Dim>() {
   static_assert((1 << 3) > getMaxEnumValForDim(),
                 "Not enough bits to encode Dim value");
   return 3;
 }
-template <>
-constexpr unsigned getNumBits<ImageDepthInfo>() {
+template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
   static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
                 "Not enough bits to encode ImageDepthInfo value");
   return 2;
 }
-template <>
-constexpr unsigned getNumBits<ImageArrayedInfo>() {
+template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
   static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
                 "Not enough bits to encode ImageArrayedInfo value");
   return 1;
 }
-template <>
-constexpr unsigned getNumBits<ImageSamplingInfo>() {
+template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
   static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
                 "Not enough bits to encode ImageSamplingInfo value");
   return 1;
 }
-template <>
-constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
+template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
   static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
                 "Not enough bits to encode ImageSamplerUseInfo value");
   return 2;
 }
-template <>
-constexpr unsigned getNumBits<ImageFormat>() {
+template <> constexpr unsigned getNumBits<ImageFormat>() {
   static_assert((1 << 6) > getMaxEnumValForImageFormat(),
                 "Not enough bits to encode ImageFormat value");
   return 6;
@@ -730,17 +721,19 @@ SampledImageType SampledImageType::get(Type imageType) {
   return Base::get(imageType.getContext(), imageType);
 }
 
-SampledImageType SampledImageType::getChecked(Type imageType,
-                                              Location location) {
-  return Base::getChecked(location, imageType);
+SampledImageType
+SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                             Type imageType) {
+  return Base::getChecked(emitError, imageType.getContext(), imageType);
 }
 
 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
 
-LogicalResult SampledImageType::verifyConstructionInvariants(Location loc,
-                                                             Type imageType) {
+LogicalResult
+SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
+                         Type imageType) {
   if (!imageType.isa<ImageType>())
-    return emitError(loc, "expected image type");
+    return emitError() << "expected image type";
 
   return success();
 }
@@ -1095,27 +1088,27 @@ MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
   return Base::get(columnType.getContext(), columnType, columnCount);
 }
 
-MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
-                                  Location location) {
-  return Base::getChecked(location, columnType, columnCount);
+MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  Type columnType, uint32_t columnCount) {
+  return Base::getChecked(emitError, columnType.getContext(), columnType,
+                          columnCount);
 }
 
-LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
-                                                       Type columnType,
-                                                       uint32_t columnCount) {
+LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 Type columnType, uint32_t columnCount) {
   if (columnCount < 2 || columnCount > 4)
-    return emitError(loc, "matrix can have 2, 3, or 4 columns only");
+    return emitError() << "matrix can have 2, 3, or 4 columns only";
 
   if (!isValidColumnType(columnType))
-    return emitError(loc, "matrix columns must be vectors of floats");
+    return emitError() << "matrix columns must be vectors of floats";
 
   /// The underlying vectors (columns) must be of size 2, 3, or 4
   ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
   if (columnShape.size() != 1)
-    return emitError(loc, "matrix columns must be 1D vectors");
+    return emitError() << "matrix columns must be 1D vectors";
 
   if (columnShape[0] < 2 || columnShape[0] > 4)
-    return emitError(loc, "matrix columns must be of size 2, 3, or 4");
+    return emitError() << "matrix columns must be of size 2, 3, or 4";
 
   return success();
 }

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 58a5b3370364..f3b329164cd1 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -211,16 +211,18 @@ FloatAttr FloatAttr::get(Type type, double value) {
   return Base::get(type.getContext(), type, value);
 }
 
-FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
-  return Base::getChecked(loc, type, value);
+FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                Type type, double value) {
+  return Base::getChecked(emitError, type.getContext(), type, value);
 }
 
 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
   return Base::get(type.getContext(), type, value);
 }
 
-FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
-  return Base::getChecked(loc, type, value);
+FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                Type type, const APFloat &value) {
+  return Base::getChecked(emitError, type.getContext(), type, value);
 }
 
 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
@@ -238,27 +240,29 @@ double FloatAttr::getValueAsDouble(APFloat value) {
 }
 
 /// Verify construction invariants.
-static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
+static LogicalResult
+verifyFloatTypeInvariants(function_ref<InFlightDiagnostic()> emitError,
+                          Type type) {
   if (!type.isa<FloatType>())
-    return emitError(loc, "expected floating point type");
+    return emitError() << "expected floating point type";
   return success();
 }
 
-LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
-                                                      double value) {
-  return verifyFloatTypeInvariants(loc, type);
+LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                Type type, double value) {
+  return verifyFloatTypeInvariants(emitError, type);
 }
 
-LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
-                                                      const APFloat &value) {
+LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                Type type, const APFloat &value) {
   // Verify that the type is correct.
-  if (failed(verifyFloatTypeInvariants(loc, type)))
+  if (failed(verifyFloatTypeInvariants(emitError, type)))
     return failure();
 
   // Verify that the type semantics match that of the value.
   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
-    return emitError(
-        loc, "FloatAttr type doesn't match the type implied by its value");
+    return emitError()
+           << "FloatAttr type doesn't match the type implied by its value";
   }
   return success();
 }
@@ -326,26 +330,28 @@ uint64_t IntegerAttr::getUInt() const {
   return getValue().getZExtValue();
 }
 
-static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
+static LogicalResult
+verifyIntegerTypeInvariants(function_ref<InFlightDiagnostic()> emitError,
+                            Type type) {
   if (type.isa<IntegerType, IndexType>())
     return success();
-  return emitError(loc, "expected integer or index type");
+  return emitError() << "expected integer or index type";
 }
 
-LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
-                                                        int64_t value) {
-  return verifyIntegerTypeInvariants(loc, type);
+LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                  Type type, int64_t value) {
+  return verifyIntegerTypeInvariants(emitError, type);
 }
 
-LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
-                                                        const APInt &value) {
-  if (failed(verifyIntegerTypeInvariants(loc, type)))
+LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                  Type type, const APInt &value) {
+  if (failed(verifyIntegerTypeInvariants(emitError, type)))
     return failure();
   if (auto integerType = type.dyn_cast<IntegerType>())
     if (integerType.getWidth() != value.getBitWidth())
-      return emitError(loc, "integer type bit width (")
-             << integerType.getWidth() << ") doesn't match value bit width ("
-             << value.getBitWidth() << ")";
+      return emitError() << "integer type bit width (" << integerType.getWidth()
+                         << ") doesn't match value bit width ("
+                         << value.getBitWidth() << ")";
   return success();
 }
 
@@ -381,9 +387,11 @@ OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
   return Base::get(context, dialect, attrData, type);
 }
 
-OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
-                                  Type type, Location location) {
-  return Base::getChecked(location, dialect, attrData, type);
+OpaqueAttr OpaqueAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  Identifier dialect, StringRef attrData,
+                                  Type type) {
+  return Base::getChecked(emitError, dialect.getContext(), dialect, attrData,
+                          type);
 }
 
 /// Returns the dialect namespace of the opaque attribute.
@@ -395,12 +403,11 @@ Identifier OpaqueAttr::getDialectNamespace() const {
 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
 
 /// Verify the construction of an opaque attribute.
-LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
-                                                       Identifier dialect,
-                                                       StringRef attrData,
-                                                       Type type) {
+LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 Identifier dialect, StringRef attrData,
+                                 Type type) {
   if (!Dialect::isValidNamespace(dialect.strref()))
-    return emitError(loc, "invalid dialect namespace '") << dialect << "'";
+    return emitError() << "invalid dialect namespace '" << dialect << "'";
   return success();
 }
 

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index cedc6ad3c2d8..9b15854919e0 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -32,10 +32,10 @@ using namespace mlir::detail;
 //===----------------------------------------------------------------------===//
 
 /// Verify the construction of an integer type.
-LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
-                                                        Type elementType) {
+LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
+                                  Type elementType) {
   if (!elementType.isIntOrFloat())
-    return emitError(loc, "invalid element type for complex");
+    return emitError() << "invalid element type for complex";
   return success();
 }
 
@@ -47,12 +47,12 @@ LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
 constexpr unsigned IntegerType::kMaxWidth;
 
 /// Verify the construction of an integer type.
-LogicalResult
-IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
-                                          SignednessSemantics signedness) {
+LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
+                                  unsigned width,
+                                  SignednessSemantics signedness) {
   if (width > IntegerType::kMaxWidth) {
-    return emitError(loc) << "integer bitwidth is limited to "
-                          << IntegerType::kMaxWidth << " bits";
+    return emitError() << "integer bitwidth is limited to "
+                       << IntegerType::kMaxWidth << " bits";
   }
   return success();
 }
@@ -183,11 +183,10 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
 //===----------------------------------------------------------------------===//
 
 /// Verify the construction of an opaque type.
-LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
-                                                       Identifier dialect,
-                                                       StringRef typeData) {
+LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 Identifier dialect, StringRef typeData) {
   if (!Dialect::isValidNamespace(dialect.strref()))
-    return emitError(loc, "invalid dialect namespace '") << dialect << "'";
+    return emitError() << "invalid dialect namespace '" << dialect << "'";
   return success();
 }
 
@@ -362,22 +361,22 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
   return Base::get(elementType.getContext(), shape, elementType);
 }
 
-VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
-                                  Type elementType) {
-  return Base::getChecked(location, shape, elementType);
+VectorType VectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  ArrayRef<int64_t> shape, Type elementType) {
+  return Base::getChecked(emitError, elementType.getContext(), shape,
+                          elementType);
 }
 
-LogicalResult VectorType::verifyConstructionInvariants(Location loc,
-                                                       ArrayRef<int64_t> shape,
-                                                       Type elementType) {
+LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<int64_t> shape, Type elementType) {
   if (shape.empty())
-    return emitError(loc, "vector types must have at least one dimension");
+    return emitError() << "vector types must have at least one dimension";
 
   if (!isValidElementType(elementType))
-    return emitError(loc, "vector elements must be int or float type");
+    return emitError() << "vector elements must be int or float type";
 
   if (any_of(shape, [](int64_t i) { return i <= 0; }))
-    return emitError(loc, "vector types must have positive constant sizes");
+    return emitError() << "vector types must have positive constant sizes";
 
   return success();
 }
@@ -400,12 +399,12 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
 // TensorType
 //===----------------------------------------------------------------------===//
 
-// Check if "elementType" can be an element type of a tensor. Emit errors if
-// location is not nullptr.  Returns failure if check failed.
-static LogicalResult checkTensorElementType(Location location,
-                                            Type elementType) {
+// Check if "elementType" can be an element type of a tensor.
+static LogicalResult
+checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
+                       Type elementType) {
   if (!TensorType::isValidElementType(elementType))
-    return emitError(location, "invalid tensor element type: ") << elementType;
+    return emitError() << "invalid tensor element type: " << elementType;
   return success();
 }
 
@@ -428,19 +427,21 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
   return Base::get(elementType.getContext(), shape, elementType);
 }
 
-RankedTensorType RankedTensorType::getChecked(Location location,
-                                              ArrayRef<int64_t> shape,
-                                              Type elementType) {
-  return Base::getChecked(location, shape, elementType);
+RankedTensorType
+RankedTensorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                             ArrayRef<int64_t> shape, Type elementType) {
+  return Base::getChecked(emitError, elementType.getContext(), shape,
+                          elementType);
 }
 
-LogicalResult RankedTensorType::verifyConstructionInvariants(
-    Location loc, ArrayRef<int64_t> shape, Type elementType) {
+LogicalResult
+RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
+                         ArrayRef<int64_t> shape, Type elementType) {
   for (int64_t s : shape) {
     if (s < -1)
-      return emitError(loc, "invalid tensor dimension size");
+      return emitError() << "invalid tensor dimension size";
   }
-  return checkTensorElementType(loc, elementType);
+  return checkTensorElementType(emitError, elementType);
 }
 
 ArrayRef<int64_t> RankedTensorType::getShape() const {
@@ -455,15 +456,16 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
   return Base::get(elementType.getContext(), elementType);
 }
 
-UnrankedTensorType UnrankedTensorType::getChecked(Location location,
-                                                  Type elementType) {
-  return Base::getChecked(location, elementType);
+UnrankedTensorType
+UnrankedTensorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                               Type elementType) {
+  return Base::getChecked(emitError, elementType.getContext(), elementType);
 }
 
 LogicalResult
-UnrankedTensorType::verifyConstructionInvariants(Location loc,
-                                                 Type elementType) {
-  return checkTensorElementType(loc, elementType);
+UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
+                           Type elementType) {
+  return checkTensorElementType(emitError, elementType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -485,8 +487,10 @@ unsigned BaseMemRefType::getMemorySpace() const {
 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
                            ArrayRef<AffineMap> affineMapComposition,
                            unsigned memorySpace) {
-  auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
-                        /*location=*/llvm::None);
+  auto result =
+      getImpl(shape, elementType, affineMapComposition, memorySpace, [=] {
+        return emitError(UnknownLoc::get(elementType.getContext()));
+      });
   assert(result && "Failed to construct instance of MemRefType.");
   return result;
 }
@@ -497,12 +501,12 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
 /// UnknownLoc.  If the MemRefType defined by the arguments would be
 /// ill-formed, emits errors (to the handler registered with the context or to
 /// the error stream) and returns nullptr.
-MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
-                                  Type elementType,
+MemRefType MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  ArrayRef<int64_t> shape, Type elementType,
                                   ArrayRef<AffineMap> affineMapComposition,
                                   unsigned memorySpace) {
   return getImpl(shape, elementType, affineMapComposition, memorySpace,
-                 location);
+                 emitError);
 }
 
 /// Get or create a new MemRefType defined by the arguments.  If the resulting
@@ -512,18 +516,16 @@ MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
                                ArrayRef<AffineMap> affineMapComposition,
                                unsigned memorySpace,
-                               Optional<Location> location) {
+                               function_ref<InFlightDiagnostic()> emitError) {
   auto *context = elementType.getContext();
 
   if (!BaseMemRefType::isValidElementType(elementType))
-    return (void)emitOptionalError(location, "invalid memref element type"),
-           MemRefType();
+    return (emitError() << "invalid memref element type", MemRefType());
 
   for (int64_t s : shape) {
     // Negative sizes are not allowed except for `-1` that means dynamic size.
     if (s < -1)
-      return (void)emitOptionalError(location, "invalid memref size"),
-             MemRefType();
+      return (emitError() << "invalid memref size", MemRefType());
   }
 
   // Check that the structure of the composition is valid, i.e. that each
@@ -533,12 +535,10 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
   unsigned i = 0;
   for (const auto &affineMap : affineMapComposition) {
     if (affineMap.getNumDims() != dim) {
-      if (location)
-        emitError(*location)
-            << "memref affine map dimension mismatch between "
-            << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
-            << " and affine map" << i + 1 << ": " << dim
-            << " != " << affineMap.getNumDims();
+      emitError() << "memref affine map dimension mismatch between "
+                  << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
+                  << " and affine map" << i + 1 << ": " << dim
+                  << " != " << affineMap.getNumDims();
       return nullptr;
     }
 
@@ -575,17 +575,18 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
   return Base::get(elementType.getContext(), elementType, memorySpace);
 }
 
-UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
-                                                  Type elementType,
-                                                  unsigned memorySpace) {
-  return Base::getChecked(location, elementType, memorySpace);
+UnrankedMemRefType
+UnrankedMemRefType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                               Type elementType, unsigned memorySpace) {
+  return Base::getChecked(emitError, elementType.getContext(), elementType,
+                          memorySpace);
 }
 
 LogicalResult
-UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
-                                                 unsigned memorySpace) {
+UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
+                           Type elementType, unsigned memorySpace) {
   if (!BaseMemRefType::isValidElementType(elementType))
-    return emitError(loc, "invalid memref element type");
+    return emitError() << "invalid memref element type";
   return success();
 }
 

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f637bb261958..d1bd017c6441 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -856,12 +856,13 @@ IntegerType IntegerType::get(MLIRContext *context, unsigned width,
   return Base::get(context, width, signedness);
 }
 
-IntegerType IntegerType::getChecked(Location location, unsigned width,
-                                    SignednessSemantics signedness) {
-  if (auto cached =
-          getCachedIntegerType(width, signedness, location->getContext()))
+IntegerType
+IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                        MLIRContext *context, unsigned width,
+                        SignednessSemantics signedness) {
+  if (auto cached = getCachedIntegerType(width, signedness, context))
     return cached;
-  return Base::getChecked(location, width, signedness);
+  return Base::getChecked(emitError, context, width, signedness);
 }
 
 /// Get an instance of the NoneType.
@@ -1005,11 +1006,14 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
 // StorageUniquerSupport
 //===----------------------------------------------------------------------===//
 
-/// Utility method to generate a default location for use when checking the
-/// construction invariants of a storage object. This is defined out-of-line to
-/// avoid the need to include Location.h.
-const AttributeStorage *
-mlir::detail::generateUnknownStorageLocation(MLIRContext *ctx) {
-  return reinterpret_cast<const AttributeStorage *>(
-      ctx->getImpl().unknownLocAttr.getAsOpaquePointer());
+/// Utility method to generate a callback that can be used to generate a
+/// diagnostic when checking the construction invariants of a storage object.
+/// This is defined out-of-line to avoid the need to include Location.h.
+llvm::unique_function<InFlightDiagnostic()>
+mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
+  return [ctx] { return emitError(UnknownLoc::get(ctx)); };
+}
+llvm::unique_function<InFlightDiagnostic()>
+mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
+  return [=] { return emitError(loc); };
 }

diff  --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 7cf0596e42d7..16efd0aedf5e 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -524,9 +524,9 @@ Attribute Parser::parseExtendedAttr(Type type) {
 
         // Otherwise, form a new opaque attribute.
         return OpaqueAttr::getChecked(
+            getEncodedSourceLocation(loc),
             Identifier::get(dialectName, state.context), symbolData,
-            attrType ? attrType : NoneType::get(state.context),
-            getEncodedSourceLocation(loc));
+            attrType ? attrType : NoneType::get(state.context));
       });
 
   // Ensure that the attribute has the same type as requested.
@@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
 
         // Otherwise, form a new opaque type.
         return OpaqueType::getChecked(
-            getEncodedSourceLocation(loc),
+            getEncodedSourceLocation(loc), state.context,
             Identifier::get(dialectName, state.context), symbolData);
       });
 }

diff  --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp
index d6adbc20ef76..d76748d85835 100644
--- a/mlir/lib/TableGen/TypeDef.cpp
+++ b/mlir/lib/TableGen/TypeDef.cpp
@@ -23,13 +23,6 @@ using namespace mlir::tblgen;
 // TypeBuilder
 //===----------------------------------------------------------------------===//
 
-/// Return an optional code body used for the `getChecked` variant of this
-/// builder.
-Optional<StringRef> TypeBuilder::getCheckedBody() const {
-  Optional<StringRef> body = def->getValueAsOptionalString("checkedBody");
-  return body && !body->empty() ? body : llvm::None;
-}
-
 /// Returns true if this builder is able to infer the MLIRContext parameter.
 bool TypeBuilder::hasInferredContextParameter() const {
   return def->getValueAsBit("hasInferredContextParam");
@@ -111,8 +104,8 @@ llvm::Optional<StringRef> TypeDef::getParserCode() const {
 bool TypeDef::genAccessors() const {
   return def->getValueAsBit("genAccessors");
 }
-bool TypeDef::genVerifyInvariantsDecl() const {
-  return def->getValueAsBit("genVerifyInvariantsDecl");
+bool TypeDef::genVerifyDecl() const {
+  return def->getValueAsBit("genVerifyDecl");
 }
 llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
   auto value = def->getValueAsString("extraClassDeclaration");

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 0e2c11a2ecb0..5d23bb5e2240 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -48,7 +48,7 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
 // An example of how one could implement a standard integer.
 def IntegerType : Test_Type<"TestInteger"> {
   let mnemonic = "int";
-  let genVerifyInvariantsDecl = 1;
+  let genVerifyDecl = 1;
   let parameters = (
     ins
     "unsigned":$width,
@@ -67,9 +67,7 @@ def IntegerType : Test_Type<"TestInteger"> {
   let builders = [
     TypeBuilder<(ins "unsigned":$width,
                      CArg<"SignednessSemantics", "Signless">:$signedness), [{
-      return Base::get($_ctxt, width, signedness);
-    }], [{
-      return Base::getChecked($_loc, width, signedness);
+      return $_get($_ctxt, width, signedness);
     }]>
   ];
   let skipDefaultBuilders = 1;
@@ -84,7 +82,7 @@ def IntegerType : Test_Type<"TestInteger"> {
     if ($_parser.parseInteger(width)) return Type();
     if ($_parser.parseGreater()) return Type();
     Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
-    return getChecked(loc, width, signedness);
+    return getChecked(loc, loc.getContext(), width, signedness);
   }];
 
   // Any extra code one wants in the type's class declaration.

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 094e5c9fc631..792ccfb0084d 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -112,8 +112,10 @@ static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
 }
 
 // Example type validity checker.
-LogicalResult TestIntegerType::verifyConstructionInvariants(
-    Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) {
+LogicalResult
+TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
+                        unsigned width,
+                        TestIntegerType::SignednessSemantics ss) {
   if (width > 8)
     return failure();
   return success();

diff  --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 5471519ab60a..98d97eb0852d 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -54,11 +54,11 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
       RTLValueType:$inner
   );
 
-  let genVerifyInvariantsDecl = 1;
+  let genVerifyDecl = 1;
 
 // DECL-LABEL: class CompoundAType : public ::mlir::Type
-// DECL: static CompoundAType getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
-// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static CompoundAType getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
+// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
 // DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
 // DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
 // DECL-NEXT: ::mlir::DialectAsmParser &parser);
@@ -95,7 +95,7 @@ def D_SingleParameterType : TestType<"SingleParameter"> {
 
 def E_IntegerType : TestType<"Integer"> {
     let mnemonic = "int";
-    let genVerifyInvariantsDecl = 1;
+    let genVerifyDecl = 1;
     let parameters = (
         ins
         "SignednessSemantics":$signedness,

diff  --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
index 9cbd53270983..f26d6a2fe07e 100644
--- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
@@ -182,27 +182,29 @@ static const char *const typeDefParsePrint = R"(
     void print(::mlir::DialectAsmPrinter &printer) const;
 )";
 
-/// The code block for the verifyConstructionInvariants and getChecked.
+/// The code block for the verify method declaration.
 ///
-/// {0}: The name of the typeDef class.
-/// {1}: List of parameters, parameters style.
+/// {0}: List of parameters, parameters style.
 static const char *const typeDefDeclVerifyStr = R"(
-    static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1});
+    using Base::getChecked;
+    static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0});
 )";
 
 /// Emit the builders for the given type.
 static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
                                  TypeParamCommaFormatter &paramTypes) {
   StringRef typeClass = typeDef.getCppClassName();
-  bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
+  bool genCheckedMethods = typeDef.genVerifyDecl();
   if (!typeDef.skipDefaultBuilders()) {
     os << llvm::formatv(
         "    static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
         paramTypes);
     if (genCheckedMethods) {
-      os << llvm::formatv(
-          "    static {0} getChecked(::mlir::Location loc{1});\n", typeClass,
-          paramTypes);
+      os << llvm::formatv("    static {0} "
+                          "getChecked(llvm::function_ref<::mlir::"
+                          "InFlightDiagnostic()> emitError, "
+                          "::mlir::MLIRContext *context{1});\n",
+                          typeClass, paramTypes);
     }
   }
 
@@ -231,10 +233,14 @@ static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
 
     // Generate the `getChecked` variant of the builder.
     if (genCheckedMethods) {
-      os << "    static " << typeClass << " getChecked(::mlir::Location loc";
+      os << "    static " << typeClass
+         << " getChecked(llvm::function_ref<mlir::InFlightDiagnostic()> "
+            "emitError";
+      if (!builder.hasInferredContextParameter())
+        os << ", ::mlir::MLIRContext *context";
       if (!paramStr.empty())
-        os << ", " << paramStr;
-      os << ");\n";
+        os << ", ";
+      os << paramStr << ");\n";
     }
   }
 }
@@ -265,9 +271,8 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
     emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
 
     // Emit the verify invariants declaration.
-    if (typeDef.genVerifyInvariantsDecl())
-      os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(),
-                          emitTypeNamePairsAfterComma);
+    if (typeDef.genVerifyDecl())
+      os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
   }
 
   // Emit the mnenomic, if specified.
@@ -515,10 +520,18 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
   }
 }
 
+/// Replace all instances of 'from' to 'to' in `str` and return the new string.
+static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
+  size_t pos = 0;
+  while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
+    str.replace(pos, from.size(), to.data(), to.size());
+  return str;
+}
+
 /// Emit the builders for the given type.
 static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
                                 ArrayRef<TypeParameter> typeDefParams) {
-  bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
+  bool genCheckedMethods = typeDef.genVerifyDecl();
   StringRef typeClass = typeDef.getCppClassName();
   if (!typeDef.skipDefaultBuilders()) {
     os << llvm::formatv(
@@ -531,8 +544,10 @@ static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
                                 typeDefParams));
     if (genCheckedMethods) {
       os << llvm::formatv(
-          "{0} {0}::getChecked(::mlir::Location loc{1}) {{\n"
-          "  return Base::getChecked(loc{2});\n}\n",
+          "{0} {0}::getChecked("
+          "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, "
+          "::mlir::MLIRContext *context{1}) {{\n"
+          "  return Base::getChecked(emitError, context{2});\n}\n",
           typeClass,
           TypeParamCommaFormatter(
               TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
@@ -542,16 +557,15 @@ static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
     }
   }
 
+  auto builderFmtCtx =
+      FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get");
+  auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get");
+  auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context");
+
   // Generate the builders specified by the user.
-  auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context");
-  auto checkedBuilderFmtCtx = FmtContext()
-                                  .addSubst("_loc", "loc")
-                                  .addSubst("_ctxt", "loc.getContext()");
   for (const TypeBuilder &builder : typeDef.getBuilders()) {
     Optional<StringRef> body = builder.getBody();
-    Optional<StringRef> checkedBody =
-        genCheckedMethods ? builder.getCheckedBody() : llvm::None;
-    if (!body && !checkedBody)
+    if (!body)
       continue;
     std::string paramStr;
     llvm::raw_string_ostream paramOS(paramStr);
@@ -565,27 +579,33 @@ static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
     paramOS.flush();
 
     // Emit the `get` variant of the builder.
-    if (body) {
-      os << llvm::formatv("{0} {0}::get(", typeClass);
-      if (!builder.hasInferredContextParameter()) {
-        os << "::mlir::MLIRContext *context";
-        if (!paramStr.empty())
-          os << ", ";
-        os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr,
-                            tgfmt(*body, &builderFmtCtx).str());
-      } else {
-        os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr, *body);
-      }
+    os << llvm::formatv("{0} {0}::get(", typeClass);
+    if (!builder.hasInferredContextParameter()) {
+      os << "::mlir::MLIRContext *context";
+      if (!paramStr.empty())
+        os << ", ";
+      os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr,
+                          tgfmt(*body, &builderFmtCtx).str());
+    } else {
+      os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr,
+                          tgfmt(*body, &inferredCtxBuilderFmtCtx).str());
     }
 
     // Emit the `getChecked` variant of the builder.
-    if (checkedBody) {
-      os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc",
+    if (genCheckedMethods) {
+      os << llvm::formatv("{0} "
+                          "{0}::getChecked(llvm::function_ref<::mlir::"
+                          "InFlightDiagnostic()> emitErrorFn",
                           typeClass);
+      std::string checkedBody =
+          replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, ");
+      if (!builder.hasInferredContextParameter()) {
+        os << ", ::mlir::MLIRContext *context";
+        checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str();
+      }
       if (!paramStr.empty())
-        os << ", " << paramStr;
-      os << llvm::formatv(") {{\n  {0};\n}\n",
-                          tgfmt(*checkedBody, &checkedBuilderFmtCtx));
+        os << ", ";
+      os << llvm::formatv("{0}) {{\n  {1};\n}\n", paramStr, checkedBody);
     }
   }
 }


        


More information about the Mlir-commits mailing list