[Mlir-commits] [mlir] 70d8fec - [mlir] Refactor the structure of the 'verifyConstructionInvariants' methods.

River Riddle llvmlistbot at llvm.org
Thu Feb 20 10:38:29 PST 2020


Author: River Riddle
Date: 2020-02-20T10:37:52-08:00
New Revision: 70d8fec7c947996e857aa136aa22c22a555b02fa

URL: https://github.com/llvm/llvm-project/commit/70d8fec7c947996e857aa136aa22c22a555b02fa
DIFF: https://github.com/llvm/llvm-project/commit/70d8fec7c947996e857aa136aa22c22a555b02fa.diff

LOG: [mlir] Refactor the structure of the 'verifyConstructionInvariants' methods.

Summary:
The current structure suffers from several problems, but the main one is that a construction failure is impossible to debug when using the 'get' methods. This is because we only optionally emit errors, so there is no context given to the user about the problem. This revision restructures this so that errors are always emitted, and the 'get' methods simply pass in an UnknownLoc to emit to. This allows for removing usages of the more constrained "emitOptionalLoc", as well as removing the need for the context parameter.

Fixes [PR#44964](https://bugs.llvm.org/show_bug.cgi?id=44964)

Differential Revision: https://reviews.llvm.org/D74876

Added: 
    

Modified: 
    mlir/docs/DefiningAttributesAndTypes.md
    mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
    mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/Location.h
    mlir/include/mlir/IR/StandardTypes.h
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/Types.h
    mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
    mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/IR/Types.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DefiningAttributesAndTypes.md b/mlir/docs/DefiningAttributesAndTypes.md
index be8d550bcf3a..dfa0357c583d 100644
--- a/mlir/docs/DefiningAttributesAndTypes.md
+++ b/mlir/docs/DefiningAttributesAndTypes.md
@@ -194,42 +194,34 @@ public:
   /// This method is used to get an instance of the 'ComplexType'. This method
   /// asserts that all of the construction invariants were satisfied. To
   /// gracefully handle failed construction, getChecked should be used instead.
-  static ComplexType get(MLIRContext *context, unsigned param, Type type) {
+  static ComplexType get(unsigned param, Type type) {
     // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
     // of this type. All parameters to the storage class are passed after the
     // type kind.
-    return Base::get(context, MyTypes::Complex, param, type);
+    return Base::get(type.getContext(), MyTypes::Complex, param, type);
   }
 
   /// This method is used to get an instance of the 'ComplexType', defined at
   /// the given location. If any of the construction invariants are invalid,
   /// errors are emitted with the provided location and a null type is returned.
   /// Note: This method is completely optional.
-  static ComplexType getChecked(MLIRContext *context, unsigned param, Type type,
-                                Location location) {
+  static ComplexType getChecked(unsigned param, Type type, Location location) {
     // Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued
     // instance of this type. All parameters to the storage class are passed
     // after the type kind.
-    return Base::getChecked(location, context, MyTypes::Complex, param, type);
+    return Base::getChecked(location, MyTypes::Complex, param, type);
   }
 
   /// This method is used to verify the construction invariants passed into the
   /// 'get' and 'getChecked' methods. Note: This method is completely optional.
   static LogicalResult verifyConstructionInvariants(
-      llvm::Optional<Location> loc, MLIRContext *context, unsigned param,
-      Type type) {
+      Location loc, unsigned param, Type type) {
     // Our type only allows non-zero parameters.
-    if (param == 0) {
-      if (loc)
-        context->emitError(loc) << "non-zero parameter passed to 'ComplexType'";
-      return failure();
-    }
+    if (param == 0)
+      return emitError(loc) << "non-zero parameter passed to 'ComplexType'";
     // Our type also expects an integer type.
-    if (!type.isa<IntegerType>()) {
-      if (loc)
-        context->emitError(loc) << "non integer-type passed to 'ComplexType'";
-      return failure();
-    }
+    if (!type.isa<IntegerType>())
+      return emitError(loc) << "non integer-type passed to 'ComplexType'";
     return success();
   }
 

diff  --git a/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h b/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
index c3a619ece0ef..b533198ef078 100644
--- a/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
@@ -66,8 +66,7 @@ class QuantizedType : public Type {
   static constexpr unsigned MaxStorageBits = 32;
 
   static LogicalResult
-  verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
-                               unsigned flags, Type storageType,
+  verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
                                Type expressedType, int64_t storageTypeMin,
                                int64_t storageTypeMax);
 
@@ -229,8 +228,7 @@ class AnyQuantizedType
 
   /// Verifies construction invariants and issues errors/warnings.
   static LogicalResult
-  verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
-                               unsigned flags, Type storageType,
+  verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
                                Type expressedType, int64_t storageTypeMin,
                                int64_t storageTypeMax);
 };
@@ -288,10 +286,11 @@ class UniformQuantizedType
              Location location);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verifyConstructionInvariants(
-      Optional<Location> loc, MLIRContext *context, unsigned flags,
-      Type storageType, Type expressedType, double scale, int64_t zeroPoint,
-      int64_t storageTypeMin, int64_t storageTypeMax);
+  static LogicalResult
+  verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
+                               Type expressedType, double scale,
+                               int64_t zeroPoint, int64_t storageTypeMin,
+                               int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
   static bool kindof(unsigned kind) {
@@ -351,11 +350,12 @@ class UniformQuantizedPerAxisType
              int64_t storageTypeMax, Location location);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verifyConstructionInvariants(
-      Optional<Location> loc, MLIRContext *context, unsigned flags,
-      Type storageType, Type expressedType, ArrayRef<double> scales,
-      ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
-      int64_t storageTypeMin, int64_t storageTypeMax);
+  static LogicalResult
+  verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
+                               Type expressedType, ArrayRef<double> scales,
+                               ArrayRef<int64_t> zeroPoints,
+                               int32_t quantizedDimension,
+                               int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
   static bool kindof(unsigned kind) {

diff  --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
index 073e0f509cba..a20fe6e207dd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
@@ -90,10 +90,11 @@ class TargetEnvAttr
 
   static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; }
 
