[Mlir-commits] [mlir] d7e8912 - [mlir] Enable delayed registration of attribute/operation/type interfaces

Alex Zinenko llvmlistbot at llvm.org
Thu Jun 17 04:19:32 PDT 2021


Author: Alex Zinenko
Date: 2021-06-17T13:19:24+02:00
New Revision: d7e891213444e2990397c623fb0250a470421fce

URL: https://github.com/llvm/llvm-project/commit/d7e891213444e2990397c623fb0250a470421fce
DIFF: https://github.com/llvm/llvm-project/commit/d7e891213444e2990397c623fb0250a470421fce.diff

LOG: [mlir] Enable delayed registration of attribute/operation/type interfaces

This functionality is similar to delayed registration of dialect interfaces. It
allows external interface models to be registered before the dialect containing
the attribute/operation/type interface is loaded, or even before the context is
created.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D104397

Added: 
    

Modified: 
    mlir/include/mlir/IR/AttributeSupport.h
    mlir/include/mlir/IR/Dialect.h
    mlir/include/mlir/IR/TypeSupport.h
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/lib/IR/Dialect.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/unittests/IR/InterfaceAttachmentTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index b10b54c6d3ca0..c84be620b8f0f 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -50,6 +50,12 @@ class AbstractAttribute {
     return interfaceMap.lookup<T>();
   }
 
+  /// Returns true if the attribute has the interface with the given ID
+  /// registered.
+  bool hasInterface(TypeID interfaceID) const {
+    return interfaceMap.contains(interfaceID);
+  }
+
   /// Return the unique identifier representing the concrete attribute class.
   TypeID getTypeID() const { return typeID; }
 

diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 46782c4353d51..b9c2e4619ed97 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -27,8 +27,9 @@ class Type;
 
 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
 using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
-using InterfaceAllocatorFunction =
+using DialectInterfaceAllocatorFunction =
     std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
+using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
 
 /// Dialects are groups of MLIR operations, types and attributes, as well as
 /// behavior associated with the entire group.  For example, hooks into other
@@ -278,11 +279,19 @@ class Dialect {
 /// dialects loaded in the Context. The parser in particular will lazily load
 /// dialects in the Context as operations are encountered.
 class DialectRegistry {
+  /// Lists of interfaces that need to be registered when the dialect is loaded.
+  struct DelayedInterfaces {
+    /// Dialect interfaces.
+    SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
+        dialectInterfaces;
+    /// Attribute/Operation/Type interfaces.
+    SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
+        objectInterfaces;
+  };
+
   using MapTy =
       std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
-  using InterfaceMapTy =
-      DenseMap<TypeID,
-               SmallVector<std::pair<TypeID, InterfaceAllocatorFunction>, 2>>;
+  using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
 
 public:
   explicit DialectRegistry() {}
@@ -336,7 +345,7 @@ class DialectRegistry {
   /// the registry.
   template <typename DialectTy>
   void addDialectInterface(TypeID interfaceTypeID,
-                           InterfaceAllocatorFunction allocator) {
+                           DialectInterfaceAllocatorFunction allocator) {
     addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
                         allocator);
   }
@@ -351,6 +360,36 @@ class DialectRegistry {
         });
   }
 
+  /// Add an external op interface model for an op that belongs to a dialect,
+  /// both provided as template parameters. The dialect must be present in the
+  /// registry.
+  template <typename OpTy, typename ModelTy>
+  void addOpInterface() {
+    StringRef opName = OpTy::getOperationName();
+    StringRef dialectName = opName.split('.').first;
+    addObjectInterface(dialectName == opName ? "" : dialectName,
+                       ModelTy::Interface::getInterfaceID(),
+                       [](MLIRContext *context) {
+                         OpTy::template attachInterface<ModelTy>(*context);
+                       });
+  }
+
+  /// Add an external attribute interface model for an attribute type `AttrTy`
+  /// that is going to belong to `DialectTy`. The dialect must be present in the
+  /// registry.
+  template <typename DialectTy, typename AttrTy, typename ModelTy>
+  void addAttrInterface() {
+    addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
+  }
+
+  /// Add an external type interface model for an type class `TypeTy` that is
+  /// going to belong to `DialectTy`. The dialect must be present in the
+  /// registry.
+  template <typename DialectTy, typename TypeTy, typename ModelTy>
+  void addTypeInterface() {
+    addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
+  }
+
   /// Register any interfaces required for the given dialect (based on its
   /// TypeID). Users are not expected to call this directly.
   void registerDelayedInterfaces(Dialect *dialect) const;
