[flang-commits] [flang] 31bb8ef - [mlir][StorageUniquer] Properly call the destructor on non-trivially destructible storage instances
River Riddle via flang-commits
flang-commits at lists.llvm.org
Thu Mar 11 11:35:45 PST 2021
Author: River Riddle
Date: 2021-03-11T11:35:32-08:00
New Revision: 31bb8efd698304a8385ff79229ffbaa5613efdfb
URL: https://github.com/llvm/llvm-project/commit/31bb8efd698304a8385ff79229ffbaa5613efdfb
DIFF: https://github.com/llvm/llvm-project/commit/31bb8efd698304a8385ff79229ffbaa5613efdfb.diff
LOG: [mlir][StorageUniquer] Properly call the destructor on non-trivially destructible storage instances
This allows for storage instances to store data that isn't uniqued in the context, or contain otherwise non-trivial logic, in the rare situations that they occur. Storage instances with trivial destructors will still have their destructor skipped. A consequence of this is that the storage instance definition must be visible from the place that registers the type.
Differential Revision: https://reviews.llvm.org/D98311
Added:
mlir/unittests/Support/StorageUniquerTest.cpp
Modified:
flang/include/flang/Optimizer/Dialect/FIRDialect.h
flang/lib/Optimizer/Dialect/FIRAttr.cpp
flang/lib/Optimizer/Dialect/FIRDialect.cpp
flang/lib/Optimizer/Dialect/FIRType.cpp
mlir/docs/Tutorials/DefiningAttributesAndTypes.md
mlir/docs/Tutorials/Toy/Ch-7.md
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/IR/BuiltinDialect.td
mlir/include/mlir/Support/StorageUniquer.h
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/Location.cpp
mlir/lib/Support/StorageUniquer.cpp
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/unittests/Support/CMakeLists.txt
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIRDialect.h b/flang/include/flang/Optimizer/Dialect/FIRDialect.h
index fb82d520dbc2..4bafb4ab7fb6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRDialect.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRDialect.h
@@ -32,6 +32,12 @@ class FIROpsDialect final : public mlir::Dialect {
mlir::Type type) const override;
void printAttribute(mlir::Attribute attr,
mlir::DialectAsmPrinter &p) const override;
+
+private:
+ // Register the Attributes of this dialect.
+ void registerAttributes();
+ // Register the Types of this dialect.
+ void registerTypes();
};
} // namespace fir
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 035245dbe935..a2fdf7cd43d0 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -243,3 +243,12 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
os << "<(unknown attribute)>";
}
}
+
+//===----------------------------------------------------------------------===//
+// FIROpsDialect
+//===----------------------------------------------------------------------===//
+
+void FIROpsDialect::registerAttributes() {
+ addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
+ PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
+}
diff --git a/flang/lib/Optimizer/Dialect/FIRDialect.cpp b/flang/lib/Optimizer/Dialect/FIRDialect.cpp
index 889b5ef55366..f80aa7d3380e 100644
--- a/flang/lib/Optimizer/Dialect/FIRDialect.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRDialect.cpp
@@ -19,13 +19,8 @@ using namespace fir;
fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
: mlir::Dialect("fir", ctx, mlir::TypeID::get<FIROpsDialect>()) {
- addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
- FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
- PointerType, RealType, RecordType, ReferenceType, SequenceType,
- ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
- fir::VectorType>();
- addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
- PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
+ registerTypes();
+ registerAttributes();
addOperations<
#define GET_OP_LIST
#include "flang/Optimizer/Dialect/FIROps.cpp.inc"
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 873f589e8a4b..beab54e4f1f8 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -866,3 +866,15 @@ mlir::LogicalResult fir::VectorType::verify(
bool fir::VectorType::isValidElementType(mlir::Type t) {
return isa_real(t) || isa_integer(t);
}
+
+//===----------------------------------------------------------------------===//
+// FIROpsDialect
+//===----------------------------------------------------------------------===//
+
+void FIROpsDialect::registerTypes() {
+ addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
+ FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
+ PointerType, RealType, RecordType, ReferenceType, SequenceType,
+ ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
+ fir::VectorType>();
+}
diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 6a261da8a6c2..7942fa2e1868 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -319,7 +319,9 @@ public:
Once the dialect types have been defined, they must then be registered with a
`Dialect`. This is done via a similar mechanism to
-[operations](LangRef.md#operations), with the `addTypes` method.
+[operations](LangRef.md#operations), with the `addTypes` method. The one
+distinct
diff erence with operations, is that when a type is registered the
+definition of its storage class must be visible.
```c++
struct MyDialect : public Dialect {
diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md
index 7074521eb5f4..315cf3237b40 100644
--- a/mlir/docs/Tutorials/Toy/Ch-7.md
+++ b/mlir/docs/Tutorials/Toy/Ch-7.md
@@ -187,6 +187,9 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
}
```
+(An important note here is that when registering a type, the definition of the
+storage class must be visible.)
+
With this we can now use our `StructType` when generating MLIR from Toy. See
examples/toy/Ch7/mlir/MLIRGen.cpp for more details.
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 8170b5b579cb..cbcf53f313c1 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -76,33 +76,6 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
}
};
-//===----------------------------------------------------------------------===//
-// ToyDialect
-//===----------------------------------------------------------------------===//
-
-/// Dialect creation, the instance will be owned by the context. This is the
-/// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
- : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
- addOperations<
-#define GET_OP_LIST
-#include "toy/Ops.cpp.inc"
- >();
- addInterfaces<ToyInlinerInterface>();
- addTypes<StructType>();
-}
-
-mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
- mlir::Attribute value,
- mlir::Type type,
- mlir::Location loc) {
- if (type.isa<StructType>())
- return builder.create<StructConstantOp>(loc, type,
- value.cast<mlir::ArrayAttr>());
- return builder.create<ConstantOp>(loc, type,
- value.cast<mlir::DenseElementsAttr>());
-}
-
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
@@ -566,3 +539,30 @@ void ToyDialect::printType(mlir::Type type,
#define GET_OP_CLASSES
#include "toy/Ops.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ToyDialect
+//===----------------------------------------------------------------------===//
+
+/// Dialect creation, the instance will be owned by the context. This is the
+/// point of registration of custom types and operations for the dialect.
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+ : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
+ addOperations<
+#define GET_OP_LIST
+#include "toy/Ops.cpp.inc"
+ >();
+ addInterfaces<ToyInlinerInterface>();
+ addTypes<StructType>();
+}
+
+mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
+ mlir::Attribute value,
+ mlir::Type type,
+ mlir::Location loc) {
+ if (type.isa<StructType>())
+ return builder.create<StructConstantOp>(loc, type,
+ value.cast<mlir::ArrayAttr>());
+ return builder.create<ConstantOp>(loc, type,
+ value.cast<mlir::DenseElementsAttr>());
+}
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
index c6d76be48494..afdf50673ed4 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
@@ -64,6 +64,9 @@ def PDL_Dialect : Dialect {
let name = "pdl";
let cppNamespace = "::mlir::pdl";
+ let extraClassDeclaration = [{
+ void registerTypes();
+ }];
}
#endif // MLIR_DIALECT_PDL_IR_PDLDIALECT
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index f18dcef1997a..d293a6a88afd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -52,6 +52,9 @@ def SPIRV_Dialect : Dialect {
let hasRegionResultAttrVerify = 1;
let extraClassDeclaration = [{
+ void registerAttributes();
+ void registerTypes();
+
//===------------------------------------------------------------------===//
// Attribute
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
index 383f87bd5d60..a52257c9d1f4 100644
--- a/mlir/include/mlir/IR/BuiltinDialect.td
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -22,6 +22,17 @@ def Builtin_Dialect : Dialect {
let name = "";
let cppNamespace = "::mlir";
+ let extraClassDeclaration = [{
+ private:
+ // Register the builtin Attributes.
+ void registerAttributes();
+ // Register the builtin Location Attributes.
+ void registerLocationAttributes();
+ // Register the builtin Types.
+ void registerTypes();
+
+ public:
+ }];
}
#endif // BUILTIN_BASE
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index fc7ffa74f3b5..2b66edb51ac9 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -135,7 +135,13 @@ class StorageUniquer {
/// instances of this class type. `id` is the type identifier that will be
/// used to identify this type when creating instances of it via 'get'.
template <typename Storage> void registerParametricStorageType(TypeID id) {
- registerParametricStorageTypeImpl(id);
+ // If the storage is trivially destructible, we don't need a destructor
+ // function.
+ if (std::is_trivially_destructible<Storage>::value)
+ return registerParametricStorageTypeImpl(id, nullptr);
+ registerParametricStorageTypeImpl(id, [](BaseStorage *storage) {
+ static_cast<Storage *>(storage)->~Storage();
+ });
}
/// Utility override when the storage type represents the type id.
template <typename Storage> void registerParametricStorageType() {
@@ -244,8 +250,10 @@ class StorageUniquer {
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
/// Implementation for registering an instance of a derived type with
- /// parametric storage.
- void registerParametricStorageTypeImpl(TypeID id);
+ /// parametric storage. This method takes an optional destructor function that
+ /// destructs storage instances when necessary.
+ void registerParametricStorageTypeImpl(
+ TypeID id, function_ref<void(BaseStorage *)> destructorFn);
/// Implementation for getting an instance of a derived type with default
/// storage.
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index a796b0725d36..895c3b0f4734 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -20,6 +20,12 @@
using namespace mlir;
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
+
void arm_sve::ArmSVEDialect::initialize() {
addOperations<
#define GET_OP_LIST
@@ -31,12 +37,6 @@ void arm_sve::ArmSVEDialect::initialize() {
>();
}
-#define GET_OP_CLASSES
-#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
-
//===----------------------------------------------------------------------===//
// ScalableVectorType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 02b4687b35a4..bfc90b2ac674 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "TypeDetail.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index e82c4ab6fb16..beb43d7072f2 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -25,10 +25,7 @@ void PDLDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
>();
- addTypes<
-#define GET_TYPEDEF_LIST
-#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
- >();
+ registerTypes();
}
/// Returns true if the given operation is used by a "binding" pdl operation
diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
index b16fade224fc..20f013af246f 100644
--- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
@@ -26,6 +26,13 @@ using namespace mlir::pdl;
// PDLDialect
//===----------------------------------------------------------------------===//
+void PDLDialect::registerTypes() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
+ >();
+}
+
static Type parsePDLType(DialectAsmParser &parser) {
StringRef typeTag;
if (parser.parseKeyword(&typeTag))
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index c74c34d88dd7..a514b44a8991 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
@@ -350,3 +351,11 @@ spirv::TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Dialect
+//===----------------------------------------------------------------------===//
+
+void spirv::SPIRVDialect::registerAttributes() {
+ addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a3b639fa4e05..81b3ee5e525f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -115,10 +115,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
void SPIRVDialect::initialize() {
- addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
- PointerType, RuntimeArrayType, SampledImageType, StructType>();
-
- addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
+ registerAttributes();
+ registerTypes();
// Add SPIR-V ops.
addOperations<
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 2bfd9b8f084f..17ee2dfb0ec0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1154,3 +1154,12 @@ void MatrixType::getCapabilities(
// Add any capabilities associated with the underlying vectors (i.e., columns)
getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
}
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Dialect
+//===----------------------------------------------------------------------===//
+
+void SPIRVDialect::registerTypes() {
+ addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
+ PointerType, RuntimeArrayType, SampledImageType, StructType>();
+}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 8945253644d8..39d016b8d0fe 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -100,36 +100,6 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
return false;
}
-//===----------------------------------------------------------------------===//
-// VectorDialect
-//===----------------------------------------------------------------------===//
-
-void VectorDialect::initialize() {
- addAttributes<CombiningKindAttr>();
-
- addOperations<
-#define GET_OP_LIST
-#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
- >();
-}
-
-/// Materialize a single constant operation from a given attribute value with
-/// the desired resultant type.
-Operation *VectorDialect::materializeConstant(OpBuilder &builder,
- Attribute value, Type type,
- Location loc) {
- return builder.create<ConstantOp>(loc, type, value);
-}
-
-IntegerType vector::getVectorSubscriptType(Builder &builder) {
- return builder.getIntegerType(64);
-}
-
-ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
- ArrayRef<int64_t> values) {
- return builder.getI64ArrayAttr(values);
-}
-
//===----------------------------------------------------------------------===//
// CombiningKindAttr
//===----------------------------------------------------------------------===//
@@ -230,6 +200,36 @@ void VectorDialect::printAttribute(Attribute attr,
llvm_unreachable("Unknown attribute type");
}
+//===----------------------------------------------------------------------===//
+// VectorDialect
+//===----------------------------------------------------------------------===//
+
+void VectorDialect::initialize() {
+ addAttributes<CombiningKindAttr>();
+
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
+ >();
+}
+
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *VectorDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<ConstantOp>(loc, type, value);
+}
+
+IntegerType vector::getVectorSubscriptType(Builder &builder) {
+ return builder.getIntegerType(64);
+}
+
+ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
+ ArrayRef<int64_t> values) {
+ return builder.getI64ArrayAttr(values);
+}
+
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 1d76122996de..5efb8f7c70ff 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -9,6 +9,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "AttributeDetail.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
@@ -28,6 +29,18 @@ using namespace mlir::detail;
#define GET_ATTRDEF_CLASSES
#include "mlir/IR/BuiltinAttributes.cpp.inc"
+//===----------------------------------------------------------------------===//
+// BuiltinDialect
+//===----------------------------------------------------------------------===//
+
+void BuiltinDialect::registerAttributes() {
+ addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
+ DenseStringElementsAttr, DictionaryAttr, FloatAttr,
+ SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
+ OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
+ UnitAttr>();
+}
+
//===----------------------------------------------------------------------===//
// DictionaryAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index b19d541e5045..28aef1500a00 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -60,17 +60,9 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
} // end anonymous namespace.
void BuiltinDialect::initialize() {
- addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
- Float80Type, Float128Type, FunctionType, IndexType, IntegerType,
- MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
- RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
- addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
- DenseStringElementsAttr, DictionaryAttr, FloatAttr,
- SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
- OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
- UnitAttr>();
- addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
- UnknownLoc>();
+ registerTypes();
+ registerAttributes();
+ registerLocationAttributes();
addOperations<
#define GET_OP_LIST
#include "mlir/IR/BuiltinOps.cpp.inc"
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 652883f745e3..758e16bf1999 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -30,6 +30,17 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
+//===----------------------------------------------------------------------===//
+// BuiltinDialect
+//===----------------------------------------------------------------------===//
+
+void BuiltinDialect::registerTypes() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/IR/BuiltinTypes.cpp.inc"
+ >();
+}
+
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//
@@ -514,7 +525,7 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
if (!BaseMemRefType::isValidElementType(elementType))
return emitError() << "invalid memref element type";
- // Negative sizes are not allowed except for `-1` that means dynamic size.
+ // Negative sizes are not allowed except for `-1` that means dynamic size.
for (int64_t s : shape)
if (s < -1)
return emitError() << "invalid memref size";
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index 93a9d265209d..cf730199e693 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Location.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Identifier.h"
#include "llvm/ADT/SetVector.h"
@@ -20,6 +21,17 @@ using namespace mlir::detail;
#define GET_ATTRDEF_CLASSES
#include "mlir/IR/BuiltinLocationAttributes.cpp.inc"
+//===----------------------------------------------------------------------===//
+// BuiltinDialect
+//===----------------------------------------------------------------------===//
+
+void BuiltinDialect::registerLocationAttributes() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/IR/BuiltinLocationAttributes.cpp.inc"
+ >();
+}
+
//===----------------------------------------------------------------------===//
// LocationAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index 7a802430f010..e7805150e37d 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -100,12 +100,23 @@ class ParametricStorageUniquer {
return storage;
}
+ /// Destroy all of the storage instances within the given shard.
+ void destroyShardInstances(Shard &shard) {
+ if (!destructorFn)
+ return;
+ for (HashedStorage &instance : shard.instances)
+ destructorFn(instance.storage);
+ }
+
public:
#if LLVM_ENABLE_THREADS != 0
/// Initialize the storage uniquer with a given number of storage shards to
- /// use. The provided shard number is required to be a valid power of 2.
- ParametricStorageUniquer(size_t numShards = 8)
- : shards(new std::atomic<Shard *>[numShards]), numShards(numShards) {
+ /// use. The provided shard number is required to be a valid power of 2. The
+ /// destructor function is used to destroy any allocated storage instances.
+ ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
+ size_t numShards = 8)
+ : shards(new std::atomic<Shard *>[numShards]), numShards(numShards),
+ destructorFn(destructorFn) {
assert(llvm::isPowerOf2_64(numShards) &&
"the number of shards is required to be a power of 2");
for (size_t i = 0; i < numShards; i++)
@@ -113,9 +124,12 @@ class ParametricStorageUniquer {
}
~ParametricStorageUniquer() {
// Free all of the allocated shards.
- for (size_t i = 0; i != numShards; ++i)
- if (Shard *shard = shards[i].load())
+ for (size_t i = 0; i != numShards; ++i) {
+ if (Shard *shard = shards[i].load()) {
+ destroyShardInstances(*shard);
delete shard;
+ }
+ }
}
/// Get or create an instance of a parametric type.
BaseStorage *
@@ -204,10 +218,17 @@ class ParametricStorageUniquer {
/// The number of available shards.
size_t numShards;
+ /// Function to used to destruct any allocated storage instances.
+ function_ref<void(BaseStorage *)> destructorFn;
+
#else
/// If multi-threading is disabled, ignore the shard parameter as we will
- /// always use one shard.
- ParametricStorageUniquer(size_t numShards = 0) {}
+ /// always use one shard. The destructor function is used to destroy any
+ /// allocated storage instances.
+ ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
+ size_t numShards = 0)
+ : destructorFn(destructorFn) {}
+ ~ParametricStorageUniquer() { destroyShardInstances(shard); }
/// Get or create an instance of a parametric type.
BaseStorage *
@@ -228,6 +249,9 @@ class ParametricStorageUniquer {
private:
/// The main uniquer shard that is used for allocating storage instances.
Shard shard;
+
+ /// Function to used to destruct any allocated storage instances.
+ function_ref<void(BaseStorage *)> destructorFn;
#endif
};
} // end anonymous namespace
@@ -323,9 +347,10 @@ auto StorageUniquer::getParametricStorageTypeImpl(
/// Implementation for registering an instance of a derived type with
/// parametric storage.
-void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
+void StorageUniquer::registerParametricStorageTypeImpl(
+ TypeID id, function_ref<void(BaseStorage *)> destructorFn) {
impl->parametricUniquers.try_emplace(
- id, std::make_unique<ParametricStorageUniquer>());
+ id, std::make_unique<ParametricStorageUniquer>(destructorFn));
}
/// Implementation for getting an instance of a derived type with default
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 72921a22e475..8dcc3498c964 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -100,6 +100,13 @@ void CompoundAAttr::print(DialectAsmPrinter &printer) const {
// TestDialect
//===----------------------------------------------------------------------===//
+void TestDialect::registerAttributes() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "TestAttrDefs.cpp.inc"
+ >();
+}
+
Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
StringRef attrTag;
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 0244d1073623..991094d3b0b0 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -166,20 +166,14 @@ struct TestInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
void TestDialect::initialize() {
+ registerAttributes();
+ registerTypes();
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
- addAttributes<
-#define GET_ATTRDEF_LIST
-#include "TestAttrDefs.cpp.inc"
- >();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>();
- addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
-#define GET_TYPEDEF_LIST
-#include "TestTypeDefs.cpp.inc"
- >();
allowUnknownOperations();
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 39b055691db2..1968ebd46f6f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -32,6 +32,9 @@ def Test_Dialect : Dialect {
let dependentDialects = ["::mlir::DLTIDialect"];
let extraClassDeclaration = [{
+ void registerAttributes();
+ void registerTypes();
+
Attribute parseAttribute(DialectAsmParser &parser,
Type type) const override;
void printAttribute(Attribute attr,
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index f8d0c6a83f07..38ab9c819974 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -164,6 +164,13 @@ unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
// TestDialect
//===----------------------------------------------------------------------===//
+void TestDialect::registerTypes() {
+ addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
+#define GET_TYPEDEF_LIST
+#include "TestTypeDefs.cpp.inc"
+ >();
+}
+
static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
llvm::SetVector<Type> &stack) {
StringRef typeTag;
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 7ea17583bc3e..6616a793ec12 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRSupportTests
DebugCounterTest.cpp
IndentedOstreamTest.cpp
MathExtrasTest.cpp
+ StorageUniquerTest.cpp
)
target_link_libraries(MLIRSupportTests
diff --git a/mlir/unittests/Support/StorageUniquerTest.cpp b/mlir/unittests/Support/StorageUniquerTest.cpp
new file mode 100644
index 000000000000..6db6783bb89f
--- /dev/null
+++ b/mlir/unittests/Support/StorageUniquerTest.cpp
@@ -0,0 +1,60 @@
+//===- StorageUniquerTest.cpp - StorageUniquer Tests ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/StorageUniquer.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+
+namespace {
+/// Simple storage class used for testing.
+template <typename ConcreteT, typename... Args>
+struct SimpleStorage : public StorageUniquer::BaseStorage {
+ using Base = SimpleStorage<ConcreteT, Args...>;
+ using KeyTy = std::tuple<Args...>;
+
+ SimpleStorage(KeyTy key) : key(key) {}
+
+ /// Get an instance of this storage instance.
+ template <typename... ParamsT>
+ static ConcreteT *get(StorageUniquer &uniquer, ParamsT &&...params) {
+ return uniquer.get<ConcreteT>(
+ /*initFn=*/{}, std::make_tuple(std::forward<ParamsT>(params)...));
+ }
+
+ /// Construct an instance with the given storage allocator.
+ static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
+ KeyTy key) {
+ return new (alloc.allocate<ConcreteT>())
+ ConcreteT(std::forward<KeyTy>(key));
+ }
+ bool operator==(const KeyTy &key) const { return this->key == key; }
+
+ KeyTy key;
+};
+} // namespace
+
+TEST(StorageUniquerTest, NonTrivialDestructor) {
+ struct NonTrivialStorage : public SimpleStorage<NonTrivialStorage, bool *> {
+ using Base::Base;
+ ~NonTrivialStorage() {
+ bool *wasDestructed = std::get<0>(key);
+ *wasDestructed = true;
+ }
+ };
+
+ // Verify that the storage instance destructor was properly called.
+ bool wasDestructed = false;
+ {
+ StorageUniquer uniquer;
+ uniquer.registerParametricStorageType<NonTrivialStorage>();
+ NonTrivialStorage::get(uniquer, &wasDestructed);
+ }
+
+ EXPECT_TRUE(wasDestructed);
+}
More information about the flang-commits
mailing list