-  static LogicalResult
-  verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
-                               IntegerAttr version, ArrayAttr extensions,
-                               ArrayAttr capabilities, DictionaryAttr limits);
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    IntegerAttr version,
+                                                    ArrayAttr extensions,
+                                                    ArrayAttr capabilities,
+                                                    DictionaryAttr limits);
 };
 
 /// Returns the attribute name for specifying argument ABI information.

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 34e0152e17d0..8a3b6cfa98b5 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -330,11 +330,9 @@ class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
   }
 
   /// Verify the construction invariants for a double value.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *ctx, Type type,
+  static LogicalResult verifyConstructionInvariants(Location loc, Type type,
                                                     double value);
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *ctx, Type type,
+  static LogicalResult verifyConstructionInvariants(Location loc, Type type,
                                                     const APFloat &value);
 };
 
@@ -361,11 +359,9 @@ class IntegerAttr
     return kind == StandardAttributes::Integer;
   }
 
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *ctx, Type type,
+  static LogicalResult verifyConstructionInvariants(Location loc, Type type,
                                                     int64_t value);
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *ctx, Type type,
+  static LogicalResult verifyConstructionInvariants(Location loc, Type type,
                                                     const APInt &value);
 };
 
@@ -419,8 +415,7 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
   StringRef getAttrData() const;
 
   /// Verify the construction of an opaque attribute.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *context,
+  static LogicalResult verifyConstructionInvariants(Location loc,
                                                     Identifier dialect,
                                                     StringRef attrData,
                                                     Type type);

diff  --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index 57e38e838b34..902bcb609395 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -54,6 +54,12 @@ class Location {
   Location(LocationAttr loc) : impl(loc) {
     assert(loc && "location should never be null.");
   }
+  Location(const LocationAttr::ImplType *impl) : impl(impl) {
+    assert(impl && "location should never be null.");
+  }
+
+  /// Return the context this location is uniqued in.
+  MLIRContext *getContext() const { return impl.getContext(); }
 
   /// Access the impl location attribute.
   operator LocationAttr() const { return impl; }

diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index e789ab729e15..7356c7be75b1 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -96,8 +96,7 @@ class IntegerType
                                 Location location);
 
   /// Verify the construction of an integer type.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *context,
