[Mlir-commits] [mlir] 575b22b - Revisit Dialect registration: require and store a TypeID on dialects

Mehdi Amini llvmlistbot at llvm.org
Fri Aug 7 08:57:19 PDT 2020


Author: Mehdi Amini
Date: 2020-08-07T15:57:08Z
New Revision: 575b22b5d11bc4c4eb85dde456d9ac7f3cfa3924

URL: https://github.com/llvm/llvm-project/commit/575b22b5d11bc4c4eb85dde456d9ac7f3cfa3924
DIFF: https://github.com/llvm/llvm-project/commit/575b22b5d11bc4c4eb85dde456d9ac7f3cfa3924.diff

LOG: Revisit Dialect registration: require and store a TypeID on dialects

This patch moves the registration to a method in the MLIRContext: getOrCreateDialect<ConcreteDialect>()

This method requires dialect to provide a static getDialectNamespace()
and store a TypeID on the Dialect itself, which allows to lazyily
create a dialect when not yet loaded in the context.
As a side effect, it means that duplicated registration of the same
dialect is not an issue anymore.

To limit the boilerplate, TableGen dialect generation is modified to
emit the constructor entirely and invoke separately a "init()" method
that the user implements.

Differential Revision: https://reviews.llvm.org/D85495

Added: 
    

Modified: 
    mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp
    mlir/examples/toy/Ch2/mlir/Dialect.cpp
    mlir/examples/toy/Ch3/mlir/Dialect.cpp
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
    mlir/include/mlir/IR/Dialect.h
    mlir/include/mlir/IR/MLIRContext.h
    mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Dialect/Quant/IR/QuantOps.cpp
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/IR/Dialect.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/tools/mlir-tblgen/DialectGen.cpp
    mlir/unittests/IR/DialectTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp b/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp
index 3a253f394b72..acdf88ab9b43 100644
--- a/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp
+++ b/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp
@@ -16,8 +16,7 @@ using namespace mlir::standalone;
 // Standalone dialect.
 //===----------------------------------------------------------------------===//
 
