[Mlir-commits] [mlir] d7e8912 - [mlir] Enable delayed registration of attribute/operation/type interfaces
Alex Zinenko
llvmlistbot at llvm.org
Thu Jun 17 04:19:32 PDT 2021
Author: Alex Zinenko
Date: 2021-06-17T13:19:24+02:00
New Revision: d7e891213444e2990397c623fb0250a470421fce
URL: https://github.com/llvm/llvm-project/commit/d7e891213444e2990397c623fb0250a470421fce
DIFF: https://github.com/llvm/llvm-project/commit/d7e891213444e2990397c623fb0250a470421fce.diff
LOG: [mlir] Enable delayed registration of attribute/operation/type interfaces
This functionality is similar to delayed registration of dialect interfaces. It
allows external interface models to be registered before the dialect containing
the attribute/operation/type interface is loaded, or even before the context is
created.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D104397
Added:
Modified:
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/Support/InterfaceSupport.h
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/unittests/IR/InterfaceAttachmentTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index b10b54c6d3ca0..c84be620b8f0f 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -50,6 +50,12 @@ class AbstractAttribute {
return interfaceMap.lookup<T>();
}
+ /// Returns true if the attribute has the interface with the given ID
+ /// registered.
+ bool hasInterface(TypeID interfaceID) const {
+ return interfaceMap.contains(interfaceID);
+ }
+
/// Return the unique identifier representing the concrete attribute class.
TypeID getTypeID() const { return typeID; }
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 46782c4353d51..b9c2e4619ed97 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -27,8 +27,9 @@ class Type;
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
-using InterfaceAllocatorFunction =
+using DialectInterfaceAllocatorFunction =
std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
+using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
/// Dialects are groups of MLIR operations, types and attributes, as well as
/// behavior associated with the entire group. For example, hooks into other
@@ -278,11 +279,19 @@ class Dialect {
/// dialects loaded in the Context. The parser in particular will lazily load
/// dialects in the Context as operations are encountered.
class DialectRegistry {
+ /// Lists of interfaces that need to be registered when the dialect is loaded.
+ struct DelayedInterfaces {
+ /// Dialect interfaces.
+ SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
+ dialectInterfaces;
+ /// Attribute/Operation/Type interfaces.
+ SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
+ objectInterfaces;
+ };
+
using MapTy =
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
- using InterfaceMapTy =
- DenseMap<TypeID,
- SmallVector<std::pair<TypeID, InterfaceAllocatorFunction>, 2>>;
+ using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
public:
explicit DialectRegistry() {}
@@ -336,7 +345,7 @@ class DialectRegistry {
/// the registry.
template <typename DialectTy>
void addDialectInterface(TypeID interfaceTypeID,
- InterfaceAllocatorFunction allocator) {
+ DialectInterfaceAllocatorFunction allocator) {
addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
allocator);
}
@@ -351,6 +360,36 @@ class DialectRegistry {
});
}
+ /// Add an external op interface model for an op that belongs to a dialect,
+ /// both provided as template parameters. The dialect must be present in the
+ /// registry.
+ template <typename OpTy, typename ModelTy>
+ void addOpInterface() {
+ StringRef opName = OpTy::getOperationName();
+ StringRef dialectName = opName.split('.').first;
+ addObjectInterface(dialectName == opName ? "" : dialectName,
+ ModelTy::Interface::getInterfaceID(),
+ [](MLIRContext *context) {
+ OpTy::template attachInterface<ModelTy>(*context);
+ });
+ }
+
+ /// Add an external attribute interface model for an attribute type `AttrTy`
+ /// that is going to belong to `DialectTy`. The dialect must be present in the
+ /// registry.
+ template <typename DialectTy, typename AttrTy, typename ModelTy>
+ void addAttrInterface() {
+ addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
+ }
+
+ /// Add an external type interface model for an type class `TypeTy` that is
+ /// going to belong to `DialectTy`. The dialect must be present in the
+ /// registry.
+ template <typename DialectTy, typename TypeTy, typename ModelTy>
+ void addTypeInterface() {
+ addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
+ }
+
/// Register any interfaces required for the given dialect (based on its
/// TypeID). Users are not expected to call this directly.
void registerDelayedInterfaces(Dialect *dialect) const;
@@ -359,7 +398,22 @@ class DialectRegistry {
/// Add an interface constructed with the given allocation function to the
/// dialect identified by its namespace.
void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
- InterfaceAllocatorFunction allocator);
+ DialectInterfaceAllocatorFunction allocator);
+
+ /// Add an attribute/operation/type interface constructible with the given
+ /// allocation function to the dialect identified by its namespace.
+ void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID,
+ ObjectInterfaceAllocatorFunction allocator);
+
+ /// Add an external model for an attribute/type interface to the dialect
+ /// identified by its namespace.
+ template <typename ObjectTy, typename ModelTy>
+ void addStorageUserInterface(StringRef dialectName) {
+ addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
+ [](MLIRContext *context) {
+ ObjectTy::template attachInterface<ModelTy>(*context);
+ });
+ }
MapTy registry;
InterfaceMapTy interfaces;
diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index df05e9ab198b7..40113c41fc229 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -58,6 +58,11 @@ class AbstractType {
return interfaceMap.lookup<T>();
}
+ /// Returns true if the type has the interface with the given ID.
+ bool hasInterface(TypeID interfaceID) const {
+ return interfaceMap.contains(interfaceID);
+ }
+
/// Return the unique identifier representing the concrete type class.
TypeID getTypeID() const { return typeID; }
diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index a49d96d576278..0161866909a29 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -16,6 +16,7 @@
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/TypeName.h"
namespace mlir {
@@ -236,8 +237,10 @@ class InterfaceMap {
llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
return compare(it.first, id);
});
- if (it != interfaces.end() && it->first == id)
- llvm::report_fatal_error("Interface already registered");
+ if (it != interfaces.end() && it->first == id) {
+ LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
+ continue;
+ }
interfaces.insert(it, element);
}
}
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 612c902d47079..f7c1883451b60 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
@@ -31,7 +32,7 @@ DialectAsmParser::~DialectAsmParser() {}
void DialectRegistry::addDialectInterface(
StringRef dialectName, TypeID interfaceTypeID,
- InterfaceAllocatorFunction allocator) {
+ DialectInterfaceAllocatorFunction allocator) {
assert(allocator && "unexpected null interface allocation function");
auto it = registry.find(dialectName.str());
assert(it != registry.end() &&
@@ -40,8 +41,8 @@ void DialectRegistry::addDialectInterface(
// Bail out if the interface with the given ID is already in the registry for
// the given dialect. We expect a small number (dozens) of interfaces so a
// linear search is fine here.
- auto &dialectInterfaces = interfaces[it->second.first];
- for (const auto &kvp : dialectInterfaces) {
+ auto &ifaces = interfaces[it->second.first];
+ for (const auto &kvp : ifaces.dialectInterfaces) {
if (kvp.first == interfaceTypeID) {
LLVM_DEBUG(llvm::dbgs()
<< "[" DEBUG_TYPE
@@ -51,7 +52,36 @@ void DialectRegistry::addDialectInterface(
}
}
- dialectInterfaces.emplace_back(interfaceTypeID, allocator);
+ ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
+}
+
+void DialectRegistry::addObjectInterface(
+ StringRef dialectName, TypeID interfaceTypeID,
+ ObjectInterfaceAllocatorFunction allocator) {
+ assert(allocator && "unexpected null interface allocation function");
+
+ // Builtin dialect has an empty prefix and is always registered.
+ TypeID dialectTypeID;
+ if (!dialectName.empty()) {
+ auto it = registry.find(dialectName.str());
+ assert(it != registry.end() &&
+ "adding an interface for an op from an unregistered dialect");
+ dialectTypeID = it->second.first;
+ } else {
+ dialectTypeID = TypeID::get<BuiltinDialect>();
+ }
+
+ auto &ifaces = interfaces[dialectTypeID];
+ for (const auto &kvp : ifaces.objectInterfaces) {
+ if (kvp.first == interfaceTypeID) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "[" DEBUG_TYPE
+ "] repeated interface object interface registration");
+ return;
+ }
+ }
+
+ ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
}
DialectAllocatorFunctionRef
@@ -79,11 +109,15 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
return;
// Add an interface if it is not already present.
- for (const auto &kvp : it->second) {
+ for (const auto &kvp : it->getSecond().dialectInterfaces) {
if (dialect->getRegisteredInterface(kvp.first))
continue;
dialect->addInterface(kvp.second(dialect));
}
+
+ // Add attribute, operation and type interfaces.
+ for (const auto &kvp : it->getSecond().objectInterfaces)
+ kvp.second(dialect->getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index da4b08ccd8825..ab12e5ff56693 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -356,12 +356,12 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry)
printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
}
- // Ensure the builtin dialect is always pre-loaded.
- getOrLoadDialect<BuiltinDialect>();
-
// Pre-populate the registry.
registry.appendTo(impl->dialectsRegistry);
+ // Ensure the builtin dialect is always pre-loaded.
+ getOrLoadDialect<BuiltinDialect>();
+
// Initialize several common attributes and types to avoid the need to lock
// the context when accessing them.
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index 6ad65438ae6c3..10cc6f85d6480 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "gtest/gtest.h"
@@ -87,6 +88,74 @@ TEST(InterfaceAttachment, Type) {
EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
}
+/// External interface model for the test type from the test dialect.
+struct TestTypeModel
+ : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
+ test::TestType> {
+ unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
+
+ static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
+};
+
+TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
+ // Put the interface in the registry.
+ DialectRegistry registry;
+ registry.insert<test::TestDialect>();
+ registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
+
+ // Check that when a context is constructed with the given registry, the type
+ // interface gets registered.
+ MLIRContext context(registry);
+ context.loadDialect<test::TestDialect>();
+ test::TestType testType = test::TestType::get(&context);
+ auto iface = testType.dyn_cast<TestExternalTypeInterface>();
+ ASSERT_TRUE(iface != nullptr);
+ EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
+ EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
+}
+
+TEST(InterfaceAttachment, TypeDelayedContextAppend) {
+ // Put the interface in the registry.
+ DialectRegistry registry;
+ registry.insert<test::TestDialect>();
+ registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
+
+ // Check that when the registry gets appended to the context, the interface
+ // becomes available for objects in loaded dialects.
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+ test::TestType testType = test::TestType::get(&context);
+ EXPECT_FALSE(testType.isa<TestExternalTypeInterface>());
+ context.appendDialectRegistry(registry);
+ EXPECT_TRUE(testType.isa<TestExternalTypeInterface>());
+}
+
+TEST(InterfaceAttachment, RepeatedRegistration) {
+ DialectRegistry registry;
+ registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
+ MLIRContext context(registry);
+
+ // Should't fail on repeated registration through the dialect registry.
+ context.appendDialectRegistry(registry);
+}
+
+TEST(InterfaceAttachment, TypeBuiltinDelayed) {
+ // Builtin dialect needs to registration or loading, but delayed interface
+ // registration must still work.
+ DialectRegistry registry;
+ registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
+
+ MLIRContext context(registry);
+ IntegerType i16 = IntegerType::get(&context, 16);
+ EXPECT_TRUE(i16.isa<TestExternalTypeInterface>());
+
+ MLIRContext initiallyEmpty;
+ IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
+ EXPECT_FALSE(i32.isa<TestExternalTypeInterface>());
+ initiallyEmpty.appendDialectRegistry(registry);
+ EXPECT_TRUE(i32.isa<TestExternalTypeInterface>());
+}
+
/// The interface provides a default implementation that expects
/// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
/// just derives from the ExternalModel.
@@ -128,9 +197,9 @@ TEST(InterfaceAttachment, Fallback) {
}
/// External model for attribute interfaces.
-struct TextExternalIntegerAttrModel
+struct TestExternalIntegerAttrModel
: public TestExternalAttrInterface::ExternalModel<
- TextExternalIntegerAttrModel, IntegerAttr> {
+ TestExternalIntegerAttrModel, IntegerAttr> {
const Dialect *getDialectPtr(Attribute attr) const {
return &attr.cast<IntegerAttr>().getDialect();
}
@@ -145,13 +214,45 @@ TEST(InterfaceAttachment, Attribute) {
// that the basics work for attributes.
IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
- IntegerAttr::attachInterface<TextExternalIntegerAttrModel>(context);
+ IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
auto iface = attr.dyn_cast<TestExternalAttrInterface>();
ASSERT_TRUE(iface != nullptr);
EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
EXPECT_EQ(iface.getSomeNumber(), 42);
}
+/// External model for an interface attachable to a non-builtin attribute.
+struct TestExternalSimpleAAttrModel
+ : public TestExternalAttrInterface::ExternalModel<
+ TestExternalSimpleAAttrModel, test::SimpleAAttr> {
+ const Dialect *getDialectPtr(Attribute attr) const {
+ return &attr.getDialect();
+ }
+
+ static int getSomeNumber() { return 21; }
+};
+
+TEST(InterfaceAttachmentTest, AttributeDelayed) {
+ // Attribute interfaces use the exact same mechanism as types, so just check
+ // that the delayed registration work for attributes.
+ DialectRegistry registry;
+ registry.insert<test::TestDialect>();
+ registry.addAttrInterface<test::TestDialect, test::SimpleAAttr,
+ TestExternalSimpleAAttrModel>();
+
+ MLIRContext context(registry);
+ context.loadDialect<test::TestDialect>();
+ auto attr = test::SimpleAAttr::get(&context);
+ EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
+
+ MLIRContext initiallyEmpty;
+ initiallyEmpty.loadDialect<test::TestDialect>();
+ attr = test::SimpleAAttr::get(&initiallyEmpty);
+ EXPECT_FALSE(attr.isa<TestExternalAttrInterface>());
+ initiallyEmpty.appendDialectRegistry(registry);
+ EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
+}
+
/// External interface model for the module operation. Only provides non-default
/// methods.
struct TestExternalOpModel
@@ -220,4 +321,55 @@ TEST(InterfaceAttachment, Operation) {
ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
}
+struct TestExternalTestOpModel
+ : public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel,
+ test::OpJ> {
+ unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
+ return op->getName().getStringRef().size() + arg;
+ }
+
+ static unsigned getNameLengthPlusArgTwice(unsigned arg) {
+ return test::OpJ::getOperationName().size() + 2 * arg;
+ }
+};
+
+TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
+ DialectRegistry registry;
+ registry.insert<test::TestDialect>();
+ registry.addOpInterface<ModuleOp, TestExternalOpModel>();
+ registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
+
+ // Construct the context directly from a registry. The interfaces are expected
+ // to be readily available on operations.
+ MLIRContext context(registry);
+ context.loadDialect<test::TestDialect>();
+ ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
+ OpBuilder builder(module);
+ auto op =
+ builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
+ EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
+ EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
+}
+
+TEST(InterfaceAttachment, OperationDelayedContextAppend) {
+ DialectRegistry registry;
+ registry.insert<test::TestDialect>();
+ registry.addOpInterface<ModuleOp, TestExternalOpModel>();
+ registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
+
+ // Construct the context, create ops, and only then append the registry. The
+ // interfaces are expected to be available after appending the registry.
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+ ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
+ OpBuilder builder(module);
+ auto op =
+ builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
+ EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation()));
+ EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation()));
+ context.appendDialectRegistry(registry);
+ EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
+ EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
+}
+
} // end namespace
More information about the Mlir-commits
mailing list