@@ -359,7 +398,22 @@ class DialectRegistry {
   /// Add an interface constructed with the given allocation function to the
   /// dialect identified by its namespace.
   void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
-                           InterfaceAllocatorFunction allocator);
+                           DialectInterfaceAllocatorFunction allocator);
+
+  /// Add an attribute/operation/type interface constructible with the given
+  /// allocation function to the dialect identified by its namespace.
+  void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID,
+                          ObjectInterfaceAllocatorFunction allocator);
+
+  /// Add an external model for an attribute/type interface to the dialect
+  /// identified by its namespace.
+  template <typename ObjectTy, typename ModelTy>
+  void addStorageUserInterface(StringRef dialectName) {
+    addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
+                       [](MLIRContext *context) {
+                         ObjectTy::template attachInterface<ModelTy>(*context);
+                       });
+  }
 
   MapTy registry;
   InterfaceMapTy interfaces;

diff  --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index df05e9ab198b7..40113c41fc229 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -58,6 +58,11 @@ class AbstractType {
     return interfaceMap.lookup<T>();
   }
 
+  /// Returns true if the type has the interface with the given ID.
+  bool hasInterface(TypeID interfaceID) const {
+    return interfaceMap.contains(interfaceID);
+  }
+
   /// Return the unique identifier representing the concrete type class.
   TypeID getTypeID() const { return typeID; }
 

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index a49d96d576278..0161866909a29 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -16,6 +16,7 @@
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/TypeName.h"
 
 namespace mlir {
@@ -236,8 +237,10 @@ class InterfaceMap {
           llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
             return compare(it.first, id);
           });
