[flang-commits] [flang] 2996a8d - [mlir] avoid exposing mutable DialectRegistry from MLIRContext

Alex Zinenko via flang-commits flang-commits at lists.llvm.org
Wed Feb 10 03:07:47 PST 2021


Author: Alex Zinenko
Date: 2021-02-10T12:07:34+01:00
New Revision: 2996a8d67553b9d469e01215b49bb1af17ad6d1e

URL: https://github.com/llvm/llvm-project/commit/2996a8d67553b9d469e01215b49bb1af17ad6d1e
DIFF: https://github.com/llvm/llvm-project/commit/2996a8d67553b9d469e01215b49bb1af17ad6d1e.diff

LOG: [mlir] avoid exposing mutable DialectRegistry from MLIRContext

MLIRContext allows its users to access directly to the DialectRegistry it
contains. While sometimes useful for registering additional dialects on an
already existing context, this breaks the encapsulation by essentially giving
raw accesses to a part of the context's internal state. Remove this mutable
access and instead provide a method to append a given DialectRegistry to the
one already contained in the context. Also provide a shortcut mechanism to
construct a context from an already existing registry, which seems to be a
common use case in the wild. Keep read-only access to the registry contained in
the context in case it needs to be copied or used for constructing another
context.

With this change, DialectRegistry is no longer concerned with loading the
dialects and deciding whether to invoke delayed interface registration. Loading
is concentrated in the MLIRContext, and the functionality of the registry
better reflects its name.

Depends On D96137

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96331

Added: 
    

Modified: 
    flang/tools/tco/tco.cpp
    mlir/include/mlir/CAPI/Registration.h
    mlir/include/mlir/IR/Dialect.h
    mlir/include/mlir/IR/MLIRContext.h
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/CAPI/Registration/Registration.cpp
    mlir/lib/ExecutionEngine/JitRunner.cpp
    mlir/lib/IR/Dialect.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Support/MlirOptMain.cpp
    mlir/lib/Target/SPIRV/TranslateRegistration.cpp
    mlir/lib/Translation/Translation.cpp
    mlir/tools/mlir-reduce/mlir-reduce.cpp
    mlir/unittests/IR/DialectTest.cpp

Removed: 
    


################################################################################
diff  --git a/flang/tools/tco/tco.cpp b/flang/tools/tco/tco.cpp
index 5931d7b39206..49f35474048e 100644
--- a/flang/tools/tco/tco.cpp
+++ b/flang/tools/tco/tco.cpp
@@ -61,8 +61,9 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
   // load the file into a module
   SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc());