+  static LogicalResult verifyConstructionInvariants(Location loc,
                                                     unsigned width);
 
   /// Return the bitwidth of this integer type.
@@ -162,8 +161,7 @@ class ComplexType
   static ComplexType getChecked(Type elementType, Location location);
 
   /// Verify the construction of an integer type.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *context,
+  static LogicalResult verifyConstructionInvariants(Location loc,
                                                     Type elementType);
 
   Type getElementType();
@@ -270,8 +268,7 @@ class VectorType
                                Location location);
 
   /// Verify the construction of a vector type.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *context,
+  static LogicalResult verifyConstructionInvariants(Location loc,
                                                     ArrayRef<int64_t> shape,
                                                     Type elementType);
 
@@ -329,8 +326,7 @@ class RankedTensorType
                                      Location location);
 
   /// Verify the construction of a ranked tensor type.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *context,
+  static LogicalResult verifyConstructionInvariants(Location loc,
                                                     ArrayRef<int64_t> shape,
                                                     Type elementType);
 
@@ -360,8 +356,7 @@ class UnrankedTensorType
   static UnrankedTensorType getChecked(Type elementType, Location location);
 
   /// Verify the construction of a unranked tensor type.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *context,
+  static LogicalResult verifyConstructionInvariants(Location loc,
                                                     Type elementType);
 
   ArrayRef<int64_t> getShape() const { return llvm::None; }
@@ -505,8 +500,7 @@ class UnrankedMemRefType
                                        Location location);
 
   /// Verify the construction of a unranked memref type.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *context,