-      if (it != interfaces.end() && it->first == id)
-        llvm::report_fatal_error("Interface already registered");
+      if (it != interfaces.end() && it->first == id) {
+        LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
+        continue;
+      }
       interfaces.insert(it, element);
     }
   }

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 612c902d47079..f7c1883451b60 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectInterface.h"
@@ -31,7 +32,7 @@ DialectAsmParser::~DialectAsmParser() {}
 
 void DialectRegistry::addDialectInterface(
     StringRef dialectName, TypeID interfaceTypeID,
-    InterfaceAllocatorFunction allocator) {
+    DialectInterfaceAllocatorFunction allocator) {
   assert(allocator && "unexpected null interface allocation function");
   auto it = registry.find(dialectName.str());
   assert(it != registry.end() &&
@@ -40,8 +41,8 @@ void DialectRegistry::addDialectInterface(
   // Bail out if the interface with the given ID is already in the registry for
   // the given dialect. We expect a small number (dozens) of interfaces so a
   // linear search is fine here.
-  auto &dialectInterfaces = interfaces[it->second.first];
-  for (const auto &kvp : dialectInterfaces) {
+  auto &ifaces = interfaces[it->second.first];
+  for (const auto &kvp : ifaces.dialectInterfaces) {
     if (kvp.first == interfaceTypeID) {
       LLVM_DEBUG(llvm::dbgs()
                  << "[" DEBUG_TYPE
@@ -51,7 +52,36 @@ void DialectRegistry::addDialectInterface(
     }
   }
 
-  dialectInterfaces.emplace_back(interfaceTypeID, allocator);
+  ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
+}
+
+void DialectRegistry::addObjectInterface(
+    StringRef dialectName, TypeID interfaceTypeID,
+    ObjectInterfaceAllocatorFunction allocator) {
+  assert(allocator && "unexpected null interface allocation function");
+
+  // Builtin dialect has an empty prefix and is always registered.
+  TypeID dialectTypeID;
+  if (!dialectName.empty()) {
+    auto it = registry.find(dialectName.str());
+    assert(it != registry.end() &&
+           "adding an interface for an op from an unregistered dialect");
+    dialectTypeID = it->second.first;
+  } else {
+    dialectTypeID = TypeID::get<BuiltinDialect>();
+  }
+
+  auto &ifaces = interfaces[dialectTypeID];
+  for (const auto &kvp : ifaces.objectInterfaces) {
+    if (kvp.first == interfaceTypeID) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "[" DEBUG_TYPE
+                    "] repeated interface object interface registration");
+      return;
+    }
+  }
+
+  ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
 }
 
 DialectAllocatorFunctionRef
@@ -79,11 +109,15 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
     return;
 
   // Add an interface if it is not already present.
-  for (const auto &kvp : it->second) {
+  for (const auto &kvp : it->getSecond().dialectInterfaces) {
     if (dialect->getRegisteredInterface(kvp.first))
       continue;
     dialect->addInterface(kvp.second(dialect));
   }
+
+  // Add attribute, operation and type interfaces.
+  for (const auto &kvp : it->getSecond().objectInterfaces)
+    kvp.second(dialect->getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index da4b08ccd8825..ab12e5ff56693 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -356,12 +356,12 @@ MLIRContext::MLIRContext(const DialectRegistry &registry)
     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
   }
 
-  // Ensure the builtin dialect is always pre-loaded.
-  getOrLoadDialect<BuiltinDialect>();
-
   // Pre-populate the registry.
   registry.appendTo(impl->dialectsRegistry);
 
+  // Ensure the builtin dialect is always pre-loaded.
+  getOrLoadDialect<BuiltinDialect>();
+
   // Initialize several common attributes and types to avoid the need to lock
   // the context when accessing them.
 

diff  --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index 6ad65438ae6c3..10cc6f85d6480 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "gtest/gtest.h"
@@ -87,6 +88,74 @@ TEST(InterfaceAttachment, Type) {
   EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
 }
 
+/// External interface model for the test type from the test dialect.
+struct TestTypeModel
+    : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
+                                                      test::TestType> {
+  unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
+
+  static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
+};
+
+TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
+  // Put the interface in the registry.
+  DialectRegistry registry;
+  registry.insert<test::TestDialect>();
+  registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
+
+  // Check that when a context is constructed with the given registry, the type
+  // interface gets registered.
+  MLIRContext context(registry);
+  context.loadDialect<test::TestDialect>();
+  test::TestType testType = test::TestType::get(&context);
+  auto iface = testType.dyn_cast<TestExternalTypeInterface>();
+  ASSERT_TRUE(iface != nullptr);
+  EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
+  EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
+}
+
+TEST(InterfaceAttachment, TypeDelayedContextAppend) {
+  // Put the interface in the registry.
+  DialectRegistry registry;
+  registry.insert<test::TestDialect>();
+  registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
+
+  // Check that when the registry gets appended to the context, the interface
+  // becomes available for objects in loaded dialects.
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+  test::TestType testType = test::TestType::get(&context);
+  EXPECT_FALSE(testType.isa<TestExternalTypeInterface>());
+  context.appendDialectRegistry(registry);
+  EXPECT_TRUE(testType.isa<TestExternalTypeInterface>());
+}
+
+TEST(InterfaceAttachment, RepeatedRegistration) {
+  DialectRegistry registry;
+  registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
+  MLIRContext context(registry);
+
+  // Should't fail on repeated registration through the dialect registry.
+  context.appendDialectRegistry(registry);
+}
+
+TEST(InterfaceAttachment, TypeBuiltinDelayed) {
+  // Builtin dialect needs to registration or loading, but delayed interface
+  // registration must still work.
+  DialectRegistry registry;
+  registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
+
+  MLIRContext context(registry);
+  IntegerType i16 = IntegerType::get(&context, 16);
+  EXPECT_TRUE(i16.isa<TestExternalTypeInterface>());
+
+  MLIRContext initiallyEmpty;
+  IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
+  EXPECT_FALSE(i32.isa<TestExternalTypeInterface>());
+  initiallyEmpty.appendDialectRegistry(registry);
+  EXPECT_TRUE(i32.isa<TestExternalTypeInterface>());
+}
+
 /// The interface provides a default implementation that expects
 /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
 /// just derives from the ExternalModel.
@@ -128,9 +197,9 @@ TEST(InterfaceAttachment, Fallback) {
 }
 
 /// External model for attribute interfaces.
-struct TextExternalIntegerAttrModel
+struct TestExternalIntegerAttrModel
     : public TestExternalAttrInterface::ExternalModel<
-          TextExternalIntegerAttrModel, IntegerAttr> {
+          TestExternalIntegerAttrModel, IntegerAttr> {
   const Dialect *getDialectPtr(Attribute attr) const {
     return &attr.cast<IntegerAttr>().getDialect();
   }
@@ -145,13 +214,45 @@ TEST(InterfaceAttachment, Attribute) {
   // that the basics work for attributes.
   IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
   ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
-  IntegerAttr::attachInterface<TextExternalIntegerAttrModel>(context);
+  IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
   auto iface = attr.dyn_cast<TestExternalAttrInterface>();
   ASSERT_TRUE(iface != nullptr);
   EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
   EXPECT_EQ(iface.getSomeNumber(), 42);
 }
 
+/// External model for an interface attachable to a non-builtin attribute.
+struct TestExternalSimpleAAttrModel
+    : public TestExternalAttrInterface::ExternalModel<
+          TestExternalSimpleAAttrModel, test::SimpleAAttr> {
+  const Dialect *getDialectPtr(Attribute attr) const {
+    return &attr.getDialect();
+  }
+
+  static int getSomeNumber() { return 21; }
+};
+
+TEST(InterfaceAttachmentTest, AttributeDelayed) {
+  // Attribute interfaces use the exact same mechanism as types, so just check
+  // that the delayed registration work for attributes.
+  DialectRegistry registry;
+  registry.insert<test::TestDialect>();
+  registry.addAttrInterface<test::TestDialect, test::SimpleAAttr,
+                            TestExternalSimpleAAttrModel>();
+
+  MLIRContext context(registry);
+  context.loadDialect<test::TestDialect>();
+  auto attr = test::SimpleAAttr::get(&context);
+  EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
+
+  MLIRContext initiallyEmpty;
+  initiallyEmpty.loadDialect<test::TestDialect>();
+  attr = test::SimpleAAttr::get(&initiallyEmpty);
+  EXPECT_FALSE(attr.isa<TestExternalAttrInterface>());
+  initiallyEmpty.appendDialectRegistry(registry);
+  EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
+}
+
 /// External interface model for the module operation. Only provides non-default
 /// methods.
 struct TestExternalOpModel
@@ -220,4 +321,55 @@ TEST(InterfaceAttachment, Operation) {
   ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
 }
 
+struct TestExternalTestOpModel
+    : public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel,
+                                                    test::OpJ> {
+  unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
+    return op->getName().getStringRef().size() + arg;
+  }
+
+  static unsigned getNameLengthPlusArgTwice(unsigned arg) {
+    return test::OpJ::getOperationName().size() + 2 * arg;
+  }
+};
+
+TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
+  DialectRegistry registry;
+  registry.insert<test::TestDialect>();
+  registry.addOpInterface<ModuleOp, TestExternalOpModel>();
+  registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
+
+  // Construct the context directly from a registry. The interfaces are expected
+  // to be readily available on operations.
+  MLIRContext context(registry);
+  context.loadDialect<test::TestDialect>();
+  ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
+  OpBuilder builder(module);
+  auto op =
+      builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
+  EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
+  EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
+}
+
+TEST(InterfaceAttachment, OperationDelayedContextAppend) {
+  DialectRegistry registry;
+  registry.insert<test::TestDialect>();
+  registry.addOpInterface<ModuleOp, TestExternalOpModel>();
+  registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
+
+  // Construct the context, create ops, and only then append the registry. The
+  // interfaces are expected to be available after appending the registry.
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+  ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
+  OpBuilder builder(module);
+  auto op =
+      builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
+  EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation()));
+  EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation()));
+  context.appendDialectRegistry(registry);
+  EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
+  EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list