-StandaloneDialect::StandaloneDialect(mlir::MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void StandaloneDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "Standalone/StandaloneOps.cpp.inc"

diff  --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
index 4be6bdc60205..86d638395268 100644
--- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -26,7 +26,8 @@ using namespace mlir::toy;
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"

diff  --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
index 4be6bdc60205..86d638395268 100644
--- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
@@ -26,7 +26,8 @@ using namespace mlir::toy;
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 97c97b07199b..ca568a55d8ea 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index 3f7dafa1e7a2..d1a518ee8ed9 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index 3f7dafa1e7a2..d1a518ee8ed9 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -75,7 +75,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index fc7bf2a2375c..e233a5549934 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -76,7 +76,8 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
 
 /// Dialect creation, the instance will be owned by the context. This is the
 /// point of registration of custom types and operations for the dialect.
-ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
   addOperations<
 #define GET_OP_LIST
 #include "toy/Ops.cpp.inc"

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 9b708fea9037..d21f5bc0b49b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -27,7 +27,9 @@ def LLVM_Dialect : Dialect {
   private:
     friend LLVMType;
 
-    std::unique_ptr<detail::LLVMDialectImpl> impl;
+    // This can't be a unique_ptr because the ctor is generated inline
+    // in the class definition at the moment.
+    detail::LLVMDialectImpl *impl;
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
index 0993b438a967..901ada94482f 100644
--- a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
+++ b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
@@ -17,7 +17,8 @@ class MLIRContext;
 
 class SDBMDialect : public Dialect {
 public:
-  SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {}
+  SDBMDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {}
 
   /// Since there are no other virtual methods in this derived class, override
   /// the destructor so that key methods get defined in the corresponding

diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 043fde9da729..bd9f3c12f64d 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -14,6 +14,7 @@
 #define MLIR_IR_DIALECT_H
 
 #include "mlir/IR/OperationSupport.h"
+#include "mlir/Support/TypeID.h"
 
 namespace mlir {
 class DialectAsmParser;
@@ -49,6 +50,9 @@ class Dialect {
 
   StringRef getNamespace() const { return name; }
 
+  /// Returns the unique identifier that corresponds to this dialect.
+  TypeID getTypeID() const { return dialectID; }
+
   /// Returns true if this dialect allows for unregistered operations, i.e.
   /// operations prefixed with the dialect namespace but not registered with
   /// addOperation.
@@ -177,7 +181,7 @@ class Dialect {
   ///       with the namespace followed by '.'.
   /// Example:
   ///       - "tf" for the TensorFlow ops like "tf.add".
-  Dialect(StringRef name, MLIRContext *context);
+  Dialect(StringRef name, MLIRContext *context, TypeID id);
 
   /// This method is used by derived classes to add their operations to the set.
   ///
@@ -223,13 +227,13 @@ class Dialect {
   Dialect(const Dialect &) = delete;
   void operator=(Dialect &) = delete;
 
-  /// Register this dialect object with the specified context.  The context
-  /// takes ownership of the heap allocated dialect.
-  void registerDialect(MLIRContext *context);
-
   /// The namespace of this dialect.
   StringRef name;
 
+  /// The unique identifier of the derived Op class, this is used in the context
+  /// to allow registering multiple times the same dialect.
+  TypeID dialectID;
+
   /// This is the context that owns this Dialect object.
   MLIRContext *context;
 
@@ -255,7 +259,9 @@ class Dialect {
                            const DialectAllocatorFunction &function);
   template <typename ConcreteDialect>
   friend void registerDialect();
+  friend class MLIRContext;
 };
+
 /// Registers all dialects and hooks from the global registries with the
 /// specified MLIRContext.
 /// Note: This method is not thread-safe.
@@ -265,12 +271,9 @@ void registerAllDialects(MLIRContext *context);
 /// global registry by calling registerDialect<MyDialect>();
 /// Note: This method is not thread-safe.
 template <typename ConcreteDialect> void registerDialect() {
-  Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
-                                    [](MLIRContext *ctx) {
-                                      // Just allocate the dialect, the context
-                                      // takes ownership of it.
-                                      new ConcreteDialect(ctx);
-                                    });
+  Dialect::registerDialectAllocator(
+      TypeID::get<ConcreteDialect>(),
+      [](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
 }
 
 /// DialectRegistration provides a global initializer that registers a Dialect
@@ -291,7 +294,7 @@ namespace llvm {
 template <typename T>
 struct isa_impl<T, ::mlir::Dialect> {
   static inline bool doit(const ::mlir::Dialect &dialect) {
-    return T::getDialectNamespace() == dialect.getNamespace();
+    return mlir::TypeID::get<T>() == dialect.getTypeID();
   }
 };
 } // namespace llvm

diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 8e75bb624449..0192a8ae06af 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -10,6 +10,7 @@
 #define MLIR_IR_MLIRCONTEXT_H
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
 #include <functional>
 #include <memory>
 #include <vector>
@@ -49,6 +50,18 @@ class MLIRContext {
     return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
   }
 
+  /// Get (or create) a dialect for the given derived dialect type. The derived
+  /// type must provide a static 'getDialectNamespace' method.
+  template <typename T>
+  T *getOrCreateDialect() {
+    return static_cast<T *>(getOrCreateDialect(
+        T::getDialectNamespace(), TypeID::get<T>(), [this]() {
+          std::unique_ptr<T> dialect(new T(this));
+          dialect->dialectID = TypeID::get<T>();
+          return dialect;
+        }));
+  }
+
   /// Return true if we allow to create operation for unregistered dialects.
   bool allowsUnregisteredDialects();
 
@@ -109,6 +122,12 @@ class MLIRContext {
 private:
   const std::unique_ptr<MLIRContextImpl> impl;
 
+  /// Get a dialect for the provided namespace and TypeID: abort the program if
+  /// a dialect exist for this namespace with 
diff erent TypeID. Returns a
+  /// pointer to the dialect owned by the context.
+  Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
+                              function_ref<std::unique_ptr<Dialect>()> ctor);
+
   MLIRContext(const MLIRContext &) = delete;
   void operator=(const MLIRContext &) = delete;
 };

diff  --git a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
index aade931ee4e7..3595970c38f2 100644
--- a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
+++ b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
@@ -18,8 +18,7 @@
 
 using namespace mlir;
 
-avx512::AVX512Dialect::AVX512Dialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void avx512::AVX512Dialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/AVX512/AVX512.cpp.inc"

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5c9dc8f3be50..fa98f63706df 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -68,8 +68,7 @@ struct AffineInlinerInterface : public DialectInlinerInterface {
 // AffineDialect
 //===----------------------------------------------------------------------===//
 
-AffineDialect::AffineDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void AffineDialect::initialize() {
   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
 #define GET_OP_LIST
 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index dd8200d3687b..58f9480c37be 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -35,8 +35,7 @@ bool GPUDialect::isKernel(Operation *op) {
   return static_cast<bool>(isKernelAttr);
 }
 
-GPUDialect::GPUDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void GPUDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
index bde81144fb54..9f7e66b0ae0a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp
@@ -20,8 +20,7 @@
 
 using namespace mlir;
 
-LLVM::LLVMAVX512Dialect::LLVMAVX512Dialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void LLVM::LLVMAVX512Dialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc"

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e03d0256eea4..515e120888b4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1683,9 +1683,8 @@ struct LLVMDialectImpl {
 } // end namespace LLVM
 } // end namespace mlir
 
-LLVMDialect::LLVMDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context),
-      impl(new detail::LLVMDialectImpl()) {
+void LLVMDialect::initialize() {
+  impl = new detail::LLVMDialectImpl();
   // clang-format off
   addTypes<LLVMVoidType,
            LLVMHalfType,
@@ -1716,7 +1715,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
   allowUnknownOperations();
 }
 
-LLVMDialect::~LLVMDialect() {}
+LLVMDialect::~LLVMDialect() { delete impl; }
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 9a09488570e1..cc809b581c84 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -136,7 +136,7 @@ static LogicalResult verify(MmaOp op) {
 //===----------------------------------------------------------------------===//
 
 // TODO: This should be the llvm.nvvm dialect once this is supported.
-NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
+void NVVMDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 47089b9d934d..70c3558638e6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -81,7 +81,7 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser,
 //===----------------------------------------------------------------------===//
 
 // TODO: This should be the llvm.rocdl dialect once this is supported.
-ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) {
+void ROCDLDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index a55d4676b54a..50924f7b7866 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -24,8 +24,7 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void mlir::linalg::LinalgDialect::initialize() {
   addTypes<RangeType>();
   addOperations<
 #define GET_OP_LIST

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 4467d3361f51..9159e87509c6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -26,8 +26,7 @@
 using namespace mlir;
 using namespace mlir::omp;
 
-OpenMPDialect::OpenMPDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void OpenMPDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index 07f881fbc52c..e7df59abc945 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -23,8 +23,7 @@ using namespace mlir;
 using namespace mlir::quant;
 using namespace mlir::quant::detail;
 
-QuantizationDialect::QuantizationDialect(MLIRContext *context)
-    : Dialect(/*name=*/"quant", context) {
+void QuantizationDialect::initialize() {
   addTypes<AnyQuantizedType, UniformQuantizedType,
            UniformQuantizedPerAxisType>();
   addOperations<

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index d0958e54269f..6f3f1e4dc0d1 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -53,8 +53,7 @@ struct SCFInlinerInterface : public DialectInlinerInterface {
 // SCFDialect
 //===----------------------------------------------------------------------===//
 
-SCFDialect::SCFDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void SCFDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index a2659d6a0eec..01c305720571 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -112,8 +112,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
 // SPIR-V Dialect
 //===----------------------------------------------------------------------===//
 
-SPIRVDialect::SPIRVDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void SPIRVDialect::initialize() {
   addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
            PointerType, RuntimeArrayType, StructType>();
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index be4c3c721572..47c592e51a40 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -59,8 +59,7 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
   return success();
 }
 
-ShapeDialect::ShapeDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void ShapeDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 74e1e20ac1a9..a19d579bcc5d 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -145,8 +145,7 @@ static LogicalResult verifyCastOp(T op) {
   return success();
 }
 
-StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void StandardOpsDialect::initialize() {
   addOperations<DmaStartOp, DmaWaitOp,
 #define GET_OP_LIST
 #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc"

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index e04091e8574f..7c715bfdb6d0 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -34,8 +34,7 @@ using namespace mlir::vector;
 // VectorDialect
 //===----------------------------------------------------------------------===//
 
-VectorDialect::VectorDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void VectorDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 501cddad2a48..02448b3d00b2 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -66,10 +66,9 @@ void mlir::registerAllDialects(MLIRContext *context) {
 // Dialect
 //===----------------------------------------------------------------------===//
 
-Dialect::Dialect(StringRef name, MLIRContext *context)
-    : name(name), context(context) {
+Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
+    : name(name), dialectID(id), context(context) {
   assert(isValidNamespace(name) && "invalid dialect namespace");
-  registerDialect(context);
 }
 
 Dialect::~Dialect() {}

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index a4e833cbf77c..32ee60f5512a 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -85,7 +85,8 @@ namespace {
 /// A builtin dialect to define types/etc that are necessary for the validity of
 /// the IR.
 struct BuiltinDialect : public Dialect {
-  BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
+  BuiltinDialect(MLIRContext *context)
+      : Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
     addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
                   DenseStringElementsAttr, DictionaryAttr, FloatAttr,
                   SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
@@ -102,6 +103,7 @@ struct BuiltinDialect : public Dialect {
     // have been fully decoupled from the core.
     addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
   }
+  static StringRef getDialectNamespace() { return ""; }
 };
 } // end anonymous namespace.
 
@@ -349,7 +351,7 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
   }
 
   // Register dialects with this context.
-  new BuiltinDialect(this);
+  getOrCreateDialect<BuiltinDialect>();
   registerAllDialects(this);
 
   // Initialize several common attributes and types to avoid the need to lock
@@ -446,25 +448,33 @@ Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
              : nullptr;
 }
 
-/// Register this dialect object with the specified context.  The context
-/// takes ownership of the heap allocated dialect.
-void Dialect::registerDialect(MLIRContext *context) {
-  auto &impl = context->getImpl();
-  std::unique_ptr<Dialect> dialect(this);
-
+/// Get a dialect for the provided namespace and TypeID: abort the program if a
+/// dialect exist for this namespace with 
diff erent TypeID. Returns a pointer to
+/// the dialect owned by the context.
+Dialect *
+MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
+                                function_ref<std::unique_ptr<Dialect>()> ctor) {
+  auto &impl = getImpl();
   // Get the correct insertion position sorted by namespace.
-  auto insertPt = llvm::lower_bound(
-      impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
-        return lhs->getNamespace() < rhs->getNamespace();
-      });
+  auto insertPt =
+      llvm::lower_bound(impl.dialects, nullptr,
+                        [&](const std::unique_ptr<Dialect> &lhs,
+                            const std::unique_ptr<Dialect> &rhs) {
+                          if (!lhs)
+                            return dialectNamespace < rhs->getNamespace();
+                          return lhs->getNamespace() < dialectNamespace;
+                        });
 
   // Abort if dialect with namespace has already been registered.
   if (insertPt != impl.dialects.end() &&
-      (*insertPt)->getNamespace() == getNamespace()) {
-    llvm::report_fatal_error("a dialect with namespace '" + getNamespace() +
+      (*insertPt)->getNamespace() == dialectNamespace) {
+    if ((*insertPt)->getTypeID() == dialectID)
+      return insertPt->get();
+    llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
                              "' has already been registered");
   }
-  impl.dialects.insert(insertPt, std::move(dialect));
+  auto it = impl.dialects.insert(insertPt, ctor());
+  return &**it;
 }
 
 bool MLIRContext::allowsUnregisteredDialects() {

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index cdbf974679bd..c9cfdc5ff415 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -130,8 +130,7 @@ struct TestInlinerInterface : public DialectInlinerInterface {
 // TestDialect
 //===----------------------------------------------------------------------===//
 
-TestDialect::TestDialect(MLIRContext *context)
-    : Dialect(getDialectNamespace(), context) {
+void TestDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "TestOps.cpp.inc"

diff  --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 4a9109d360c3..13421c42c3c2 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -63,8 +63,14 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
 /// {1}: The dialect namespace.
 static const char *const dialectDeclBeginStr = R"(
 class {0} : public ::mlir::Dialect {
+  explicit {0}(::mlir::MLIRContext *context)
+    : ::mlir::Dialect(getDialectNamespace(), context,
+      ::mlir::TypeID::get<{0}>()) {{
+    initialize();
+  }
+  void initialize();
+  friend class ::mlir::MLIRContext;
 public:
-  explicit {0}(::mlir::MLIRContext *context);
   static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
 )";
 

diff  --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index 49d2e272d7f1..bc389ce1f0da 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -14,7 +14,15 @@ using namespace mlir::detail;
 
 namespace {
 struct TestDialect : public Dialect {
-  TestDialect(MLIRContext *context) : Dialect(/*name=*/"test", context) {}
+  static StringRef getDialectNamespace() { return "test"; };
+  TestDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
+};
+struct AnotherTestDialect : public Dialect {
+  static StringRef getDialectNamespace() { return "test"; };
+  AnotherTestDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context,
+                TypeID::get<AnotherTestDialect>()) {}
 };
 
 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
@@ -22,8 +30,8 @@ TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
 
   // Registering a dialect with the same namespace twice should result in a
   // failure.
-  new TestDialect(&context);
-  ASSERT_DEATH(new TestDialect(&context), "");
+  context.getOrCreateDialect<TestDialect>();
+  ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), "");
 }
 
 } // end namespace


        


More information about the Mlir-commits mailing list