+  static LogicalResult verifyConstructionInvariants(Location loc,
                                                     Type elementType,
                                                     unsigned memorySpace);
 

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 7f92eb1b2fe0..7520f8e053ed 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -18,10 +18,15 @@
 #include "mlir/Support/StorageUniquer.h"
 
 namespace mlir {
-class Location;
+class AttributeStorage;
 class MLIRContext;
 
 namespace detail {
+/// Utility method to generate a raw default location for use when checking the
+/// construction invariants of a storage object. This is defined out-of-line to
+/// avoid the need to include Location.h.
+const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx);
+
 /// Utility class for implementing users of storage classes uniqued by a
 /// StorageUniquer. Clients are not expected to interact with this class
 /// directly.
@@ -53,21 +58,20 @@ class StorageUserBase : public BaseT {
   template <typename... Args>
   static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
     // Ensure that the invariants are correct for construction.
-    assert(succeeded(
-        ConcreteT::verifyConstructionInvariants(llvm::None, ctx, args...)));
+    assert(succeeded(ConcreteT::verifyConstructionInvariants(
+        generateUnknownStorageLocation(ctx), args...)));
     return UniquerT::template get<ConcreteT>(ctx, kind, args...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx, defined at
   /// the given, potentially unknown, location. If the arguments provided are
   /// invalid then emit errors and return a null object.
-  template <typename... Args>
-  static ConcreteT getChecked(const Location &loc, MLIRContext *ctx,
-                              unsigned kind, Args... args) {
+  template <typename LocationT, typename... Args>
+  static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) {
     // If the construction invariants fail then we return a null attribute.
-    if (failed(ConcreteT::verifyConstructionInvariants(loc, ctx, args...)))
+    if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
       return ConcreteT();
-    return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+    return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
   }
 
   /// Default implementation that just returns success.

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index ac5f0dbe2d04..abbc282b35d1 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -46,10 +46,8 @@ struct OpaqueTypeStorage;
 ///        current type. Used for isa/dyn_cast casting functionality.
 ///
 ///  * Optional:
-///    - static LogicalResult verifyConstructionInvariants(
-///                                               Optional<Location> loc,
-///                                               MLIRContext *context,
-///                                               Args... args)
+///    - static LogicalResult verifyConstructionInvariants(Location loc,
+///                                                        Args... args)
 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
 ///        methods to ensure that the arguments passed in are valid to construct
 ///        a type instance with.
@@ -238,8 +236,7 @@ class OpaqueType
   StringRef getTypeData() const;
 
   /// Verify the construction of an opaque type.
-  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
-                                                    MLIRContext *context,
+  static LogicalResult verifyConstructionInvariants(Location loc,
                                                     Identifier dialect,
                                                     StringRef typeData);
 

diff  --git a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
index bdbfc2a30fc3..d8cee1afe255 100644
--- a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
@@ -24,20 +24,19 @@ unsigned QuantizedType::getFlags() const {
 }
 
 LogicalResult QuantizedType::verifyConstructionInvariants(
-    Optional<Location> loc, MLIRContext *context, unsigned flags,
-    Type storageType, Type expressedType, int64_t storageTypeMin,
-    int64_t storageTypeMax) {
+    Location loc, unsigned flags, Type storageType, Type expressedType,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
   // Verify that the storage type is integral.
   // This restriction may be lifted at some point in favor of using bf16
   // or f16 as exact representations on hardware where that is advantageous.
   auto intStorageType = storageType.dyn_cast<IntegerType>();
   if (!intStorageType)
-    return emitOptionalError(loc, "storage type must be integral");
+    return emitError(loc, "storage type must be integral");
   unsigned integralWidth = intStorageType.getWidth();
 
   // Verify storage width.
   if (integralWidth == 0 || integralWidth > MaxStorageBits)
-    return emitOptionalError(loc, "illegal storage type size: ", integralWidth);
+    return emitError(loc, "illegal storage type size: ") << integralWidth;
 
   // Verify storageTypeMin and storageTypeMax.
   bool isSigned =
@@ -49,8 +48,8 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
   if (storageTypeMax - storageTypeMin <= 0 ||
       storageTypeMin < defaultIntegerMin ||
       storageTypeMax > defaultIntegerMax) {
-    return emitOptionalError(loc, "illegal storage min and storage max: (",
-                             storageTypeMin, ":", storageTypeMax, ")");
+    return emitError(loc, "illegal storage min and storage max: (")
+           << storageTypeMin << ":" << storageTypeMax << ")";
   }
   return success();
 }
@@ -209,17 +208,15 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
                                               int64_t storageTypeMin,
                                               int64_t storageTypeMax,
                                               Location location) {
-  return Base::getChecked(location, storageType.getContext(),
-                          QuantizationTypes::Any, flags, storageType,
+  return Base::getChecked(location, QuantizationTypes::Any, flags, storageType,
                           expressedType, storageTypeMin, storageTypeMax);
 }
 
 LogicalResult AnyQuantizedType::verifyConstructionInvariants(
-    Optional<Location> loc, MLIRContext *context, unsigned flags,
-    Type storageType, Type expressedType, int64_t storageTypeMin,
-    int64_t storageTypeMax) {
+    Location loc, unsigned flags, Type storageType, Type expressedType,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
   if (failed(QuantizedType::verifyConstructionInvariants(
-          loc, context, flags, storageType, expressedType, storageTypeMin,
+          loc, flags, storageType, expressedType, storageTypeMin,
           storageTypeMax))) {
     return failure();
   }
@@ -228,7 +225,7 @@ LogicalResult AnyQuantizedType::verifyConstructionInvariants(
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
   if (expressedType && !expressedType.isa<FloatType>())
-    return emitOptionalError(loc, "expressed type must be floating point");
+    return emitError(loc, "expressed type must be floating point");
 
   return success();
 }
@@ -249,18 +246,17 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
                                  Type expressedType, double scale,
                                  int64_t zeroPoint, int64_t storageTypeMin,
                                  int64_t storageTypeMax, Location location) {
-  return Base::getChecked(location, storageType.getContext(),
-                          QuantizationTypes::UniformQuantized, flags,
+  return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags,
                           storageType, expressedType, scale, zeroPoint,
                           storageTypeMin, storageTypeMax);
 }
 
 LogicalResult UniformQuantizedType::verifyConstructionInvariants(
-    Optional<Location> loc, MLIRContext *context, unsigned flags,
-    Type storageType, Type expressedType, double scale, int64_t zeroPoint,
-    int64_t storageTypeMin, int64_t storageTypeMax) {
+    Location loc, unsigned flags, Type storageType, Type expressedType,
+    double scale, int64_t zeroPoint, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
   if (failed(QuantizedType::verifyConstructionInvariants(
-          loc, context, flags, storageType, expressedType, storageTypeMin,
+          loc, flags, storageType, expressedType, storageTypeMin,
           storageTypeMax))) {
     return failure();
   }
@@ -268,18 +264,17 @@ LogicalResult UniformQuantizedType::verifyConstructionInvariants(
   // Uniform quantization requires fully expressed parameters, including
   // expressed type.
   if (!expressedType)
-    return emitOptionalError(loc,
-                             "uniform quantization requires expressed type");
+    return emitError(loc, "uniform quantization requires expressed type");
 
   // Verify that the expressed type is floating point.
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
   if (!expressedType.isa<FloatType>())
-    return emitOptionalError(loc, "expressed type must be floating point");
+    return emitError(loc, "expressed type must be floating point");
 
   // Verify scale.
   if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
-    return emitOptionalError(loc, "illegal scale: ", scale);
+    return emitError(loc, "illegal scale: ") << scale;
 
   return success();
 }
@@ -306,19 +301,18 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
     int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
     Location location) {
-  return Base::getChecked(location, storageType.getContext(),
-                          QuantizationTypes::UniformQuantizedPerAxis, flags,
-                          storageType, expressedType, scales, zeroPoints,
+  return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis,
+                          flags, storageType, expressedType, scales, zeroPoints,
                           quantizedDimension, storageTypeMin, storageTypeMax);
 }
 
 LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
-    Optional<Location> loc, MLIRContext *context, unsigned flags,
-    Type storageType, Type expressedType, ArrayRef<double> scales,
-    ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
-    int64_t storageTypeMin, int64_t storageTypeMax) {
+    Location loc, unsigned flags, Type storageType, Type expressedType,
+    ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+    int32_t quantizedDimension, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
   if (failed(QuantizedType::verifyConstructionInvariants(
-          loc, context, flags, storageType, expressedType, storageTypeMin,
+          loc, flags, storageType, expressedType, storageTypeMin,
           storageTypeMax))) {
     return failure();
   }
@@ -326,24 +320,23 @@ LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
   // Uniform quantization requires fully expressed parameters, including
   // expressed type.
   if (!expressedType)
-    return emitOptionalError(loc,
-                             "uniform quantization requires expressed type");
+    return emitError(loc, "uniform quantization requires expressed type");
 
   // Verify that the expressed type is floating point.
   // If this restriction is ever eliminated, the parser/printer must be
   // extended.
   if (!expressedType.isa<FloatType>())
-    return emitOptionalError(loc, "expressed type must be floating point");
+    return emitError(loc, "expressed type must be floating point");
 
   // Ensure that the number of scales and zeroPoints match.
   if (scales.size() != zeroPoints.size())
-    return emitOptionalError(loc, "illegal number of scales and zeroPoints: ",
-                             scales.size(), ", ", zeroPoints.size());
+    return emitError(loc, "illegal number of scales and zeroPoints: ")
+           << scales.size() << ", " << zeroPoints.size();
 
   // Verify scale.
   for (double scale : scales) {
     if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
-      return emitOptionalError(loc, "illegal scale: ", scale);
+      return emitError(loc, "illegal scale: ") << scale;
   }
 
   return success();

diff  --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index fbb8a93956d1..5c44d60b548b 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -103,10 +103,10 @@ DictionaryAttr spirv::TargetEnvAttr::getResourceLimits() {
 }
 
 LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
-    Optional<Location> loc, MLIRContext *context, IntegerAttr version,
-    ArrayAttr extensions, ArrayAttr capabilities, DictionaryAttr limits) {
+    Location loc, IntegerAttr version, ArrayAttr extensions,
+    ArrayAttr capabilities, DictionaryAttr limits) {
   if (!version.getType().isInteger(32))
-    return emitOptionalError(loc, "expected 32-bit integer for version");
+    return emitError(loc, "expected 32-bit integer for version");
 
   if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
         if (auto strAttr = attr.dyn_cast<StringAttr>())
@@ -114,7 +114,7 @@ LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
             return true;
         return false;
       }))
