[Mlir-commits] [mlir] 9b50844 - [mlir] Fix delayed object interfaces registration

Vladislav Vinogradov llvmlistbot at llvm.org
Tue Aug 3 02:21:05 PDT 2021


Author: Vladislav Vinogradov
Date: 2021-08-03T12:21:55+03:00
New Revision: 9b50844fd798b5a81afd4aeb44b053d622747a42

URL: https://github.com/llvm/llvm-project/commit/9b50844fd798b5a81afd4aeb44b053d622747a42
DIFF: https://github.com/llvm/llvm-project/commit/9b50844fd798b5a81afd4aeb44b053d622747a42.diff

LOG: [mlir] Fix delayed object interfaces registration

Store both interfaceID and objectID as key for interface registration callback.
Otherwise the implementation allows to register only one external model per one object in the single dialect.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D107274

Added: 
    

Modified: 
    mlir/include/mlir/IR/Dialect.h
    mlir/lib/IR/Dialect.cpp
    mlir/unittests/IR/InterfaceAttachmentTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index e16b3e47a0140..f615819fd16bb 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -17,6 +17,7 @@
 #include "mlir/Support/TypeID.h"
 
 #include <map>
+#include <tuple>
 
 namespace mlir {
 class DialectAsmParser;
@@ -285,7 +286,7 @@ class DialectRegistry {
     SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
         dialectInterfaces;
     /// Attribute/Operation/Type interfaces.
-    SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
+    SmallVector<std::tuple<TypeID, TypeID, ObjectInterfaceAllocatorFunction>, 2>
         objectInterfaces;
   };
 
@@ -367,7 +368,8 @@ class DialectRegistry {
   void addOpInterface() {
     StringRef opName = OpTy::getOperationName();
     StringRef dialectName = opName.split('.').first;
-    addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
+    addObjectInterface(dialectName, TypeID::get<OpTy>(),
+                       ModelTy::Interface::getInterfaceID(),
                        [](MLIRContext *context) {
                          OpTy::template attachInterface<ModelTy>(*context);
                        });
@@ -401,14 +403,16 @@ class DialectRegistry {
 
   /// Add an attribute/operation/type interface constructible with the given
   /// allocation function to the dialect identified by its namespace.
-  void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID,
+  void addObjectInterface(StringRef dialectName, TypeID objectID,
+                          TypeID interfaceTypeID,
                           ObjectInterfaceAllocatorFunction allocator);
 
   /// Add an external model for an attribute/type interface to the dialect
   /// identified by its namespace.
   template <typename ObjectTy, typename ModelTy>
   void addStorageUserInterface(StringRef dialectName) {
-    addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
+    addObjectInterface(dialectName, TypeID::get<ObjectTy>(),
+                       ModelTy::Interface::getInterfaceID(),
                        [](MLIRContext *context) {
                          ObjectTy::template attachInterface<ModelTy>(*context);
                        });

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 4713463124d92..80c8dabe1f3b9 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -58,16 +58,19 @@ void DialectRegistry::addDialectInterface(
 }
 
 void DialectRegistry::addObjectInterface(
-    StringRef dialectName, TypeID interfaceTypeID,
+    StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
     ObjectInterfaceAllocatorFunction allocator) {
   assert(allocator && "unexpected null interface allocation function");
+
   auto it = registry.find(dialectName.str());
   assert(it != registry.end() &&
          "adding an interface for an op from an unregistered dialect");
 
-  auto &ifaces = interfaces[it->second.first];
-  for (const auto &kvp : ifaces.objectInterfaces) {
-    if (kvp.first == interfaceTypeID) {
+  auto dialectID = it->second.first;
+  auto &ifaces = interfaces[dialectID];
+
+  for (const auto &info : ifaces.objectInterfaces) {
+    if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
       LLVM_DEBUG(llvm::dbgs()
                  << "[" DEBUG_TYPE
                     "] repeated interface object interface registration");
@@ -75,7 +78,7 @@ void DialectRegistry::addObjectInterface(
     }
   }
 
-  ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
+  ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
 }
 
 DialectAllocatorFunctionRef
@@ -110,8 +113,8 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
   }
 
   // Add attribute, operation and type interfaces.
-  for (const auto &kvp : it->getSecond().objectInterfaces)
-    kvp.second(dialect->getContext());
+  for (const auto &info : it->getSecond().objectInterfaces)
+    std::get<2>(info)(dialect->getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index b83e5a0bf2f77..76124707cbfc7 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -321,15 +321,16 @@ TEST(InterfaceAttachment, Operation) {
   ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
 }
 
+template <class ConcreteOp>
 struct TestExternalTestOpModel
-    : public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel,
-                                                    test::OpJ> {
+    : public TestExternalOpInterface::ExternalModel<
+          TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
   unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
     return op->getName().getStringRef().size() + arg;
   }
 
   static unsigned getNameLengthPlusArgTwice(unsigned arg) {
-    return test::OpJ::getOperationName().size() + 2 * arg;
+    return ConcreteOp::getOperationName().size() + 2 * arg;
   }
 };
 
@@ -337,39 +338,61 @@ TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
   DialectRegistry registry;
   registry.insert<test::TestDialect>();
   registry.addOpInterface<ModuleOp, TestExternalOpModel>();
-  registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
+  registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
+  registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
 
   // Construct the context directly from a registry. The interfaces are expected
   // to be readily available on operations.
   MLIRContext context(registry);
   context.loadDialect<test::TestDialect>();
+
   ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
   OpBuilder builder(module);
-  auto op =
+  auto opJ =
       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
+  auto opH =
+      builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
+  auto opI =
+      builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
+
   EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
-  EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
+  EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
+  EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
+  EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
 }
 
 TEST(InterfaceAttachment, OperationDelayedContextAppend) {
   DialectRegistry registry;
   registry.insert<test::TestDialect>();
   registry.addOpInterface<ModuleOp, TestExternalOpModel>();
-  registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
+  registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
+  registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
 
   // Construct the context, create ops, and only then append the registry. The
   // interfaces are expected to be available after appending the registry.
   MLIRContext context;
   context.loadDialect<test::TestDialect>();
+
   ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
   OpBuilder builder(module);
-  auto op =
+  auto opJ =
       builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
+  auto opH =
+      builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
+  auto opI =
+      builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
+
   EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation()));
-  EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation()));
+  EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
+  EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
+  EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
+
   context.appendDialectRegistry(registry);
+
   EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
-  EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
+  EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
+  EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
+  EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
 }
 
 } // end namespace


        


More information about the Mlir-commits mailing list