[Mlir-commits] [mlir] 31bb8ef - [mlir][StorageUniquer] Properly call the destructor on non-trivially destructible storage instances

River Riddle llvmlistbot at llvm.org
Thu Mar 11 11:35:44 PST 2021


Author: River Riddle
Date: 2021-03-11T11:35:32-08:00
New Revision: 31bb8efd698304a8385ff79229ffbaa5613efdfb

URL: https://github.com/llvm/llvm-project/commit/31bb8efd698304a8385ff79229ffbaa5613efdfb
DIFF: https://github.com/llvm/llvm-project/commit/31bb8efd698304a8385ff79229ffbaa5613efdfb.diff

LOG: [mlir][StorageUniquer] Properly call the destructor on non-trivially destructible storage instances

This allows for storage instances to store data that isn't uniqued in the context, or contain otherwise non-trivial logic, in the rare situations that they occur. Storage instances with trivial destructors will still have their destructor skipped. A consequence of this is that the storage instance definition must be visible from the place that registers the type.

Differential Revision: https://reviews.llvm.org/D98311

Added: 
    mlir/unittests/Support/StorageUniquerTest.cpp

Modified: 
    flang/include/flang/Optimizer/Dialect/FIRDialect.h
    flang/lib/Optimizer/Dialect/FIRAttr.cpp
    flang/lib/Optimizer/Dialect/FIRDialect.cpp
    flang/lib/Optimizer/Dialect/FIRType.cpp
    mlir/docs/Tutorials/DefiningAttributesAndTypes.md
    mlir/docs/Tutorials/Toy/Ch-7.md
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/IR/BuiltinDialect.td
    mlir/include/mlir/Support/StorageUniquer.h
    mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinDialect.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/Location.cpp
    mlir/lib/Support/StorageUniquer.cpp
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/unittests/Support/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIRDialect.h b/flang/include/flang/Optimizer/Dialect/FIRDialect.h
index fb82d520dbc2..4bafb4ab7fb6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRDialect.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRDialect.h
@@ -32,6 +32,12 @@ class FIROpsDialect final : public mlir::Dialect {
                                  mlir::Type type) const override;
   void printAttribute(mlir::Attribute attr,
                       mlir::DialectAsmPrinter &p) const override;
+
+private:
+  // Register the Attributes of this dialect.
+  void registerAttributes();
+  // Register the Types of this dialect.
+  void registerTypes();
 };
 
 } // namespace fir

diff  --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 035245dbe935..a2fdf7cd43d0 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -243,3 +243,12 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
     os << "<(unknown attribute)>";
   }
 }
+
+//===----------------------------------------------------------------------===//
+// FIROpsDialect
+//===----------------------------------------------------------------------===//
+
+void FIROpsDialect::registerAttributes() {
+  addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
+                PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
+}