-    return emitOptionalError(loc, "unknown extension in extension list");
+    return emitError(loc, "unknown extension in extension list");
 
   if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
         if (auto intAttr = attr.dyn_cast<IntegerAttr>())
@@ -122,11 +122,10 @@ LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
             return true;
         return false;
       }))
-    return emitOptionalError(loc, "unknown capability in capability list");
+    return emitError(loc, "unknown capability in capability list");
 
   if (!limits.isa<spirv::ResourceLimitsAttr>())
-    return emitOptionalError(loc,
-                             "expected spirv::ResourceLimitsAttr for limits");
+    return emitError(loc, "expected spirv::ResourceLimitsAttr for limits");
 
   return success();
 }

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 229738aaa058..4833dee5a951 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -182,8 +182,7 @@ FloatAttr FloatAttr::get(Type type, double value) {
 }
 
 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
-  return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
-                          type, value);
+  return Base::getChecked(loc, StandardAttributes::Float, type, value);
 }
 
 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
@@ -191,8 +190,7 @@ FloatAttr FloatAttr::get(Type type, const APFloat &value) {
 }
 
 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
-  return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
-                          type, value);
+  return Base::getChecked(loc, StandardAttributes::Float, type, value);
 }
 
 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
@@ -210,22 +208,18 @@ double FloatAttr::getValueAsDouble(APFloat value) {
 }
 
 /// Verify construction invariants.