-  mlir::MLIRContext context;
-  fir::registerFIRDialects(context.getDialectRegistry());
+  mlir::DialectRegistry registry;
+  fir::registerFIRDialects(registry);
+  mlir::MLIRContext context(registry);
   auto owningRef = mlir::parseSourceFile(sourceMgr, &context);
 
   if (!owningRef) {

diff  --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h
index 7601f9fc0e63..ac909d1dd9da 100644
--- a/mlir/include/mlir/CAPI/Registration.h
+++ b/mlir/include/mlir/CAPI/Registration.h
@@ -35,7 +35,9 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
 
 #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName)      \
   static void mlirContextRegister##Name##Dialect(MlirContext context) {        \
-    unwrap(context)->getDialectRegistry().insert<ClassName>();                 \
+    mlir::DialectRegistry registry;                                            \
+    registry.insert<ClassName>();                                              \
+    unwrap(context)->appendDialectRegistry(registry);                          \
   }                                                                            \
   static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) {     \
     return wrap(unwrap(context)->getOrLoadDialect<ClassName>());               \

diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 978531f3098c..7798f92400b3 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -26,6 +26,7 @@ class OpBuilder;
 class Type;
 
 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
+using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
 using InterfaceAllocatorFunction =
     std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
 
@@ -241,8 +242,7 @@ class DialectRegistry {
       DenseMap<TypeID, SmallVector<InterfaceAllocatorFunction, 2>>;
 
 public:
-  explicit DialectRegistry(MLIRContext *context = nullptr)
-      : owningContext(context) {}
+  explicit DialectRegistry() {}
 
   template <typename ConcreteDialect>
   void insert() {
@@ -267,42 +267,37 @@ class DialectRegistry {
   /// ownership of the dialect and for delayed interface registration to happen.
   void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
 
-  /// Load a dialect for this namespace in the provided context.
-  Dialect *loadByName(StringRef name, MLIRContext *context);
+  /// Return an allocation function for constructing the dialect identified by
+  /// its namespace, or nullptr if the namespace is not in this registry.
+  DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
 
   // Register all dialects available in the current registry with the registry
   // in the provided context.
-  void appendTo(DialectRegistry &destination) {
+  void appendTo(DialectRegistry &destination) const {
     for (const auto &nameAndRegistrationIt : registry)
       destination.insert(nameAndRegistrationIt.second.first,
                          nameAndRegistrationIt.first,
                          nameAndRegistrationIt.second.second);
     destination.interfaces.insert(interfaces.begin(), interfaces.end());
   }
-  // Load all dialects available in the registry in the provided context.
-  void loadAll(MLIRContext *context) {
-    for (const auto &nameAndRegistrationIt : registry)
-      nameAndRegistrationIt.second.second(context);
-  }
 
   /// Return the names of dialects known to this registry.
-  auto getDialectNames() {
+  auto getDialectNames() const {
     return llvm::map_range(
-        registry, [](const MapTy::value_type &item) { return item.first; });
+        registry,
+        [](const MapTy::value_type &item) -> StringRef { return item.first; });
   }
 
   /// Add an interface constructed with the given allocation function to the
   /// dialect provided as template parameter. The dialect must be present in
-  /// the registry, but may or may not be loaded. If it is not loaded, the
-  /// interface registration is delayed until the loading.
+  /// the registry.
   template <typename DialectTy>
   void addDialectInterface(InterfaceAllocatorFunction allocator) {
     addDialectInterface(DialectTy::getDialectNamespace(), allocator);
   }
 
   /// Add an interface to the dialect, both provided as template parameter. The
-  /// dialect must be present in the registry, but may or may not be loaded. If
-  /// it is not loaded, the interface registration is delayed until the loading.
+  /// dialect must be present in the registry.
   template <typename DialectTy, typename InterfaceTy>
   void addDialectInterface() {
     addDialectInterface<DialectTy>([](Dialect *dialect) {
@@ -312,7 +307,7 @@ class DialectRegistry {
 
   /// Register any interfaces required for the given dialect (based on its
   /// TypeID). Users are not expected to call this directly.
-  void registerDelayedInterfaces(Dialect *dialect);
+  void registerDelayedInterfaces(Dialect *dialect) const;
 
 private:
   /// Add an interface constructed with the given allocation function to the
@@ -322,10 +317,6 @@ class DialectRegistry {
 
   MapTy registry;
   InterfaceMapTy interfaces;
-
-  /// If this registry belongs to a context, this points back to the context.
-  /// Useful for checking if a dialect is loaded in the context.
-  MLIRContext *owningContext;
 };
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index eace86f9cb7a..0a7d15070618 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -36,17 +36,19 @@ class StorageUniquer;
 class MLIRContext {
 public:
   /// Create a new Context.
-  /// The loadAllDialects parameters allows to load all dialects from the global
-  /// registry on Context construction. It is deprecated and will be removed
-  /// soon.
   explicit MLIRContext();
+  explicit MLIRContext(const DialectRegistry &registry);
   ~MLIRContext();
 
   /// Return information about all IR dialects loaded in the context.
   std::vector<Dialect *> getLoadedDialects();
 
   /// Return the dialect registry associated with this context.
-  DialectRegistry &getDialectRegistry();
+  const DialectRegistry &getDialectRegistry();
+
+  /// Append the contents of the given dialect registry to the registry
+  /// associated with this context.
+  void appendDialectRegistry(const DialectRegistry &registry);
 
   /// Return information about all available dialects in the registry in this
   /// context.
@@ -87,6 +89,9 @@ class MLIRContext {
     loadDialect<OtherDialect, MoreDialects...>();
   }
 
+  /// Load all dialects available in the registry in this context.
+  void loadAllAvailableDialects();
+
   /// Get (or create) a dialect for the given derived dialect name.
   /// The dialect will be loaded from the registry if no dialect is found.
   /// If no dialect is loaded for this name and none is available in the

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 7fd063d11464..7b7698832065 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -45,7 +45,7 @@
 
 namespace mlir {
 
-// Add all the MLIR dialects to the provided registry.
+/// Add all the MLIR dialects to the provided registry.
 inline void registerAllDialects(DialectRegistry &registry) {
   // clang-format off
   registry.insert<acc::OpenACCDialect,
@@ -78,6 +78,13 @@ inline void registerAllDialects(DialectRegistry &registry) {
   // clang-format on
 }
 
+/// Append all the MLIR dialects to the registry contained in the given context.
+inline void registerAllDialects(MLIRContext &context) {
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  context.appendDialectRegistry(registry);
+}
+
 } // namespace mlir
 
 #endif // MLIR_INITALLDIALECTS_H_

diff  --git a/mlir/lib/CAPI/Registration/Registration.cpp b/mlir/lib/CAPI/Registration/Registration.cpp
index 1d6294dbaaba..223188ac6ccd 100644
--- a/mlir/lib/CAPI/Registration/Registration.cpp
+++ b/mlir/lib/CAPI/Registration/Registration.cpp
@@ -12,7 +12,7 @@
 #include "mlir/InitAllDialects.h"
 
 void mlirRegisterAllDialects(MlirContext context) {
-  registerAllDialects(unwrap(context)->getDialectRegistry());
+  mlir::registerAllDialects(*unwrap(context));
   // TODO: we may not want to eagerly load here.
-  unwrap(context)->getDialectRegistry().loadAll(unwrap(context));
+  unwrap(context)->loadAllAvailableDialects();
 }

diff  --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index a69d19746a86..4139e4014f89 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -331,7 +331,7 @@ int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
   }
 
   MLIRContext context;
-  registerAllDialects(context.getDialectRegistry());
+  registerAllDialects(context);
 
   auto m = parseMLIRInput(options.inputFilename, &context);
   if (!m) {

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 01f8ec1f38cc..a77f31cc7f40 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -29,27 +29,18 @@ DialectAsmParser::~DialectAsmParser() {}
 void DialectRegistry::addDialectInterface(
     StringRef dialectName, InterfaceAllocatorFunction allocator) {
   assert(allocator && "unexpected null interface allocation function");
-
-  // If the dialect is already loaded, directly add the interface.
-  if (Dialect *dialect = owningContext
-                             ? owningContext->getLoadedDialect(dialectName)
-                             : nullptr) {
-    dialect->addInterface(allocator(dialect));
-    return;
-  }
-
-  // Otherwise, store it in the interface map for delayed registration.
   auto it = registry.find(dialectName.str());
   assert(it != registry.end() &&
          "adding an interface for an unregistered dialect");
   interfaces[it->second.first].push_back(allocator);
 }
 
-Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
+DialectAllocatorFunctionRef
+DialectRegistry::getDialectAllocator(StringRef name) const {
   auto it = registry.find(name.str());
   if (it == registry.end())
     return nullptr;
-  return it->second.second(context);
+  return it->second.second;
 }
 
 void DialectRegistry::insert(TypeID typeID, StringRef name,
@@ -63,7 +54,7 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
   }
 }
 
-void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) {
+void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
   auto it = interfaces.find(dialect->getTypeID());
   if (it == interfaces.end())
     return;

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 832eea747771..f637bb261958 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -326,8 +326,7 @@ class MLIRContextImpl {
   DictionaryAttr emptyDictionaryAttr;
 
 public:
-  MLIRContextImpl(MLIRContext *ctx)
-      : dialectsRegistry(ctx), identifiers(identifierAllocator) {}
+  MLIRContextImpl() : identifiers(identifierAllocator) {}
   ~MLIRContextImpl() {
     for (auto typeMapping : registeredTypes)
       typeMapping.second->~AbstractType();
@@ -337,7 +336,10 @@ class MLIRContextImpl {
 };
 } // end namespace mlir
 
-MLIRContext::MLIRContext() : impl(new MLIRContextImpl(this)) {
+MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {}
+
+MLIRContext::MLIRContext(const DialectRegistry &registry)
+    : impl(new MLIRContextImpl) {
   // Initialize values based on the command line flags if they were provided.
   if (clOptions.isConstructed()) {
     disableMultithreading(clOptions->disableThreading);
@@ -348,6 +350,9 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl(this)) {
   // Ensure the builtin dialect is always pre-loaded.
   getOrLoadDialect<BuiltinDialect>();
 
+  // Pre-populate the registry.
+  registry.appendTo(impl->dialectsRegistry);
+
   // Initialize several common attributes and types to avoid the need to lock
   // the context when accessing them.
 
@@ -424,7 +429,15 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
 // Dialect and Operation Registration
 //===----------------------------------------------------------------------===//
 
-DialectRegistry &MLIRContext::getDialectRegistry() {
+void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
+  registry.appendTo(impl->dialectsRegistry);
+
+  // For the already loaded dialects, register the interfaces immediately.
+  for (const auto &kvp : impl->loadedDialects)
+    registry.registerDelayedInterfaces(kvp.second.get());
+}
+
+const DialectRegistry &MLIRContext::getDialectRegistry() {
   return impl->dialectsRegistry;
 }
 
@@ -459,7 +472,9 @@ Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
   Dialect *dialect = getLoadedDialect(name);
   if (dialect)
     return dialect;
-  return impl->dialectsRegistry.loadByName(name, this);
+  DialectAllocatorFunctionRef allocator =
+      impl->dialectsRegistry.getDialectAllocator(name);
+  return allocator ? allocator(this) : nullptr;
 }
 
 /// Get a dialect for the provided namespace and TypeID: abort the program if a
@@ -507,6 +522,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
   return dialect.get();
 }
 
+void MLIRContext::loadAllAvailableDialects() {
+  for (StringRef name : getAvailableDialects())
+    getOrLoadDialect(name);
+}
+
 llvm::hash_code MLIRContext::getRegistryHash() {
   llvm::hash_code hash(0);
   // Factor in number of loaded dialects, attributes, operations, types.

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index d9f2e7f23508..f4779cb860a0 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -865,7 +865,9 @@ LogicalResult PassManager::run(Operation *op) {
   // Register all dialects for the current pipeline.
   DialectRegistry dependentDialects;
   getDependentDialects(dependentDialects);
-  dependentDialects.loadAll(context);
+  context->appendDialectRegistry(dependentDialects);
+  for (StringRef name : dependentDialects.getDialectNames())
+    context->getOrLoadDialect(name);
 
   // Initialize all of the passes within the pass manager with a new generation.
   llvm::hash_code newInitKey = context->getRegistryHash();

diff  --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 27968517987d..8f250669564b 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -95,10 +95,9 @@ static LogicalResult processBuffer(raw_ostream &os,
   sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
 
   // Parse the input file.
-  MLIRContext context;
-  registry.appendTo(context.getDialectRegistry());
+  MLIRContext context(registry);
   if (preloadDialectsInContext)
-    registry.loadAll(&context);
+    context.loadAllAvailableDialects();
   context.allowUnregisteredDialects(allowUnregisteredDialects);
   context.printOpOnDiagnostic(!verifyDiagnostics);
 

diff  --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
index 70af479f414b..703ef5ab71a7 100644
--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
+++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
@@ -136,8 +136,9 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
   if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo)))
     return failure();
 
-  MLIRContext deserializationContext;
-  context->getDialectRegistry().loadAll(&deserializationContext);
+  MLIRContext deserializationContext(context->getDialectRegistry());
+  // TODO: we should only load the required dialects instead of all dialects.
+  deserializationContext.loadAllAvailableDialects();
   // Then deserialize to get back a SPIR-V module.
   spirv::OwningSPIRVModuleRef spirvModule =
       spirv::deserialize(binary, &deserializationContext);

diff  --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp
index a797c5aa6060..d357a58ac892 100644
--- a/mlir/lib/Translation/Translation.cpp
+++ b/mlir/lib/Translation/Translation.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Translation.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Parser.h"
 #include "mlir/Support/FileUtilities.h"
@@ -97,7 +98,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
   registerTranslation(name, [function, dialectRegistration](
                                 llvm::SourceMgr &sourceMgr, raw_ostream &output,
                                 MLIRContext *context) {
-    dialectRegistration(context->getDialectRegistry());
+    DialectRegistry registry;
+    dialectRegistration(registry);
+    context->appendDialectRegistry(registry);
     auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
     if (!module)
       return failure();

diff  --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp
index 432e8dfe8468..d995683bb30c 100644
--- a/mlir/tools/mlir-reduce/mlir-reduce.cpp
+++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp
@@ -89,11 +89,12 @@ int main(int argc, char **argv) {
   if (!output)
     llvm::report_fatal_error(errorMessage);
 
-  mlir::MLIRContext context;
-  registerAllDialects(context.getDialectRegistry());
+  mlir::DialectRegistry registry;
+  registerAllDialects(registry);
 #ifdef MLIR_INCLUDE_TESTS
-  mlir::test::registerTestDialect(context.getDialectRegistry());
+  mlir::test::registerTestDialect(registry);
 #endif
+  mlir::MLIRContext context(registry);
 
   mlir::OwningModuleRef moduleRef;
   if (failed(loadModule(context, moduleRef, inputFilename)))

diff  --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index ed19558ef5ca..64d207bec453 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -65,8 +65,7 @@ TEST(Dialect, DelayedInterfaceRegistration) {
   // Delayed registration of an interface for TestDialect.
   registry.addDialectInterface<TestDialect, TestDialectInterface>();
 
-  MLIRContext context;
-  registry.appendTo(context.getDialectRegistry());
+  MLIRContext context(registry);
 
   // Load the TestDialect and check that the interface got registered for it.
   auto *testDialect = context.getOrLoadDialect<TestDialect>();
@@ -85,8 +84,11 @@ TEST(Dialect, DelayedInterfaceRegistration) {
 
   // Use the same mechanism as for delayed registration but for an already
   // loaded dialect and check that the interface is now registered.
-  context.getDialectRegistry()
+  DialectRegistry secondRegistry;
+  secondRegistry.insert<SecondTestDialect>();
+  secondRegistry
       .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
+  context.appendDialectRegistry(secondRegistry);
   secondTestDialectInterface =
       secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
   EXPECT_TRUE(secondTestDialectInterface != nullptr);


        


More information about the flang-commits mailing list