diff  --git a/flang/lib/Optimizer/Dialect/FIRDialect.cpp b/flang/lib/Optimizer/Dialect/FIRDialect.cpp
index 889b5ef55366..f80aa7d3380e 100644
--- a/flang/lib/Optimizer/Dialect/FIRDialect.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRDialect.cpp
@@ -19,13 +19,8 @@ using namespace fir;
 
 fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
     : mlir::Dialect("fir", ctx, mlir::TypeID::get<FIROpsDialect>()) {
-  addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
-           FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
-           PointerType, RealType, RecordType, ReferenceType, SequenceType,
-           ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
-           fir::VectorType>();
-  addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
-                PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
+  registerTypes();
+  registerAttributes();
   addOperations<
 #define GET_OP_LIST
 #include "flang/Optimizer/Dialect/FIROps.cpp.inc"

diff  --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 873f589e8a4b..beab54e4f1f8 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -866,3 +866,15 @@ mlir::LogicalResult fir::VectorType::verify(
 bool fir::VectorType::isValidElementType(mlir::Type t) {
   return isa_real(t) || isa_integer(t);
 }
+
+//===----------------------------------------------------------------------===//
+// FIROpsDialect
+//===----------------------------------------------------------------------===//
+
+void FIROpsDialect::registerTypes() {
+  addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
+           FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
+           PointerType, RealType, RecordType, ReferenceType, SequenceType,
+           ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
+           fir::VectorType>();
+}

diff  --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 6a261da8a6c2..7942fa2e1868 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -319,7 +319,9 @@ public:
 
 Once the dialect types have been defined, they must then be registered with a
 `Dialect`. This is done via a similar mechanism to
-[operations](LangRef.md#operations), with the `addTypes` method.
+[operations](LangRef.md#operations), with the `addTypes` method. The one
+distinct 
diff erence with operations, is that when a type is registered the
+definition of its storage class must be visible.
 
 ```c++
 struct MyDialect : public Dialect {

diff  --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md
index 7074521eb5f4..315cf3237b40 100644
--- a/mlir/docs/Tutorials/Toy/Ch-7.md
+++ b/mlir/docs/Tutorials/Toy/Ch-7.md
@@ -187,6 +187,9 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
 }
 ```
 
+(An important note here is that when registering a type, the definition of the
+storage class must be visible.)
+
 With this we can now use our `StructType` when generating MLIR from Toy. See
 examples/toy/Ch7/mlir/MLIRGen.cpp for more details.
 

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 8170b5b579cb..cbcf53f313c1 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -76,33 +76,6 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   }
 };
 
-//===----------------------------------------------------------------------===//
-// ToyDialect
-//===----------------------------------------------------------------------===//
-
-/// Dialect creation, the instance will be owned by the context. This is the
-/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
-    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
-  addOperations<
-#define GET_OP_LIST
-#include "toy/Ops.cpp.inc"
-      >();
-  addInterfaces<ToyInlinerInterface>();
-  addTypes<StructType>();
-}
-
-mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
-                                                 mlir::Attribute value,
-                                                 mlir::Type type,
-                                                 mlir::Location loc) {
-  if (type.isa<StructType>())
-    return builder.create<StructConstantOp>(loc, type,
-                                            value.cast<mlir::ArrayAttr>());
-  return builder.create<ConstantOp>(loc, type,
-                                    value.cast<mlir::DenseElementsAttr>());
-}
-
 //===----------------------------------------------------------------------===//
 // Toy Operations
 //===----------------------------------------------------------------------===//
@@ -566,3 +539,30 @@ void ToyDialect::printType(mlir::Type type,
 
 #define GET_OP_CLASSES
 #include "toy/Ops.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ToyDialect
+//===----------------------------------------------------------------------===//
+
+/// Dialect creation, the instance will be owned by the context. This is the
+/// point of registration of custom types and operations for the dialect.
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
+  addOperations<
+#define GET_OP_LIST
+#include "toy/Ops.cpp.inc"
+      >();
+  addInterfaces<ToyInlinerInterface>();
+  addTypes<StructType>();
+}
+
+mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
+                                                 mlir::Attribute value,
+                                                 mlir::Type type,
+                                                 mlir::Location loc) {
+  if (type.isa<StructType>())
+    return builder.create<StructConstantOp>(loc, type,
+                                            value.cast<mlir::ArrayAttr>());
+  return builder.create<ConstantOp>(loc, type,
+                                    value.cast<mlir::DenseElementsAttr>());
+}

diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
index c6d76be48494..afdf50673ed4 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
@@ -64,6 +64,9 @@ def PDL_Dialect : Dialect {
 
   let name = "pdl";
   let cppNamespace = "::mlir::pdl";
+  let extraClassDeclaration = [{
+    void registerTypes();
+  }];
 }
 
 #endif // MLIR_DIALECT_PDL_IR_PDLDIALECT

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index f18dcef1997a..d293a6a88afd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -52,6 +52,9 @@ def SPIRV_Dialect : Dialect {
   let hasRegionResultAttrVerify = 1;
 
   let extraClassDeclaration = [{
+    void registerAttributes();
+    void registerTypes();
+
     //===------------------------------------------------------------------===//
     // Attribute
     //===------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
index 383f87bd5d60..a52257c9d1f4 100644
--- a/mlir/include/mlir/IR/BuiltinDialect.td
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -22,6 +22,17 @@ def Builtin_Dialect : Dialect {
 
   let name = "";
   let cppNamespace = "::mlir";
+  let extraClassDeclaration = [{
+  private:
+    // Register the builtin Attributes.
+    void registerAttributes();
+    // Register the builtin Location Attributes.
+    void registerLocationAttributes();
+    // Register the builtin Types.
+    void registerTypes();
+
+  public:
+  }];
 }
 
 #endif // BUILTIN_BASE

diff  --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index fc7ffa74f3b5..2b66edb51ac9 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -135,7 +135,13 @@ class StorageUniquer {
   /// instances of this class type. `id` is the type identifier that will be
   /// used to identify this type when creating instances of it via 'get'.
   template <typename Storage> void registerParametricStorageType(TypeID id) {
-    registerParametricStorageTypeImpl(id);
+    // If the storage is trivially destructible, we don't need a destructor
+    // function.
+    if (std::is_trivially_destructible<Storage>::value)
+      return registerParametricStorageTypeImpl(id, nullptr);
+    registerParametricStorageTypeImpl(id, [](BaseStorage *storage) {
+      static_cast<Storage *>(storage)->~Storage();
+    });
   }
   /// Utility override when the storage type represents the type id.
   template <typename Storage> void registerParametricStorageType() {
@@ -244,8 +250,10 @@ class StorageUniquer {
       function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
 
   /// Implementation for registering an instance of a derived type with
-  /// parametric storage.
-  void registerParametricStorageTypeImpl(TypeID id);
+  /// parametric storage. This method takes an optional destructor function that
+  /// destructs storage instances when necessary.
+  void registerParametricStorageTypeImpl(
+      TypeID id, function_ref<void(BaseStorage *)> destructorFn);
 
   /// Implementation for getting an instance of a derived type with default
   /// storage.

diff  --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index a796b0725d36..895c3b0f4734 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -20,6 +20,12 @@
 
 using namespace mlir;
 
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
+
 void arm_sve::ArmSVEDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
@@ -31,12 +37,6 @@ void arm_sve::ArmSVEDialect::initialize() {
       >();
 }
 
-#define GET_OP_CLASSES
-#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
-
 //===----------------------------------------------------------------------===//
 // ScalableVectorType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 02b4687b35a4..bfc90b2ac674 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "TypeDetail.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index e82c4ab6fb16..beb43d7072f2 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -25,10 +25,7 @@ void PDLDialect::initialize() {
 #define GET_OP_LIST
 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
       >();
-  addTypes<
-#define GET_TYPEDEF_LIST
-#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
-      >();
+  registerTypes();
 }
 
 /// Returns true if the given operation is used by a "binding" pdl operation

diff  --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
index b16fade224fc..20f013af246f 100644
--- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
@@ -26,6 +26,13 @@ using namespace mlir::pdl;
 // PDLDialect
 //===----------------------------------------------------------------------===//
 
+void PDLDialect::registerTypes() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
+      >();
+}
+
 static Type parsePDLType(DialectAsmParser &parser) {
   StringRef typeTag;
   if (parser.parseKeyword(&typeTag))

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index c74c34d88dd7..a514b44a8991 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
 
@@ -350,3 +351,11 @@ spirv::TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Dialect
+//===----------------------------------------------------------------------===//
+
+void spirv::SPIRVDialect::registerAttributes() {
+  addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
+}

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a3b639fa4e05..81b3ee5e525f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -115,10 +115,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::initialize() {
-  addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
-           PointerType, RuntimeArrayType, SampledImageType, StructType>();
-
-  addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
+  registerAttributes();
+  registerTypes();
 
   // Add SPIR-V ops.
   addOperations<

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 2bfd9b8f084f..17ee2dfb0ec0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1154,3 +1154,12 @@ void MatrixType::getCapabilities(
   // Add any capabilities associated with the underlying vectors (i.e., columns)
   getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
 }
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Dialect
+//===----------------------------------------------------------------------===//
+
+void SPIRVDialect::registerTypes() {
+  addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
+           PointerType, RuntimeArrayType, SampledImageType, StructType>();
+}

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 8945253644d8..39d016b8d0fe 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -100,36 +100,6 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
   return false;
 }
 
-//===----------------------------------------------------------------------===//
-// VectorDialect
-//===----------------------------------------------------------------------===//
-
-void VectorDialect::initialize() {
-  addAttributes<CombiningKindAttr>();
-
-  addOperations<
-#define GET_OP_LIST
-#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
-      >();
-}
-
-/// Materialize a single constant operation from a given attribute value with
-/// the desired resultant type.
-Operation *VectorDialect::materializeConstant(OpBuilder &builder,
-                                              Attribute value, Type type,
-                                              Location loc) {
-  return builder.create<ConstantOp>(loc, type, value);
-}
-
-IntegerType vector::getVectorSubscriptType(Builder &builder) {
-  return builder.getIntegerType(64);
-}
-
-ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
-                                         ArrayRef<int64_t> values) {
-  return builder.getI64ArrayAttr(values);
-}
-
 //===----------------------------------------------------------------------===//
 // CombiningKindAttr
 //===----------------------------------------------------------------------===//
@@ -230,6 +200,36 @@ void VectorDialect::printAttribute(Attribute attr,
     llvm_unreachable("Unknown attribute type");
 }
 
+//===----------------------------------------------------------------------===//
+// VectorDialect
+//===----------------------------------------------------------------------===//
+
+void VectorDialect::initialize() {
+  addAttributes<CombiningKindAttr>();
+
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
+      >();
+}
+
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *VectorDialect::materializeConstant(OpBuilder &builder,
+                                              Attribute value, Type type,
+                                              Location loc) {
+  return builder.create<ConstantOp>(loc, type, value);
+}
+
+IntegerType vector::getVectorSubscriptType(Builder &builder) {
+  return builder.getIntegerType(64);
+}
+
+ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
+                                         ArrayRef<int64_t> values) {
+  return builder.getI64ArrayAttr(values);
+}
+
 //===----------------------------------------------------------------------===//
 // ReductionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 1d76122996de..5efb8f7c70ff 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -9,6 +9,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "AttributeDetail.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/IntegerSet.h"
@@ -28,6 +29,18 @@ using namespace mlir::detail;
 #define GET_ATTRDEF_CLASSES
 #include "mlir/IR/BuiltinAttributes.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// BuiltinDialect
+//===----------------------------------------------------------------------===//
+
+void BuiltinDialect::registerAttributes() {
+  addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
+                DenseStringElementsAttr, DictionaryAttr, FloatAttr,
+                SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
+                OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
+                UnitAttr>();
+}
+
 //===----------------------------------------------------------------------===//
 // DictionaryAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index b19d541e5045..28aef1500a00 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -60,17 +60,9 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
 } // end anonymous namespace.
 
 void BuiltinDialect::initialize() {
-  addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
-           Float80Type, Float128Type, FunctionType, IndexType, IntegerType,
-           MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
-           RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
-  addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
-                DenseStringElementsAttr, DictionaryAttr, FloatAttr,
-                SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
-                OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
-                UnitAttr>();
-  addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
-                UnknownLoc>();
+  registerTypes();
+  registerAttributes();
+  registerLocationAttributes();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/IR/BuiltinOps.cpp.inc"

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 652883f745e3..758e16bf1999 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -30,6 +30,17 @@ using namespace mlir::detail;
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// BuiltinDialect
+//===----------------------------------------------------------------------===//
+
+void BuiltinDialect::registerTypes() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/IR/BuiltinTypes.cpp.inc"
+      >();
+}
+
 //===----------------------------------------------------------------------===//
 /// ComplexType
 //===----------------------------------------------------------------------===//
@@ -514,7 +525,7 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
   if (!BaseMemRefType::isValidElementType(elementType))
     return emitError() << "invalid memref element type";
 
-    // Negative sizes are not allowed except for `-1` that means dynamic size.
+  // Negative sizes are not allowed except for `-1` that means dynamic size.
   for (int64_t s : shape)
     if (s < -1)
       return emitError() << "invalid memref size";

diff  --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index 93a9d265209d..cf730199e693 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/Location.h"
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Identifier.h"
 #include "llvm/ADT/SetVector.h"
 
@@ -20,6 +21,17 @@ using namespace mlir::detail;
 #define GET_ATTRDEF_CLASSES
 #include "mlir/IR/BuiltinLocationAttributes.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// BuiltinDialect
+//===----------------------------------------------------------------------===//
+
+void BuiltinDialect::registerLocationAttributes() {
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/IR/BuiltinLocationAttributes.cpp.inc"
+      >();
+}
+
 //===----------------------------------------------------------------------===//
 // LocationAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index 7a802430f010..e7805150e37d 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -100,12 +100,23 @@ class ParametricStorageUniquer {
     return storage;
   }
 
+  /// Destroy all of the storage instances within the given shard.
+  void destroyShardInstances(Shard &shard) {
+    if (!destructorFn)
+      return;
+    for (HashedStorage &instance : shard.instances)
+      destructorFn(instance.storage);
+  }
+
 public:
 #if LLVM_ENABLE_THREADS != 0
   /// Initialize the storage uniquer with a given number of storage shards to
-  /// use. The provided shard number is required to be a valid power of 2.
-  ParametricStorageUniquer(size_t numShards = 8)
-      : shards(new std::atomic<Shard *>[numShards]), numShards(numShards) {
+  /// use. The provided shard number is required to be a valid power of 2. The
+  /// destructor function is used to destroy any allocated storage instances.
+  ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
+                           size_t numShards = 8)
+      : shards(new std::atomic<Shard *>[numShards]), numShards(numShards),
+        destructorFn(destructorFn) {
     assert(llvm::isPowerOf2_64(numShards) &&
            "the number of shards is required to be a power of 2");
     for (size_t i = 0; i < numShards; i++)
@@ -113,9 +124,12 @@ class ParametricStorageUniquer {
   }
   ~ParametricStorageUniquer() {
     // Free all of the allocated shards.
-    for (size_t i = 0; i != numShards; ++i)
-      if (Shard *shard = shards[i].load())
+    for (size_t i = 0; i != numShards; ++i) {
+      if (Shard *shard = shards[i].load()) {
+        destroyShardInstances(*shard);
         delete shard;
+      }
+    }
   }
   /// Get or create an instance of a parametric type.
   BaseStorage *
@@ -204,10 +218,17 @@ class ParametricStorageUniquer {
   /// The number of available shards.
   size_t numShards;
 
+  /// Function to used to destruct any allocated storage instances.
+  function_ref<void(BaseStorage *)> destructorFn;
+
 #else
   /// If multi-threading is disabled, ignore the shard parameter as we will
-  /// always use one shard.
-  ParametricStorageUniquer(size_t numShards = 0) {}
+  /// always use one shard. The destructor function is used to destroy any
+  /// allocated storage instances.
+  ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
+                           size_t numShards = 0)
+      : destructorFn(destructorFn) {}
+  ~ParametricStorageUniquer() { destroyShardInstances(shard); }
 
   /// Get or create an instance of a parametric type.
   BaseStorage *
@@ -228,6 +249,9 @@ class ParametricStorageUniquer {
 private:
   /// The main uniquer shard that is used for allocating storage instances.
   Shard shard;
+
+  /// Function to used to destruct any allocated storage instances.
+  function_ref<void(BaseStorage *)> destructorFn;
 #endif
 };
 } // end anonymous namespace
@@ -323,9 +347,10 @@ auto StorageUniquer::getParametricStorageTypeImpl(
 
 /// Implementation for registering an instance of a derived type with
 /// parametric storage.
-void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
+void StorageUniquer::registerParametricStorageTypeImpl(
+    TypeID id, function_ref<void(BaseStorage *)> destructorFn) {
   impl->parametricUniquers.try_emplace(
-      id, std::make_unique<ParametricStorageUniquer>());
+      id, std::make_unique<ParametricStorageUniquer>(destructorFn));
 }
 
 /// Implementation for getting an instance of a derived type with default

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 72921a22e475..8dcc3498c964 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -100,6 +100,13 @@ void CompoundAAttr::print(DialectAsmPrinter &printer) const {
 // TestDialect
 //===----------------------------------------------------------------------===//
 
+void TestDialect::registerAttributes() {
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "TestAttrDefs.cpp.inc"
+      >();
+}
+
 Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
                                       Type type) const {
   StringRef attrTag;

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 0244d1073623..991094d3b0b0 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -166,20 +166,14 @@ struct TestInlinerInterface : public DialectInlinerInterface {
 //===----------------------------------------------------------------------===//
 
 void TestDialect::initialize() {
+  registerAttributes();
+  registerTypes();
   addOperations<
 #define GET_OP_LIST
 #include "TestOps.cpp.inc"
       >();
-  addAttributes<
-#define GET_ATTRDEF_LIST
-#include "TestAttrDefs.cpp.inc"
-      >();
   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
                 TestInlinerInterface>();
-  addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
-#define GET_TYPEDEF_LIST
-#include "TestTypeDefs.cpp.inc"
-           >();
   allowUnknownOperations();
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 39b055691db2..1968ebd46f6f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -32,6 +32,9 @@ def Test_Dialect : Dialect {
   let dependentDialects = ["::mlir::DLTIDialect"];
 
   let extraClassDeclaration = [{
+    void registerAttributes();
+    void registerTypes();
+
     Attribute parseAttribute(DialectAsmParser &parser,
                              Type type) const override;
     void printAttribute(Attribute attr,

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index f8d0c6a83f07..38ab9c819974 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -164,6 +164,13 @@ unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
 // TestDialect
 //===----------------------------------------------------------------------===//
 
+void TestDialect::registerTypes() {
+  addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
+#define GET_TYPEDEF_LIST
+#include "TestTypeDefs.cpp.inc"
+           >();
+}
+
 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
                           llvm::SetVector<Type> &stack) {
   StringRef typeTag;

diff  --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 7ea17583bc3e..6616a793ec12 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRSupportTests
   DebugCounterTest.cpp
   IndentedOstreamTest.cpp
   MathExtrasTest.cpp
+  StorageUniquerTest.cpp
 )
 
 target_link_libraries(MLIRSupportTests

diff  --git a/mlir/unittests/Support/StorageUniquerTest.cpp b/mlir/unittests/Support/StorageUniquerTest.cpp
new file mode 100644
index 000000000000..6db6783bb89f
--- /dev/null
+++ b/mlir/unittests/Support/StorageUniquerTest.cpp
@@ -0,0 +1,60 @@
+//===- StorageUniquerTest.cpp - StorageUniquer Tests ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/StorageUniquer.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+
+namespace {
+/// Simple storage class used for testing.
+template <typename ConcreteT, typename... Args>
+struct SimpleStorage : public StorageUniquer::BaseStorage {
+  using Base = SimpleStorage<ConcreteT, Args...>;
+  using KeyTy = std::tuple<Args...>;
+
+  SimpleStorage(KeyTy key) : key(key) {}
+
+  /// Get an instance of this storage instance.
+  template <typename... ParamsT>
+  static ConcreteT *get(StorageUniquer &uniquer, ParamsT &&...params) {
+    return uniquer.get<ConcreteT>(
+        /*initFn=*/{}, std::make_tuple(std::forward<ParamsT>(params)...));
+  }
+
+  /// Construct an instance with the given storage allocator.
+  static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
+                              KeyTy key) {
+    return new (alloc.allocate<ConcreteT>())
+        ConcreteT(std::forward<KeyTy>(key));
+  }
+  bool operator==(const KeyTy &key) const { return this->key == key; }
+
+  KeyTy key;
+};
+} // namespace
+
+TEST(StorageUniquerTest, NonTrivialDestructor) {
+  struct NonTrivialStorage : public SimpleStorage<NonTrivialStorage, bool *> {
+    using Base::Base;
+    ~NonTrivialStorage() {
+      bool *wasDestructed = std::get<0>(key);
+      *wasDestructed = true;
+    }
+  };
+
+  // Verify that the storage instance destructor was properly called.
+  bool wasDestructed = false;
+  {
+    StorageUniquer uniquer;
+    uniquer.registerParametricStorageType<NonTrivialStorage>();
+    NonTrivialStorage::get(uniquer, &wasDestructed);
+  }
+
+  EXPECT_TRUE(wasDestructed);
+}


        


More information about the Mlir-commits mailing list