-static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc,
-                                               Type type) {
+static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
   if (!type.isa<FloatType>())
-    return emitOptionalError(loc, "expected floating point type");
+    return emitError(loc, "expected floating point type");
   return success();
 }
 
-LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
-                                                      MLIRContext *ctx,
-                                                      Type type, double value) {
+LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
+                                                      double value) {
   return verifyFloatTypeInvariants(loc, type);
 }
 
-LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
-                                                      MLIRContext *ctx,
-                                                      Type type,
+LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
                                                       const APFloat &value) {
   // Verify that the type is correct.
   if (failed(verifyFloatTypeInvariants(loc, type)))
@@ -233,7 +227,7 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
 
   // Verify that the type semantics match that of the value.
   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
-    return emitOptionalError(
+    return emitError(
         loc, "FloatAttr type doesn't match the type implied by its value");
   }
   return success();
@@ -286,31 +280,26 @@ APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
 
 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
 
-static LogicalResult verifyIntegerTypeInvariants(Optional<Location> loc,
-                                                 Type type) {
+static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
   if (type.isa<IntegerType>() || type.isa<IndexType>())
     return success();
-  return emitOptionalError(loc, "expected integer or index type");
+  return emitError(loc, "expected integer or index type");
 }
 
-LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc,
-                                                        MLIRContext *ctx,
-                                                        Type type,
+LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
                                                         int64_t value) {
   return verifyIntegerTypeInvariants(loc, type);
 }
 
-LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc,
-                                                        MLIRContext *ctx,
-                                                        Type type,
+LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
                                                         const APInt &value) {
   if (failed(verifyIntegerTypeInvariants(loc, type)))
     return failure();
   if (auto integerType = type.dyn_cast<IntegerType>())
     if (integerType.getWidth() != value.getBitWidth())
-      return emitOptionalError(
-          loc, "integer type bit width (", integerType.getWidth(),
-          ") doesn't match value bit width (", value.getBitWidth(), ")");
+      return emitError(loc, "integer type bit width (")
+             << integerType.getWidth() << ") doesn't match value bit width ("
+             << value.getBitWidth() << ")";
   return success();
 }
 
@@ -337,8 +326,8 @@ OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
 
 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
                                   Type type, Location location) {
-  return Base::getChecked(location, type.getContext(),
-                          StandardAttributes::Opaque, dialect, attrData, type);
+  return Base::getChecked(location, StandardAttributes::Opaque, dialect,
+                          attrData, type);
 }
 
 /// Returns the dialect namespace of the opaque attribute.
