[Mlir-commits] [mlir] 77eee57 - [mlir] Refactor DialectRegistry delayed interface support into a general DialectExtension mechanism
River Riddle
llvmlistbot at llvm.org
Wed Mar 16 22:51:27 PDT 2022
Author: River Riddle
Date: 2022-03-16T22:15:25-07:00
New Revision: 77eee5795e2cf753e4400fb089d01018417c4ee0
URL: https://github.com/llvm/llvm-project/commit/77eee5795e2cf753e4400fb089d01018417c4ee0
DIFF: https://github.com/llvm/llvm-project/commit/77eee5795e2cf753e4400fb089d01018417c4ee0.diff
LOG: [mlir] Refactor DialectRegistry delayed interface support into a general DialectExtension mechanism
The current dialect registry allows for attaching delayed interfaces, that are added to attrs/dialects/ops/etc.
when the owning dialect gets loaded. This is clunky for quite a few reasons, e.g. each interface type has a
separate tracking structure, and is also quite limiting. This commit refactors this delayed mutation of
dialect constructs into a more general DialectExtension mechanism. This mechanism is essentially a registration
callback that is invoked when a set of dialects have been loaded. This allows for attaching interfaces directly
on the loaded constructs, and also allows for loading new dependent dialects. The latter of which is
extremely useful as it will now enable dependent dialects to only apply in the contexts in which they
are necessary. For example, a dialect dependency can now be conditional on if a user actually needs the
interface that relies on it.
Differential Revision: https://reviews.llvm.org/D120367
Added:
mlir/include/mlir/IR/DialectRegistry.h
Modified:
mlir/include/mlir/IR/Dialect.h
mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp
mlir/unittests/IR/DialectTest.cpp
mlir/unittests/IR/InterfaceAttachmentTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 798d66faccdfa..c7a70bfc1cbca 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -13,6 +13,7 @@
#ifndef MLIR_IR_DIALECT_H
#define MLIR_IR_DIALECT_H
+#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/TypeID.h"
@@ -26,11 +27,9 @@ class DialectInterface;
class OpBuilder;
class Type;
-using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
-using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
-using DialectInterfaceAllocatorFunction =
- std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
-using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
+//===----------------------------------------------------------------------===//
+// Dialect
+//===----------------------------------------------------------------------===//
/// Dialects are groups of MLIR operations, types and attributes, as well as
/// behavior associated with the entire group. For example, hooks into other
@@ -180,6 +179,16 @@ class Dialect {
getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
}
+ /// Register a dialect interface with this dialect instance.
+ void addInterface(std::unique_ptr<DialectInterface> interface);
+
+ /// Register a set of dialect interfaces with this dialect instance.
+ template <typename... Args>
+ void addInterfaces() {
+ (void)std::initializer_list<int>{
+ 0, (addInterface(std::make_unique<Args>(this)), 0)...};
+ }
+
protected:
/// The constructor takes a unique namespace for this dialect as well as the
/// context to bind to.
@@ -218,15 +227,6 @@ class Dialect {
/// Enable support for unregistered types.
void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
- /// Register a dialect interface with this dialect instance.
- void addInterface(std::unique_ptr<DialectInterface> interface);
-
- /// Register a set of dialect interfaces with this dialect instance.
- template <typename... Args> void addInterfaces() {
- (void)std::initializer_list<int>{
- 0, (addInterface(std::make_unique<Args>(this)), 0)...};
- }
-
private:
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;
@@ -274,168 +274,6 @@ class Dialect {
friend class MLIRContext;
};
-/// The DialectRegistry maps a dialect namespace to a constructor for the
-/// matching dialect.
-/// This allows for decoupling the list of dialects "available" from the
-/// dialects loaded in the Context. The parser in particular will lazily load
-/// dialects in the Context as operations are encountered.
-class DialectRegistry {
- /// Lists of interfaces that need to be registered when the dialect is loaded.
- struct DelayedInterfaces {
- /// Dialect interfaces.
- SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
- dialectInterfaces;
- /// Attribute/Operation/Type interfaces.
- SmallVector<std::tuple<TypeID, TypeID, ObjectInterfaceAllocatorFunction>, 2>
- objectInterfaces;
- };
-
- using MapTy =
- std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
- using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
-
-public:
- explicit DialectRegistry();
-
- template <typename ConcreteDialect> void insert() {
- insert(TypeID::get<ConcreteDialect>(),
- ConcreteDialect::getDialectNamespace(),
- static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
- // Just allocate the dialect, the context
- // takes ownership of it.
- return ctx->getOrLoadDialect<ConcreteDialect>();
- })));
- }
-
- template <typename ConcreteDialect, typename OtherDialect,
- typename... MoreDialects>
- void insert() {
- insert<ConcreteDialect>();
- insert<OtherDialect, MoreDialects...>();
- }
-
- /// Add a new dialect constructor to the registry. The constructor must be
- /// calling MLIRContext::getOrLoadDialect in order for the context to take
- /// ownership of the dialect and for delayed interface registration to happen.
- void insert(TypeID typeID, StringRef name,
- const DialectAllocatorFunction &ctor);
-
- /// Return an allocation function for constructing the dialect identified by
- /// its namespace, or nullptr if the namespace is not in this registry.
- DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
-
- // Register all dialects available in the current registry with the registry
- // in the provided context.
- void appendTo(DialectRegistry &destination) const {
- for (const auto &nameAndRegistrationIt : registry)
- destination.insert(nameAndRegistrationIt.second.first,
- nameAndRegistrationIt.first,
- nameAndRegistrationIt.second.second);
- // Merge interfaces.
- for (auto it : interfaces) {
- TypeID dialect = it.first;
- auto destInterfaces = destination.interfaces.find(dialect);
- if (destInterfaces == destination.interfaces.end()) {
- destination.interfaces[dialect] = it.second;
- continue;
- }
- // The destination already has delayed interface registrations for this
- // dialect. Merge registrations into the destination registry.
- destInterfaces->second.dialectInterfaces.append(
- it.second.dialectInterfaces.begin(),
- it.second.dialectInterfaces.end());
- destInterfaces->second.objectInterfaces.append(
- it.second.objectInterfaces.begin(), it.second.objectInterfaces.end());
- }
- }
-
- /// Return the names of dialects known to this registry.
- auto getDialectNames() const {
- return llvm::map_range(
- registry,
- [](const MapTy::value_type &item) -> StringRef { return item.first; });
- }
-
- /// Add an interface constructed with the given allocation function to the
- /// dialect provided as template parameter. The dialect must be present in
- /// the registry.
- template <typename DialectTy>
- void addDialectInterface(TypeID interfaceTypeID,
- DialectInterfaceAllocatorFunction allocator) {
- addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
- allocator);
- }
-
- /// Add an interface to the dialect, both provided as template parameter. The
- /// dialect must be present in the registry.
- template <typename DialectTy, typename InterfaceTy>
- void addDialectInterface() {
- addDialectInterface<DialectTy>(
- InterfaceTy::getInterfaceID(), [](Dialect *dialect) {
- return std::make_unique<InterfaceTy>(dialect);
- });
- }
-
- /// Add an external op interface model for an op that belongs to a dialect,
- /// both provided as template parameters. The dialect must be present in the
- /// registry.
- template <typename OpTy, typename ModelTy> void addOpInterface() {
- StringRef opName = OpTy::getOperationName();
- StringRef dialectName = opName.split('.').first;
- addObjectInterface(dialectName, TypeID::get<OpTy>(),
- ModelTy::Interface::getInterfaceID(),
- [](MLIRContext *context) {
- OpTy::template attachInterface<ModelTy>(*context);
- });
- }
-
- /// Add an external attribute interface model for an attribute type `AttrTy`
- /// that is going to belong to `DialectTy`. The dialect must be present in the
- /// registry.
- template <typename DialectTy, typename AttrTy, typename ModelTy>
- void addAttrInterface() {
- addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
- }
-
- /// Add an external type interface model for an type class `TypeTy` that is
- /// going to belong to `DialectTy`. The dialect must be present in the
- /// registry.
- template <typename DialectTy, typename TypeTy, typename ModelTy>
- void addTypeInterface() {
- addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
- }
-
- /// Register any interfaces required for the given dialect (based on its
- /// TypeID). Users are not expected to call this directly.
- void registerDelayedInterfaces(Dialect *dialect) const;
-
-private:
- /// Add an interface constructed with the given allocation function to the
- /// dialect identified by its namespace.
- void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
- const DialectInterfaceAllocatorFunction &allocator);
-
- /// Add an attribute/operation/type interface constructible with the given
- /// allocation function to the dialect identified by its namespace.
- void addObjectInterface(StringRef dialectName, TypeID objectID,
- TypeID interfaceTypeID,
- const ObjectInterfaceAllocatorFunction &allocator);
-
- /// Add an external model for an attribute/type interface to the dialect
- /// identified by its namespace.
- template <typename ObjectTy, typename ModelTy>
- void addStorageUserInterface(StringRef dialectName) {
- addObjectInterface(dialectName, TypeID::get<ObjectTy>(),
- ModelTy::Interface::getInterfaceID(),
- [](MLIRContext *context) {
- ObjectTy::template attachInterface<ModelTy>(*context);
- });
- }
-
- MapTy registry;
- InterfaceMapTy interfaces;
-};
-
} // namespace mlir
namespace llvm {
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
new file mode 100644
index 0000000000000..064c63d9ad562
--- /dev/null
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -0,0 +1,222 @@
+//===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines functionality for registring and extending dialects.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_DIALECTREGISTRY_H
+#define MLIR_IR_DIALECTREGISTRY_H
+
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <map>
+#include <tuple>
+
+namespace mlir {
+class Dialect;
+
+using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
+using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
+
+//===----------------------------------------------------------------------===//
+// DialectExtension
+//===----------------------------------------------------------------------===//
+
+/// This class represents an opaque dialect extension. It contains a set of
+/// required dialects and an application function. The required dialects control
+/// when the extension is applied, i.e. the extension is applied when all
+/// required dialects are loaded. The application function can be used to attach
+/// additional functionality to attributes, dialects, operations, types, etc.,
+/// and may also load additional necessary dialects.
+class DialectExtensionBase {
+public:
+ virtual ~DialectExtensionBase();
+
+ /// Return the dialects that our required by this extension to be loaded
+ /// before applying.
+ ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
+
+ /// Apply this extension to the given context and the required dialects.
+ virtual void apply(MLIRContext *context,
+ MutableArrayRef<Dialect *> dialects) const = 0;
+
+ /// Return a copy of this extension.
+ virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
+
+protected:
+ /// Initialize the extension with a set of required dialects. Note that there
+ /// should always be at least one affected dialect.
+ DialectExtensionBase(ArrayRef<StringRef> dialectNames)
+ : dialectNames(dialectNames.begin(), dialectNames.end()) {
+ assert(!dialectNames.empty() && "expected at least one affected dialect");
+ }
+
+private:
+ /// The names of the dialects affected by this extension.
+ SmallVector<StringRef> dialectNames;
+};
+
+/// This class represents a dialect extension anchored on the given set of
+/// dialects. When all of the specified dialects have been loaded, the
+/// application function of this extension will be executed.
+template <typename DerivedT, typename... DialectsT>
+class DialectExtension : public DialectExtensionBase {
+public:
+ /// Applies this extension to the given context and set of required dialects.
+ virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
+
+ /// Return a copy of this extension.
+ std::unique_ptr<DialectExtensionBase> clone() const final {
+ return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
+ }
+
+protected:
+ DialectExtension()
+ : DialectExtensionBase(
+ ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
+
+ /// Override the base apply method to allow providing the exact dialect types.
+ void apply(MLIRContext *context,
+ MutableArrayRef<Dialect *> dialects) const final {
+ unsigned dialectIdx = 0;
+ auto derivedDialects = std::tuple<DialectsT *...>{
+ static_cast<DialectsT *>(dialects[dialectIdx++])...};
+ llvm::apply_tuple(
+ [&](DialectsT *...dialect) { apply(context, dialect...); },
+ derivedDialects);
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// DialectRegistry
+//===----------------------------------------------------------------------===//
+
+/// The DialectRegistry maps a dialect namespace to a constructor for the
+/// matching dialect. This allows for decoupling the list of dialects
+/// "available" from the dialects loaded in the Context. The parser in
+/// particular will lazily load dialects in the Context as operations are
+/// encountered.
+class DialectRegistry {
+ using MapTy =
+ std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
+
+public:
+ explicit DialectRegistry();
+
+ template <typename ConcreteDialect>
+ void insert() {
+ insert(TypeID::get<ConcreteDialect>(),
+ ConcreteDialect::getDialectNamespace(),
+ static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
+ // Just allocate the dialect, the context
+ // takes ownership of it.
+ return ctx->getOrLoadDialect<ConcreteDialect>();
+ })));
+ }
+
+ template <typename ConcreteDialect, typename OtherDialect,
+ typename... MoreDialects>
+ void insert() {
+ insert<ConcreteDialect>();
+ insert<OtherDialect, MoreDialects...>();
+ }
+
+ /// Add a new dialect constructor to the registry. The constructor must be
+ /// calling MLIRContext::getOrLoadDialect in order for the context to take
+ /// ownership of the dialect and for delayed interface registration to happen.
+ void insert(TypeID typeID, StringRef name,
+ const DialectAllocatorFunction &ctor);
+
+ /// Return an allocation function for constructing the dialect identified by
+ /// its namespace, or nullptr if the namespace is not in this registry.
+ DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
+
+ // Register all dialects available in the current registry with the registry
+ // in the provided context.
+ void appendTo(DialectRegistry &destination) const {
+ for (const auto &nameAndRegistrationIt : registry)
+ destination.insert(nameAndRegistrationIt.second.first,
+ nameAndRegistrationIt.first,
+ nameAndRegistrationIt.second.second);
+ // Merge the extensions.
+ for (const auto &extension : extensions)
+ destination.extensions.push_back(extension->clone());
+ }
+
+ /// Return the names of dialects known to this registry.
+ auto getDialectNames() const {
+ return llvm::map_range(
+ registry,
+ [](const MapTy::value_type &item) -> StringRef { return item.first; });
+ }
+
+ /// Apply any held extensions that require the given dialect. Users are not
+ /// expected to call this directly.
+ void applyExtensions(Dialect *dialect) const;
+
+ /// Apply any applicable extensions to the given context. Users are not
+ /// expected to call this directly.
+ void applyExtensions(MLIRContext *ctx) const;
+
+ /// Add the given extension to the registry.
+ void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
+ extensions.push_back(std::move(extension));
+ }
+
+ /// Add the given extensions to the registry.
+ template <typename... ExtensionsT>
+ void addExtensions() {
+ (void)std::initializer_list<int>{
+ addExtension(std::make_unique<ExtensionsT>())...};
+ }
+
+ /// Add an extension function that requires the given dialects.
+ /// Note: This bare functor overload is provided in addition to the
+ /// std::function variant to enable dialect type deduction, e.g.:
+ /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
+ ///
+ /// is equivalent to:
+ /// registry.addExtension<MyDialect>(
+ /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
+ /// )
+ template <typename... DialectsT>
+ void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
+ addExtension<DialectsT...>(
+ std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
+ }
+ template <typename... DialectsT>
+ void
+ addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
+ using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
+
+ struct Extension : public DialectExtension<Extension, DialectsT...> {
+ Extension(const Extension &) = default;
+ Extension(ExtensionFnT extensionFn)
+ : extensionFn(std::move(extensionFn)) {}
+ ~Extension() override = default;
+
+ void apply(MLIRContext *context, DialectsT *...dialects) const final {
+ extensionFn(context, dialects...);
+ }
+ ExtensionFnT extensionFn;
+ };
+ addExtension(std::make_unique<Extension>(std::move(extensionFn)));
+ }
+
+private:
+ MapTy registry;
+ std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_DIALECTREGISTRY_H
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 2b08300bb7127..12726a1656bbc 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -154,7 +154,9 @@ struct SelectOpInterface
void mlir::arith::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addOpInterface<ConstantOp, ConstantOpInterface>();
- registry.addOpInterface<IndexCastOp, IndexCastOpInterface>();
- registry.addOpInterface<SelectOp, SelectOpInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) {
+ ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
+ IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
+ SelectOp::attachInterface<SelectOpInterface>(*ctx);
+ });
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 63073dd1e9ebd..99ce070e94000 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -695,7 +695,9 @@ LogicalResult bufferization::deallocateBuffers(Operation *op) {
void bufferization::registerAllocationOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addOpInterface<memref::AllocOp, DefaultAllocationInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+ memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
+ });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 535627469ae7a..fc31b4a260f95 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -962,9 +962,11 @@ struct FuncOpInterface
void mlir::linalg::comprehensive_bufferize::std_ext::
registerModuleBufferizationExternalModels(DialectRegistry ®istry) {
- registry.addOpInterface<func::CallOp, std_ext::CallOpInterface>();
- registry.addOpInterface<func::ReturnOp, std_ext::ReturnOpInterface>();
- registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
+ func::CallOp::attachInterface<std_ext::CallOpInterface>(*ctx);
+ func::ReturnOp::attachInterface<std_ext::ReturnOpInterface>(*ctx);
+ func::FuncOp::attachInterface<std_ext::FuncOpInterface>(*ctx);
+ });
}
/// Set the attribute that triggers inplace bufferization on a FuncOp argument
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 2cbed7cebb1e9..4ee0e6360e117 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -246,22 +246,13 @@ struct InitTensorOpInterface
/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
/// the `BufferizableOpInterface` with each of them.
-template <typename... OpTys>
-struct LinalgOpInterfaceHelper;
-
-template <typename First, typename... Others>
-struct LinalgOpInterfaceHelper<First, Others...> {
- static void registerOpInterface(DialectRegistry ®istry) {
- registry.addOpInterface<First, LinalgOpInterface<First>>();
- LinalgOpInterfaceHelper<Others...>::registerOpInterface(registry);
+template <typename... Ops>
+struct LinalgOpInterfaceHelper {
+ static void registerOpInterface(MLIRContext *ctx) {
+ (void)std::initializer_list<int>{
+ 0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...};
}
};
-
-template <>
-struct LinalgOpInterfaceHelper<> {
- static void registerOpInterface(DialectRegistry ®istry) {}
-};
-
} // namespace
/// Return true if all `neededValues` are in scope at the given
@@ -501,13 +492,15 @@ LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addOpInterface<linalg::InitTensorOp, InitTensorOpInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
+ linalg::InitTensorOp::attachInterface<InitTensorOpInterface>(*ctx);
- // Register all Linalg structured ops. `LinalgOp` is an interface and it is
- // not possible to attach an external interface to an existing interface.
- // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
- LinalgOpInterfaceHelper<
+ // Register all Linalg structured ops. `LinalgOp` is an interface and it is
+ // not possible to attach an external interface to an existing interface.
+ // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
+ LinalgOpInterfaceHelper<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
- >::registerOpInterface(registry);
+ >::registerOpInterface(ctx);
+ });
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 50136a0bfe47a..fde24fd8b7ab4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -503,8 +503,10 @@ struct YieldOpInterface
void mlir::scf::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>();
- registry.addOpInterface<ForOp, ForOpInterface>();
- registry.addOpInterface<IfOp, IfOpInterface>();
- registry.addOpInterface<YieldOp, YieldOpInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
+ ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
+ ForOp::attachInterface<ForOpInterface>(*ctx);
+ IfOp::attachInterface<IfOpInterface>(*ctx);
+ YieldOp::attachInterface<YieldOpInterface>(*ctx);
+ });
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 8796f4cc378e7..ce3c85e6454e0 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -168,6 +168,8 @@ struct AssumingYieldOpInterface
void mlir::shape::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addOpInterface<shape::AssumingOp, AssumingOpInterface>();
- registry.addOpInterface<shape::AssumingYieldOp, AssumingYieldOpInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
+ shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
+ shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
+ });
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index c5f8649ceae22..29da4f5554131 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -205,11 +205,11 @@ struct ReifyPadOp
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry
- .addOpInterface<tensor::ExpandShapeOp,
- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>();
- registry
- .addOpInterface<tensor::CollapseShapeOp,
- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>();
- registry.addOpInterface<tensor::PadOp, ReifyPadOp>();
+ registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
+ ExpandShapeOp::attachInterface<
+ ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
+ CollapseShapeOp::attachInterface<
+ ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+ PadOp::attachInterface<ReifyPadOp>(*ctx);
+ });
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 5ecdea177aa78..5f83f016f585a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -283,5 +283,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
void mlir::tensor::registerTilingOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addOpInterface<tensor::PadOp, PadOpTiling>();
+ registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
+ tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
+ });
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 0efebdfc9d41a..a9519b98803cc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -700,15 +700,17 @@ struct RankOpInterface
void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addOpInterface<CastOp, CastOpInterface>();
- registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>();
- registry.addOpInterface<DimOp, DimOpInterface>();
- registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>();
- registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
- registry.addOpInterface<ExtractOp, ExtractOpInterface>();
- registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
- registry.addOpInterface<GenerateOp, GenerateOpInterface>();
- registry.addOpInterface<InsertOp, InsertOpInterface>();
- registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
- registry.addOpInterface<RankOp, RankOpInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
+ CastOp::attachInterface<CastOpInterface>(*ctx);
+ CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
+ DimOp::attachInterface<DimOpInterface>(*ctx);
+ ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
+ ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
+ ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
+ FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
+ GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
+ InsertOp::attachInterface<InsertOpInterface>(*ctx);
+ InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
+ RankOp::attachInterface<RankOpInterface>(*ctx);
+ });
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 48008934e46b6..c823f34b695ac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -121,6 +121,8 @@ struct TransferWriteOpInterface
void mlir::vector::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addOpInterface<TransferReadOp, TransferReadOpInterface>();
- registry.addOpInterface<TransferWriteOp, TransferWriteOpInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+ TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
+ TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
+ });
}
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index b14876a165763..2e983d641a9ec 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -24,97 +24,6 @@
using namespace mlir;
using namespace detail;
-//===----------------------------------------------------------------------===//
-// DialectRegistry
-//===----------------------------------------------------------------------===//
-
-DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
-
-void DialectRegistry::addDialectInterface(
- StringRef dialectName, TypeID interfaceTypeID,
- const DialectInterfaceAllocatorFunction &allocator) {
- assert(allocator && "unexpected null interface allocation function");
- auto it = registry.find(dialectName.str());
- assert(it != registry.end() &&
- "adding an interface for an unregistered dialect");
-
- // Bail out if the interface with the given ID is already in the registry for
- // the given dialect. We expect a small number (dozens) of interfaces so a
- // linear search is fine here.
- auto &ifaces = interfaces[it->second.first];
- for (const auto &kvp : ifaces.dialectInterfaces) {
- if (kvp.first == interfaceTypeID) {
- LLVM_DEBUG(llvm::dbgs()
- << "[" DEBUG_TYPE
- "] repeated interface registration for dialect "
- << dialectName);
- return;
- }
- }
-
- ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
-}
-
-void DialectRegistry::addObjectInterface(
- StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
- const ObjectInterfaceAllocatorFunction &allocator) {
- assert(allocator && "unexpected null interface allocation function");
-
- auto it = registry.find(dialectName.str());
- assert(it != registry.end() &&
- "adding an interface for an op from an unregistered dialect");
-
- auto dialectID = it->second.first;
- auto &ifaces = interfaces[dialectID];
-
- for (const auto &info : ifaces.objectInterfaces) {
- if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
- LLVM_DEBUG(llvm::dbgs()
- << "[" DEBUG_TYPE
- "] repeated interface object interface registration");
- return;
- }
- }
-
- ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
-}
-
-DialectAllocatorFunctionRef
-DialectRegistry::getDialectAllocator(StringRef name) const {
- auto it = registry.find(name.str());
- if (it == registry.end())
- return nullptr;
- return it->second.second;
-}
-
-void DialectRegistry::insert(TypeID typeID, StringRef name,
- const DialectAllocatorFunction &ctor) {
- auto inserted = registry.insert(
- std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
- if (!inserted.second && inserted.first->second.first != typeID) {
- llvm::report_fatal_error(
- "Trying to register
diff erent dialects for the same namespace: " +
- name);
- }
-}
-
-void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
- auto it = interfaces.find(dialect->getTypeID());
- if (it == interfaces.end())
- return;
-
- // Add an interface if it is not already present.
- for (const auto &kvp : it->getSecond().dialectInterfaces) {
- if (dialect->getRegisteredInterface(kvp.first))
- continue;
- dialect->addInterface(kvp.second(dialect));
- }
-
- // Add attribute, operation and type interfaces.
- for (const auto &info : it->getSecond().objectInterfaces)
- std::get<2>(info)(dialect->getContext());
-}
-
//===----------------------------------------------------------------------===//
// Dialect
//===----------------------------------------------------------------------===//
@@ -189,7 +98,13 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
auto it = registeredInterfaces.try_emplace(interface->getID(),
std::move(interface));
(void)it;
- assert(it.second && "interface kind has already been registered");
+ LLVM_DEBUG({
+ if (!it.second) {
+ llvm::dbgs() << "[" DEBUG_TYPE
+ "] repeated interface registration for dialect "
+ << getNamespace();
+ }
+ });
}
//===----------------------------------------------------------------------===//
@@ -216,3 +131,100 @@ const DialectInterface *
DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
return getInterfaceFor(op->getDialect());
}
+
+//===----------------------------------------------------------------------===//
+// DialectExtension
+//===----------------------------------------------------------------------===//
+
+DialectExtensionBase::~DialectExtensionBase() = default;
+
+//===----------------------------------------------------------------------===//
+// DialectRegistry
+//===----------------------------------------------------------------------===//
+
+DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
+
+DialectAllocatorFunctionRef
+DialectRegistry::getDialectAllocator(StringRef name) const {
+ auto it = registry.find(name.str());
+ if (it == registry.end())
+ return nullptr;
+ return it->second.second;
+}
+
+void DialectRegistry::insert(TypeID typeID, StringRef name,
+ const DialectAllocatorFunction &ctor) {
+ auto inserted = registry.insert(
+ std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
+ if (!inserted.second && inserted.first->second.first != typeID) {
+ llvm::report_fatal_error(
+ "Trying to register
diff erent dialects for the same namespace: " +
+ name);
+ }
+}
+
+void DialectRegistry::applyExtensions(Dialect *dialect) const {
+ MLIRContext *ctx = dialect->getContext();
+ StringRef dialectName = dialect->getNamespace();
+
+ // Functor used to try to apply the given extension.
+ auto applyExtension = [&](const DialectExtensionBase &extension) {
+ ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
+
+ // Handle the simple case of a single dialect name. In this case, the
+ // required dialect should be the current dialect.
+ if (dialectNames.size() == 1) {
+ if (dialectNames.front() == dialectName)
+ extension.apply(ctx, dialect);
+ return;
+ }
+
+ // Otherwise, check to see if this extension requires this dialect.
+ const StringRef *nameIt = llvm::find(dialectNames, dialectName);
+ if (nameIt == dialectNames.end())
+ return;
+
+ // If it does, ensure that all of the other required dialects have been
+ // loaded.
+ SmallVector<Dialect *> requiredDialects;
+ requiredDialects.reserve(dialectNames.size());
+ for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
+ ++it) {
+ // The current dialect is known to be loaded.
+ if (it == nameIt) {
+ requiredDialects.push_back(dialect);
+ continue;
+ }
+ // Otherwise, check if it is loaded.
+ Dialect *loadedDialect = ctx->getLoadedDialect(*it);
+ if (!loadedDialect)
+ return;
+ requiredDialects.push_back(loadedDialect);
+ }
+ extension.apply(ctx, requiredDialects);
+ };
+
+ for (const auto &extension : extensions)
+ applyExtension(*extension);
+}
+
+void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
+ // Functor used to try to apply the given extension.
+ auto applyExtension = [&](const DialectExtensionBase &extension) {
+ ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
+
+ // Check to see if all of the dialects for this extension are loaded.
+ SmallVector<Dialect *> requiredDialects;
+ requiredDialects.reserve(dialectNames.size());
+ for (StringRef dialectName : dialectNames) {
+ Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
+ if (!loadedDialect)
+ return;
+ requiredDialects.push_back(loadedDialect);
+ }
+ extension.apply(ctx, requiredDialects);
+ };
+
+ for (const auto &extension : extensions)
+ applyExtension(*extension);
+}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 1b6d354eff37e..2c0b3ba049d7b 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -357,9 +357,8 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) {
registry.appendTo(impl->dialectsRegistry);
- // For the already loaded dialects, register the interfaces immediately.
- for (const auto &kvp : impl->loadedDialects)
- registry.registerDelayedInterfaces(kvp.second.get());
+ // For the already loaded dialects, apply any possible extensions immediately.
+ registry.applyExtensions(this);
}
const DialectRegistry &MLIRContext::getDialectRegistry() {
@@ -437,8 +436,8 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
}
- // Actually register the interfaces with delayed registration.
- impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
+ // Apply any extensions to this newly loaded dialect.
+ impl.dialectsRegistry.applyExtensions(dialect.get());
return dialect.get();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
index 776559c31c9d7..044462d33cfd1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp
@@ -44,8 +44,9 @@ class AMXDialectLLVMIRTranslationInterface
void mlir::registerAMXDialectTranslation(DialectRegistry ®istry) {
registry.insert<amx::AMXDialect>();
- registry.addDialectInterface<amx::AMXDialect,
- AMXDialectLLVMIRTranslationInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+ dialect->addInterfaces<AMXDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerAMXDialectTranslation(MLIRContext &context) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp
index 22ea7316d2e95..7098592d506e0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.cpp
@@ -45,8 +45,10 @@ class ArmNeonDialectLLVMIRTranslationInterface
void mlir::registerArmNeonDialectTranslation(DialectRegistry ®istry) {
registry.insert<arm_neon::ArmNeonDialect>();
- registry.addDialectInterface<arm_neon::ArmNeonDialect,
- ArmNeonDialectLLVMIRTranslationInterface>();
+ registry.addExtension(
+ +[](MLIRContext *ctx, arm_neon::ArmNeonDialect *dialect) {
+ dialect->addInterfaces<ArmNeonDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerArmNeonDialectTranslation(MLIRContext &context) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
index 364e63d40150f..bc1f0e934fa02 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
@@ -44,8 +44,9 @@ class ArmSVEDialectLLVMIRTranslationInterface
void mlir::registerArmSVEDialectTranslation(DialectRegistry ®istry) {
registry.insert<arm_sve::ArmSVEDialect>();
- registry.addDialectInterface<arm_sve::ArmSVEDialect,
- ArmSVEDialectLLVMIRTranslationInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, arm_sve::ArmSVEDialect *dialect) {
+ dialect->addInterfaces<ArmSVEDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerArmSVEDialectTranslation(MLIRContext &context) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 6701ccb830d53..43f5069ddd2e8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -503,8 +503,9 @@ class LLVMDialectLLVMIRTranslationInterface
void mlir::registerLLVMDialectTranslation(DialectRegistry ®istry) {
registry.insert<LLVM::LLVMDialect>();
- registry.addDialectInterface<LLVM::LLVMDialect,
- LLVMDialectLLVMIRTranslationInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
+ dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerLLVMDialectTranslation(MLIRContext &context) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index f39b0d3378115..e09260dca28a2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -141,8 +141,9 @@ class NVVMDialectLLVMIRTranslationInterface
void mlir::registerNVVMDialectTranslation(DialectRegistry ®istry) {
registry.insert<NVVM::NVVMDialect>();
- registry.addDialectInterface<NVVM::NVVMDialect,
- NVVMDialectLLVMIRTranslationInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
+ dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerNVVMDialectTranslation(MLIRContext &context) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index bda505d935034..49431d15ed819 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -533,8 +533,9 @@ LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) {
registry.insert<acc::OpenACCDialect>();
- registry.addDialectInterface<acc::OpenACCDialect,
- OpenACCDialectLLVMIRTranslationInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, acc::OpenACCDialect *dialect) {
+ dialect->addInterfaces<OpenACCDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 82fd430d94e8f..a835b72913989 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1270,8 +1270,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
void mlir::registerOpenMPDialectTranslation(DialectRegistry ®istry) {
registry.insert<omp::OpenMPDialect>();
- registry.addDialectInterface<omp::OpenMPDialect,
- OpenMPDialectLLVMIRTranslationInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
+ dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 5bc02dc552709..71d0a6123a700 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -107,8 +107,9 @@ class ROCDLDialectLLVMIRTranslationInterface
void mlir::registerROCDLDialectTranslation(DialectRegistry ®istry) {
registry.insert<ROCDL::ROCDLDialect>();
- registry.addDialectInterface<ROCDL::ROCDLDialect,
- ROCDLDialectLLVMIRTranslationInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) {
+ dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerROCDLDialectTranslation(MLIRContext &context) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp
index f2ddc2ab16455..fa5f61420ee8a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp
@@ -45,8 +45,10 @@ class X86VectorDialectLLVMIRTranslationInterface
void mlir::registerX86VectorDialectTranslation(DialectRegistry ®istry) {
registry.insert<x86vector::X86VectorDialect>();
- registry.addDialectInterface<x86vector::X86VectorDialect,
- X86VectorDialectLLVMIRTranslationInterface>();
+ registry.addExtension(
+ +[](MLIRContext *ctx, x86vector::X86VectorDialect *dialect) {
+ dialect->addInterfaces<X86VectorDialectLLVMIRTranslationInterface>();
+ });
}
void mlir::registerX86VectorDialectTranslation(MLIRContext &context) {
diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index b4fd697e08820..1da7775d8eca9 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -63,7 +63,9 @@ TEST(Dialect, DelayedInterfaceRegistration) {
registry.insert<TestDialect, SecondTestDialect>();
// Delayed registration of an interface for TestDialect.
- registry.addDialectInterface<TestDialect, TestDialectInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
+ dialect->addInterfaces<TestDialectInterface>();
+ });
MLIRContext context(registry);
@@ -85,8 +87,10 @@ TEST(Dialect, DelayedInterfaceRegistration) {
// loaded dialect and check that the interface is now registered.
DialectRegistry secondRegistry;
secondRegistry.insert<SecondTestDialect>();
- secondRegistry
- .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
+ secondRegistry.addExtension(
+ +[](MLIRContext *ctx, SecondTestDialect *dialect) {
+ dialect->addInterfaces<SecondTestDialectInterface>();
+ });
context.appendDialectRegistry(secondRegistry);
secondTestDialectInterface =
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
@@ -97,7 +101,9 @@ TEST(Dialect, RepeatedDelayedRegistration) {
// Set up the delayed registration.
DialectRegistry registry;
registry.insert<TestDialect>();
- registry.addDialectInterface<TestDialect, TestDialectInterface>();
+ registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
+ dialect->addInterfaces<TestDialectInterface>();
+ });
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
@@ -110,33 +116,12 @@ TEST(Dialect, RepeatedDelayedRegistration) {
// on repeated interface registration.
DialectRegistry secondRegistry;
secondRegistry.insert<TestDialect>();
- secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
+ secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
+ dialect->addInterfaces<TestDialectInterface>();
+ });
context.appendDialectRegistry(secondRegistry);
testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
}
-// A dialect that registers two interfaces with the same InterfaceID, triggering
-// an assertion failure.
-struct RepeatedRegistrationDialect : public Dialect {
- static StringRef getDialectNamespace() { return "repeatedreg"; }
- RepeatedRegistrationDialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context,
- TypeID::get<RepeatedRegistrationDialect>()) {
- addInterfaces<TestDialectInterface>();
- addInterfaces<SecondTestDialectInterface>();
- }
-};
-
-TEST(Dialect, RepeatedInterfaceRegistrationDeath) {
- MLIRContext context;
- (void)context;
-
- // This triggers an assertion in debug mode.
-#ifndef NDEBUG
- ASSERT_DEATH(context.loadDialect<RepeatedRegistrationDialect>(),
- "interface kind has already been registered");
-#endif
-}
-
} // namespace
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index 132625e0b8b42..d5e19d27f8eb5 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -102,7 +102,9 @@ TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
// Put the interface in the registry.
DialectRegistry registry;
registry.insert<test::TestDialect>();
- registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
+ registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
+ test::TestType::attachInterface<TestTypeModel>(*ctx);
+ });
// Check that when a context is constructed with the given registry, the type
// interface gets registered.
@@ -119,7 +121,9 @@ TEST(InterfaceAttachment, TypeDelayedContextAppend) {
// Put the interface in the registry.
DialectRegistry registry;
registry.insert<test::TestDialect>();
- registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
+ registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
+ test::TestType::attachInterface<TestTypeModel>(*ctx);
+ });
// Check that when the registry gets appended to the context, the interface
// becomes available for objects in loaded dialects.
@@ -133,7 +137,9 @@ TEST(InterfaceAttachment, TypeDelayedContextAppend) {
TEST(InterfaceAttachment, RepeatedRegistration) {
DialectRegistry registry;
- registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
+ registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
+ IntegerType::attachInterface<Model>(*ctx);
+ });
MLIRContext context(registry);
// Should't fail on repeated registration through the dialect registry.
@@ -144,7 +150,9 @@ TEST(InterfaceAttachment, TypeBuiltinDelayed) {
// Builtin dialect needs to registration or loading, but delayed interface
// registration must still work.
DialectRegistry registry;
- registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
+ registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
+ IntegerType::attachInterface<Model>(*ctx);
+ });
MLIRContext context(registry);
IntegerType i16 = IntegerType::get(&context, 16);
@@ -238,8 +246,9 @@ TEST(InterfaceAttachmentTest, AttributeDelayed) {
// that the delayed registration work for attributes.
DialectRegistry registry;
registry.insert<test::TestDialect>();
- registry.addAttrInterface<test::TestDialect, test::SimpleAAttr,
- TestExternalSimpleAAttrModel>();
+ registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
+ test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
+ });
MLIRContext context(registry);
context.loadDialect<test::TestDialect>();
@@ -343,12 +352,16 @@ struct TestExternalTestOpModel
TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
DialectRegistry registry;
registry.insert<test::TestDialect>();
- registry.addOpInterface<ModuleOp, TestExternalOpModel>();
- registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
- registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
-
- // Construct the context directly from a registry. The interfaces are expected
- // to be readily available on operations.
+ registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
+ ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
+ });
+ registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
+ test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
+ test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
+ });
+
+ // Construct the context directly from a registry. The interfaces are
+ // expected to be readily available on operations.
MLIRContext context(registry);
context.loadDialect<test::TestDialect>();
@@ -370,9 +383,13 @@ TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
TEST(InterfaceAttachment, OperationDelayedContextAppend) {
DialectRegistry registry;
registry.insert<test::TestDialect>();
- registry.addOpInterface<ModuleOp, TestExternalOpModel>();
- registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
- registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
+ registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
+ ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
+ });
+ registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
+ test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
+ test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
+ });
// Construct the context, create ops, and only then append the registry. The
// interfaces are expected to be available after appending the registry.
More information about the Mlir-commits
mailing list