@@ -350,13 +339,12 @@ Identifier OpaqueAttr::getDialectNamespace() const {
 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
 
 /// Verify the construction of an opaque attribute.
-LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc,
-                                                       MLIRContext *context,
+LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
                                                        Identifier dialect,
                                                        StringRef attrData,
                                                        Type type) {
   if (!Dialect::isValidNamespace(dialect.strref()))
-    return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
+    return emitError(loc, "invalid dialect namespace '") << dialect << "'";
   return success();
 }
 

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index aa82726bf10f..311fcc35ebdd 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -518,7 +518,7 @@ IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
                                     Location location) {
   if (auto cached = getCachedIntegerType(width, context))
     return cached;
-  return Base::getChecked(location, context, StandardTypes::Integer, width);
+  return Base::getChecked(location, StandardTypes::Integer, width);
 }
 
 /// Get an instance of the NoneType.
@@ -639,3 +639,16 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
   llvm::sys::SmartScopedWriter<true> affineLock(impl.affineMutex);
   return constructorFn();
 }
+
+//===----------------------------------------------------------------------===//
+// StorageUniquerSupport
+//===----------------------------------------------------------------------===//
+
+/// Utility method to generate a default location for use when checking the
+/// construction invariants of a storage object. This is defined out-of-line to
+/// avoid the need to include Location.h.
+const AttributeStorage *
+mlir::detail::generateUnknownStorageLocation(MLIRContext *ctx) {
+  return reinterpret_cast<const AttributeStorage *>(
+      ctx->getImpl().unknownLocAttr.getAsOpaquePointer());
+}

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index c4269af9030e..cb9386febde8 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -52,12 +52,11 @@ bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
 constexpr unsigned IntegerType::kMaxWidth;
 
 /// Verify the construction of an integer type.
-LogicalResult IntegerType::verifyConstructionInvariants(Optional<Location> loc,
-                                                        MLIRContext *context,
+LogicalResult IntegerType::verifyConstructionInvariants(Location loc,
                                                         unsigned width) {
   if (width > IntegerType::kMaxWidth) {
-    return emitOptionalError(loc, "integer bitwidth is limited to ",
-                             IntegerType::kMaxWidth, " bits");
+    return emitError(loc) << "integer bitwidth is limited to "
+                          << IntegerType::kMaxWidth << " bits";
   }
   return success();
 }
@@ -203,24 +202,20 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
 
 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
                                   Location location) {
-  return Base::getChecked(location, elementType.getContext(),
-                          StandardTypes::Vector, shape, elementType);
+  return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
 }
 
-LogicalResult VectorType::verifyConstructionInvariants(Optional<Location> loc,
-                                                       MLIRContext *context,
+LogicalResult VectorType::verifyConstructionInvariants(Location loc,
                                                        ArrayRef<int64_t> shape,
                                                        Type elementType) {
   if (shape.empty())
-    return emitOptionalError(loc,
-                             "vector types must have at least one dimension");
+    return emitError(loc, "vector types must have at least one dimension");
 
   if (!isValidElementType(elementType))
-    return emitOptionalError(loc, "vector elements must be int or float type");
+    return emitError(loc, "vector elements must be int or float type");
 
   if (any_of(shape, [](int64_t i) { return i <= 0; }))
-    return emitOptionalError(loc,
-                             "vector types must have positive constant sizes");
+    return emitError(loc, "vector types must have positive constant sizes");
 
   return success();
 }
@@ -233,11 +228,10 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
 
 // Check if "elementType" can be an element type of a tensor. Emit errors if
 // location is not nullptr.  Returns failure if check failed.
-static inline LogicalResult checkTensorElementType(Optional<Location> location,
-                                                   MLIRContext *context,
+static inline LogicalResult checkTensorElementType(Location location,
                                                    Type elementType) {
   if (!TensorType::isValidElementType(elementType))
-    return emitOptionalError(location, "invalid tensor element type");
+    return emitError(location, "invalid tensor element type");
   return success();
 }
 
@@ -254,18 +248,17 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
                                               Type elementType,
                                               Location location) {
-  return Base::getChecked(location, elementType.getContext(),
-                          StandardTypes::RankedTensor, shape, elementType);
+  return Base::getChecked(location, StandardTypes::RankedTensor, shape,
+                          elementType);
 }
 
 LogicalResult RankedTensorType::verifyConstructionInvariants(
-    Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
-    Type elementType) {
+    Location loc, ArrayRef<int64_t> shape, Type elementType) {
   for (int64_t s : shape) {
     if (s < -1)
-      return emitOptionalError(loc, "invalid tensor dimension size");
+      return emitError(loc, "invalid tensor dimension size");
   }
-  return checkTensorElementType(loc, context, elementType);
+  return checkTensorElementType(loc, elementType);
 }
 
 ArrayRef<int64_t> RankedTensorType::getShape() const {
@@ -283,13 +276,13 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
 
 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
                                                   Location location) {
-  return Base::getChecked(location, elementType.getContext(),
-                          StandardTypes::UnrankedTensor, elementType);
+  return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
 }
 
-LogicalResult UnrankedTensorType::verifyConstructionInvariants(
-    Optional<Location> loc, MLIRContext *context, Type elementType) {
-  return checkTensorElementType(loc, context, elementType);
+LogicalResult
+UnrankedTensorType::verifyConstructionInvariants(Location loc,
+                                                 Type elementType) {
+  return checkTensorElementType(loc, elementType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -399,8 +392,7 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
                                                   unsigned memorySpace,
                                                   Location location) {
-  return Base::getChecked(location, elementType.getContext(),
-                          StandardTypes::UnrankedMemRef, elementType,
+  return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
                           memorySpace);
 }
 
@@ -408,13 +400,13 @@ unsigned UnrankedMemRefType::getMemorySpace() const {
   return getImpl()->memorySpace;
 }
 
-LogicalResult UnrankedMemRefType::verifyConstructionInvariants(
-    Optional<Location> loc, MLIRContext *context, Type elementType,
-    unsigned memorySpace) {
+LogicalResult
+UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
+                                                 unsigned memorySpace) {
   // Check that memref is formed from allowed types.
   if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
       !elementType.isa<ComplexType>())
-    return emitOptionalError(*loc, "invalid memref element type");
+    return emitError(loc, "invalid memref element type");
   return success();
 }
 
@@ -621,16 +613,14 @@ ComplexType ComplexType::get(Type elementType) {
 }
 
 ComplexType ComplexType::getChecked(Type elementType, Location location) {
-  return Base::getChecked(location, elementType.getContext(),
-                          StandardTypes::Complex, elementType);
+  return Base::getChecked(location, StandardTypes::Complex, elementType);
 }
 
 /// Verify the construction of an integer type.
-LogicalResult ComplexType::verifyConstructionInvariants(Optional<Location> loc,
-                                                        MLIRContext *context,
+LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
                                                         Type elementType) {
   if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
-    return emitOptionalError(loc, "invalid element type for complex");
+    return emitError(loc, "invalid element type for complex");
   return success();
 }
 

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index dfe1a21f90bd..6fcf0016b0c5 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -59,7 +59,7 @@ OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
 
 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
                                   MLIRContext *context, Location location) {
-  return Base::getChecked(location, context, Kind::Opaque, dialect, typeData);
+  return Base::getChecked(location, Kind::Opaque, dialect, typeData);
 }
 
 /// Returns the dialect namespace of the opaque type.
@@ -71,11 +71,10 @@ Identifier OpaqueType::getDialectNamespace() const {
 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
 
 /// Verify the construction of an opaque type.
-LogicalResult OpaqueType::verifyConstructionInvariants(Optional<Location> loc,
-                                                       MLIRContext *context,
+LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
                                                        Identifier dialect,
                                                        StringRef typeData) {
   if (!Dialect::isValidNamespace(dialect.strref()))
-    return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
+    return emitError(loc, "invalid dialect namespace '") << dialect << "'";
   return success();
 }


        


More information about the Mlir-commits mailing list