[Mlir-commits] [mlir] 83a635c - [mlir] Add support for interface inheritance
River Riddle
llvmlistbot at llvm.org
Wed Jan 18 19:16:49 PST 2023
Author: River Riddle
Date: 2023-01-18T19:16:30-08:00
New Revision: 83a635c0d4759bd77bbbb21ff8d202cb8c3ea57b
URL: https://github.com/llvm/llvm-project/commit/83a635c0d4759bd77bbbb21ff8d202cb8c3ea57b
DIFF: https://github.com/llvm/llvm-project/commit/83a635c0d4759bd77bbbb21ff8d202cb8c3ea57b.diff
LOG: [mlir] Add support for interface inheritance
This allows for interfaces to define a set of "base classes",
which are interfaces whose methods/extra class decls/etc.
should be inherited by the derived interface. This more
easily enables combining interfaces and their dependencies,
without lots of awkard casting. Additional implicit conversion
operators also greatly simplify the conversion process.
One other aspect of this "inheritance" is that we also implicitly
add the base interfaces to the attr/op/type. The user can still
add them manually if desired, but this should help remove some
of the boiler plate when an interface has dependencies.
See https://discourse.llvm.org/t/interface-inheritance-and-dependencies-interface-method-visibility-interface-composition
Differential Revision: https://reviews.llvm.org/D140198
Added:
Modified:
mlir/docs/Interfaces.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/Support/InterfaceSupport.h
mlir/include/mlir/TableGen/Interfaces.h
mlir/lib/IR/ExtensibleDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Support/InterfaceSupport.cpp
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/lib/TableGen/Interfaces.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/lib/Dialect/Test/TestInterfaces.td
mlir/test/mlir-tblgen/op-interface.td
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 9482c5a9ab1a5..6bb5070138632 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -379,6 +379,9 @@ comprised of the following components:
* C++ Class Name (Provided via template parameter)
- The name of the C++ interface class.
+* Interface Base Classes
+ - A set of interfaces that the interface class should derived from. See
+ [Interface Inheritance](#interface-inheritance) below for more details.
* Description (`description`)
- A string description of the interface, its invariants, example usages,
etc.
@@ -415,6 +418,8 @@ comprised of the following components:
- The structure of this code block corresponds 1-1 with the structure of a
[`Trait::verifyTrait`](Traits.md) method.
+##### Interface Methods
+
There are two types of methods that can be used with an interface,
`InterfaceMethod` and `StaticInterfaceMethod`. They are both comprised of the
same core components, with the distinction that `StaticInterfaceMethod` models a
@@ -634,6 +639,71 @@ def OpWithOverrideInferTypeInterfaceOp : Op<...
[DeclareOpInterfaceMethods<MyInterface, ["getNumWithDefault"]>]> { ... }
```
+##### Interface Inheritance
+
+Interfaces also support a limited form of inheritance, which allows for
+building upon pre-existing interfaces in a way similar to that of classes in
+programming languages like C++. This more easily allows for building modular
+interfaces, without suffering from the pain of lots of explicit casting. To
+enable inheritance, an interface simply needs to provide the desired set of
+base classes in its definition. For example:
+
+```tablegen
+def MyBaseInterface : OpInterface<"MyBaseInterface"> {
+ ...
+}
+
+def MyInterface : OpInterface<"MyInterface", [MyBaseInterface]> {
+ ...
+}
+```
+
+This will result in `MyInterface` inheriting various components from
+`MyBaseInterface`, namely its interface methods and extra class declarations.
+Given that these inherited components are comprised of opaque C++ blobs, we
+cannot properly sandbox the names. As such, it's important to ensure that inherited
+components do not create name overlaps, as these will result in errors during
+interface generation.
+
+`MyInterface` will also implicitly inherit any base classes defined on
+`MyBaseInterface` as well. It's important to note, however, that there is only
+ever one instance of each interface for a given attribute, operation, or type.
+Inherited interface methods simplify forward to base interface implementation.
+This produces a simpler system overall, and also removes any potential problems
+surrounding "diamond inheritance". The interfaces on an attribute/op/type can be
+thought of as comprising a set, with each interface (including base interfaces)
+uniqued within this set and referenced elsewhere as necessary.
+
+When adding an interface with inheritance to an attribute, operation, or type,
+all of the base interfaces are also implicitly added as well. The user may still
+manually specify the base interfaces if they desire, such as for use with the
+`Declare<Attr|Op|Type>InterfaceMethods` helper classes.
+
+If our interface were to be specified as:
+
+```tablegen
+def MyBaseInterface : OpInterface<"MyBaseInterface"> {
+ ...
+}
+
+def MyOtherBaseInterface : OpInterface<MyOtherBaseInterface, [MyBaseInterface]> {
+ ...
+}
+
+def MyInterface : OpInterface<"MyInterface", [MyBaseInterface, MyOtherBaseInterface]> {
+ ...
+}
+```
+
+An operation with `MyInterface` attached, would have the following interfaces added:
+
+* MyBaseInterface, MyOtherBaseInterface, MyInterface
+
+The methods from `MyBaseInterface` in both `MyInterface` and `MyOtherBaseInterface` would
+forward to a single unique implementation for the operation.
+
+##### Generation
+
Once the interfaces have been defined, the C++ header and source files can be
generated using the `--gen-<attr|op|type>-interface-decls` and
`--gen-<attr|op|type>-interface-defs` options with mlir-tblgen. Note that when
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index ea9976ca2e24f..e859e40dc1b00 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1947,7 +1947,7 @@ def AttrSizedResultSegments :
def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait;
//===----------------------------------------------------------------------===//
-// OpInterface definitions
+// Interface definitions
//===----------------------------------------------------------------------===//
// Marker used to identify the argument list for an op or interface method.
@@ -2025,7 +2025,7 @@ class StaticInterfaceMethod<string desc, string retTy, string methodName,
defaultImplementation>;
// Interface represents a base interface.
-class Interface<string name> {
+class Interface<string name, list<Interface> baseInterfacesArg = []> {
// A human-readable description of what this interface does.
string description = "";
@@ -2055,28 +2055,36 @@ class Interface<string name> {
// `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the
// entity being checked.
code extraClassOf = "";
+
+ // An optional set of base interfaces that this interface
+ // "derives" from.
+ list<Interface> baseInterfaces = baseInterfacesArg;
}
// AttrInterface represents an interface registered to an attribute.
-class AttrInterface<string name> : Interface<name>, InterfaceTrait<name>,
- Attr<CPred<"$_self.isa<"
- # !if(!empty(cppNamespace),"", cppNamespace # "::") # name # ">()">,
- name # " instance">
-{
+class AttrInterface<string name, list<Interface> baseInterfaces = []>
+ : Interface<name, baseInterfaces>, InterfaceTrait<name>,
+ Attr<CPred<"$_self.isa<"
+ # !if(!empty(cppNamespace),"", cppNamespace # "::") # name # ">()">,
+ name # " instance"
+ > {
let storageType = !if(!empty(cppNamespace), "", cppNamespace # "::") # name;
let returnType = storageType;
let convertFromStorage = "$_self";
}
// OpInterface represents an interface registered to an operation.
-class OpInterface<string name> : Interface<name>, OpInterfaceTrait<name>;
+class OpInterface<string name, list<Interface> baseInterfaces = []>
+ : Interface<name, baseInterfaces>, OpInterfaceTrait<name>;
// TypeInterface represents an interface registered to a type.
-class TypeInterface<string name> : Interface<name>, InterfaceTrait<name>,
- Type<CPred<"$_self.isa<"
- # !if(!empty(cppNamespace),"", cppNamespace # "::") # name # ">()">,
+class TypeInterface<string name, list<Interface> baseInterfaces = []>
+ : Interface<name, baseInterfaces>, InterfaceTrait<name>,
+ Type<CPred<"$_self.isa<"
+ # !if(!empty(cppNamespace),"", cppNamespace # "::") # name # ">()">,
name # " instance",
- !if(!empty(cppNamespace),"", cppNamespace # "::") # name>;
+ !if(!empty(cppNamespace),"", cppNamespace # "::") # name
+ >;
// Whether to declare the interface methods in the user entity's header. This
// class simply wraps an Interface but is used to indicate that the method
@@ -2092,29 +2100,32 @@ class DeclareInterfaceMethods<list<string> overridenMethods = []> {
class DeclareAttrInterfaceMethods<AttrInterface interface,
list<string> overridenMethods = []>
: DeclareInterfaceMethods<overridenMethods>,
- AttrInterface<interface.cppInterfaceName> {
+ AttrInterface<interface.cppInterfaceName, interface.baseInterfaces> {
let description = interface.description;
let cppInterfaceName = interface.cppInterfaceName;
let cppNamespace = interface.cppNamespace;
let methods = interface.methods;
+ let baseInterfaces = interface.baseInterfaces;
}
class DeclareOpInterfaceMethods<OpInterface interface,
list<string> overridenMethods = []>
: DeclareInterfaceMethods<overridenMethods>,
- OpInterface<interface.cppInterfaceName> {
+ OpInterface<interface.cppInterfaceName, interface.baseInterfaces> {
let description = interface.description;
let cppInterfaceName = interface.cppInterfaceName;
let cppNamespace = interface.cppNamespace;
let methods = interface.methods;
+ let baseInterfaces = interface.baseInterfaces;
}
class DeclareTypeInterfaceMethods<TypeInterface interface,
list<string> overridenMethods = []>
: DeclareInterfaceMethods<overridenMethods>,
- TypeInterface<interface.cppInterfaceName> {
+ TypeInterface<interface.cppInterfaceName, interface.baseInterfaces> {
let description = interface.description;
let cppInterfaceName = interface.cppInterfaceName;
let cppNamespace = interface.cppNamespace;
let methods = interface.methods;
+ let baseInterfaces = interface.baseInterfaces;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index fc7281fc526f1..5d526f9d62c62 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -284,8 +284,9 @@ class OperationName {
/// Attach the given models as implementations of the corresponding
/// interfaces for the concrete operation.
- template <typename... Models> void attachInterface() {
- getImpl()->getInterfaceMap().insert<Models...>();
+ template <typename... Models>
+ void attachInterface() {
+ getImpl()->getInterfaceMap().insertModels<Models...>();
}
/// Returns true if this operation has the given interface registered to it.
@@ -378,7 +379,8 @@ class RegisteredOperationName : public OperationName {
public:
/// Implementation of the InterfaceConcept for operation APIs that forwarded
/// to a concrete op implementation.
- template <typename ConcreteOp> struct Model : public Impl {
+ template <typename ConcreteOp>
+ struct Model : public Impl {
Model(Dialect *dialect)
: Impl(ConcreteOp::getOperationName(), dialect,
TypeID::get<ConcreteOp>(), ConcreteOp::getInterfaceMap()) {}
@@ -418,7 +420,8 @@ class RegisteredOperationName : public OperationName {
/// Register a new operation in a Dialect object.
/// This constructor is used by Dialect objects when they register the list
/// of operations they contain.
- template <typename T> static void insert(Dialect &dialect) {
+ template <typename T>
+ static void insert(Dialect &dialect) {
insert(std::make_unique<Model<T>>(&dialect), T::getAttributeNames());
}
/// The use of this method is in general discouraged in favor of
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index ff5a0630e4fff..128ad815556c5 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -140,7 +140,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
"that is not itself registered.");
(checkInterfaceTarget<IfaceModels>(), ...);
- abstract->interfaceMap.template insert<IfaceModels...>();
+ abstract->interfaceMap.template insertModels<IfaceModels...>();
}
/// Get or create a new ConcreteT instance within the ctx. This
diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index 6ba7b33189843..71c4003708356 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -111,8 +111,8 @@ class Interface : public BaseType {
}
/// Constructor for a known concept.
- Interface(ValueT t, Concept *conceptImpl)
- : BaseType(t), conceptImpl(conceptImpl) {
+ Interface(ValueT t, const Concept *conceptImpl)
+ : BaseType(t), conceptImpl(const_cast<Concept *>(conceptImpl)) {
assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
}
@@ -152,25 +152,6 @@ struct count_if_t_impl<Pred, N, T, Us...>
template <template <class> class Pred, typename... Ts>
using count_if_t = count_if_t_impl<Pred, 0, Ts...>;
-namespace {
-/// Type trait indicating whether all template arguments are
-/// trivially-destructible.
-template <typename... Args>
-struct all_trivially_destructible;
-
-template <typename Arg, typename... Args>
-struct all_trivially_destructible<Arg, Args...> {
- static constexpr const bool value =
- std::is_trivially_destructible<Arg>::value &&
- all_trivially_destructible<Args...>::value;
-};
-
-template <>
-struct all_trivially_destructible<> {
- static constexpr const bool value = true;
-};
-} // namespace
-
/// This class provides an efficient mapping between a given `Interface` type,
/// and a particular implementation of its concept.
class InterfaceMap {
@@ -182,7 +163,16 @@ class InterfaceMap {
template <typename... Types>
using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;
+ /// Trait to check if T provides a 'initializeInterfaceConcept' method.
+ template <typename T, typename... Args>
+ using has_initialize_method =
+ decltype(std::declval<T>().initializeInterfaceConcept(
+ std::declval<InterfaceMap &>()));
+ template <typename T>
+ using detect_initialize_method = llvm::is_detected<has_initialize_method, T>;
+
public:
+ InterfaceMap() = default;
InterfaceMap(InterfaceMap &&) = default;
InterfaceMap &operator=(InterfaceMap &&rhs) {
for (auto &it : interfaces)
@@ -205,11 +195,9 @@ class InterfaceMap {
if constexpr (numInterfaces == 0)
return InterfaceMap();
- std::array<std::pair<TypeID, void *>, numInterfaces> elements;
- std::pair<TypeID, void *> *elementIt = elements.data();
- (void)elementIt;
- (addModelAndUpdateIterator<Types>(elementIt), ...);
- return InterfaceMap(elements);
+ InterfaceMap map;
+ (map.insertPotentialInterface<Types>(), ...);
+ return map;
}
/// Returns an instance of the concept object for the given interface if it
@@ -222,42 +210,40 @@ class InterfaceMap {
/// Returns true if the interface map contains an interface for the given id.
bool contains(TypeID interfaceID) const { return lookup(interfaceID); }
- /// Create an InterfaceMap given with the implementation of the interfaces.
- /// The use of this constructor is in general discouraged in favor of
- /// 'InterfaceMap::get<InterfaceA, ...>()'.
- InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements);
-
- /// Insert the given models as implementations of the corresponding interfaces
- /// for the concrete attribute class.
+ /// Insert the given interface models.
template <typename... IfaceModels>
- void insert() {
- static_assert(all_trivially_destructible<IfaceModels...>::value,
- "interface models must be trivially destructible");
- std::pair<TypeID, void *> elements[] = {
- std::make_pair(IfaceModels::Interface::getInterfaceID(),
- new (malloc(sizeof(IfaceModels))) IfaceModels())...};
- insert(elements);
+ void insertModels() {
+ (insertModel<IfaceModels>(), ...);
}
private:
- InterfaceMap() = default;
-
- /// Assign the interface model of the type to the given opaque element
- /// iterator and increment it.
+ /// Insert the given interface type into the map, ignoring it if it doesn't
+ /// actually represent an interface.
template <typename T>
- static inline std::enable_if_t<detect_get_interface_id<T>::value>
- addModelAndUpdateIterator(std::pair<TypeID, void *> *&elementIt) {
- *elementIt = {T::getInterfaceID(), new (malloc(sizeof(typename T::ModelT)))
- typename T::ModelT()};
- ++elementIt;
+ inline void insertPotentialInterface() {
+ if constexpr (detect_get_interface_id<T>::value)
+ insertModel<typename T::ModelT>();
}
- /// Overload when `T` isn't an interface.
- template <typename T>
- static inline std::enable_if_t<!detect_get_interface_id<T>::value>
- addModelAndUpdateIterator(std::pair<TypeID, void *> *&) {}
- /// Insert the given set of interface models into the interface map.
- void insert(ArrayRef<std::pair<TypeID, void *>> elements);
+ /// Insert the given interface model into the map.
+ template <typename InterfaceModel>
+ void insertModel() {
+ // FIXME(#59975): Uncomment this when SPIRV no longer awkwardly reimplements
+ // interfaces in a way that isn't clean/compatible.
+ // static_assert(std::is_trivially_destructible_v<InterfaceModel>,
+ // "interface models must be trivially destructible");
+
+ // Build the interface model, optionally initializing if necessary.
+ InterfaceModel *model =
+ new (malloc(sizeof(InterfaceModel))) InterfaceModel();
+ if constexpr (detect_initialize_method<InterfaceModel>::value)
+ model->initializeInterfaceConcept(*this);
+
+ insert(InterfaceModel::Interface::getInterfaceID(), model);
+ }
+ /// Insert the given set of interface id and concept implementation into the
+ /// interface map.
+ void insert(TypeID interfaceId, void *conceptImpl);
/// Compare two TypeID instances by comparing the underlying pointer.
static bool compare(TypeID lhs, TypeID rhs) {
diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 7168f130c73a7..15f667e0ffce0 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -12,6 +12,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/iterator.h"
namespace llvm {
class Init;
@@ -72,10 +73,17 @@ class InterfaceMethod {
class Interface {
public:
explicit Interface(const llvm::Record *def);
+ Interface(const Interface &rhs) : def(rhs.def), methods(rhs.methods) {
+ for (auto &base : rhs.baseInterfaces)
+ baseInterfaces.push_back(std::make_unique<Interface>(*base));
+ }
// Return the name of this interface.
StringRef getName() const;
+ // Returns this interface's name prefixed with namespaces.
+ std::string getFullyQualifiedName() const;
+
// Return the C++ namespace of this interface.
StringRef getCppNamespace() const;
@@ -101,6 +109,11 @@ class Interface {
// Return the verify method body if it has one.
std::optional<StringRef> getVerify() const;
+ // Return the base interfaces of this interface.
+ auto getBaseInterfaces() const {
+ return llvm::make_pointee_range(baseInterfaces);
+ }
+
// If there's a verify method, return if it needs to access the ops in the
// regions.
bool verifyWithRegions() const;
@@ -114,6 +127,9 @@ class Interface {
// The methods of this interface.
SmallVector<InterfaceMethod, 8> methods;
+
+ // The base interfaces of this interface.
+ SmallVector<std::unique_ptr<Interface>> baseInterfaces;
};
// An interface that is registered to an Attribute.
diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index 5190b5f97af7a..701683fdcb5ce 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -299,7 +299,7 @@ DynamicOpDefinition::DynamicOpDefinition(
: Impl(StringAttr::get(dialect->getContext(),
(dialect->getNamespace() + "." + name).str()),
dialect, dialect->allocateTypeID(),
- /*interfaceMap=*/detail::InterfaceMap(std::nullopt)),
+ /*interfaceMap=*/detail::InterfaceMap()),
verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
parseFn(std::move(parseFn)), printFn(std::move(printFn)),
foldHookFn(std::move(foldHookFn)),
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index b0fe94f0c513a..5bbedda196415 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -742,7 +742,7 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
auto nameAttr = StringAttr::get(context, name);
it.first->second = std::make_unique<UnregisteredOpModel>(
nameAttr, nameAttr.getReferencedDialect(), TypeID::get<void>(),
- detail::InterfaceMap(std::nullopt));
+ detail::InterfaceMap());
}
impl = it.first->second.get();
}
diff --git a/mlir/lib/Support/InterfaceSupport.cpp b/mlir/lib/Support/InterfaceSupport.cpp
index 4b8dd58801e23..d813046eca7f3 100644
--- a/mlir/lib/Support/InterfaceSupport.cpp
+++ b/mlir/lib/Support/InterfaceSupport.cpp
@@ -18,27 +18,16 @@
using namespace mlir;
-detail::InterfaceMap::InterfaceMap(
- MutableArrayRef<std::pair<TypeID, void *>> elements)
- : interfaces(elements.begin(), elements.end()) {
- llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) {
- return compare(lhs.first, rhs.first);
- });
-}
-
-void detail::InterfaceMap::insert(
- ArrayRef<std::pair<TypeID, void *>> elements) {
+void detail::InterfaceMap::insert(TypeID interfaceId, void *conceptImpl) {
// Insert directly into the right position to keep the interfaces sorted.
- for (auto &element : elements) {
- TypeID id = element.first;
- auto *it = llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
- return compare(it.first, id);
- });
- if (it != interfaces.end() && it->first == id) {
- LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
- free(element.second);
- continue;
- }
- interfaces.insert(it, element);
+ auto *it =
+ llvm::lower_bound(interfaces, interfaceId, [](const auto &it, TypeID id) {
+ return compare(it.first, id);
+ });
+ if (it != interfaces.end() && it->first == interfaceId) {
+ LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
+ free(conceptImpl);
+ return;
}
+ interfaces.insert(it, {interfaceId, conceptImpl});
}
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 042fba42d5509..e9fc9274dee9e 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -8,6 +8,7 @@
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Dialect.h"
+#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
@@ -56,9 +57,23 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
traits.reserve(traitSet.size());
- for (auto *traitInit : *traitList)
- if (traitSet.insert(traitInit).second)
- traits.push_back(Trait::create(traitInit));
+ llvm::unique_function<void(llvm::ListInit *)> processTraitList =
+ [&](llvm::ListInit *traitList) {
+ for (auto *traitInit : *traitList) {
+ if (!traitSet.insert(traitInit).second)
+ continue;
+
+ // If this is an interface, add any bases to the trait list.
+ auto *traitDef = cast<llvm::DefInit>(traitInit)->getDef();
+ if (traitDef->isSubClassOf("Interface")) {
+ if (auto *bases = traitDef->getValueAsListInit("baseInterfaces"))
+ processTraitList(bases);
+ }
+
+ traits.push_back(Trait::create(traitInit));
+ }
+ };
+ processTraitList(traitList);
}
// Populate the parameters.
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index bd56f6b027007..c1d1ba0540c32 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Interfaces.h"
+#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
@@ -74,9 +75,25 @@ Interface::Interface(const llvm::Record *def) : def(def) {
assert(def->isSubClassOf("Interface") &&
"must be subclass of TableGen 'Interface' class");
+ // Initialize the interface methods.
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
+
+ // Initialize the interface base classes.
+ auto *basesInit =
+ dyn_cast<llvm::ListInit>(def->getValueInit("baseInterfaces"));
+ llvm::unique_function<void(Interface)> addBaseInterfaceFn =
+ [&](const Interface &baseInterface) {
+ // Inherit any base interfaces.
+ for (const auto &baseBaseInterface : baseInterface.getBaseInterfaces())
+ addBaseInterfaceFn(baseBaseInterface);
+
+ // Add the base interface.
+ baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
+ };
+ for (llvm::Init *init : basesInit->getValues())
+ addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
}
// Return the name of this interface.
@@ -84,6 +101,15 @@ StringRef Interface::getName() const {
return def->getValueAsString("cppInterfaceName");
}
+// Returns this interface's name prefixed with namespaces.
+std::string Interface::getFullyQualifiedName() const {
+ StringRef cppNamespace = getCppNamespace();
+ StringRef name = getName();
+ if (cppNamespace.empty())
+ return name.str();
+ return (cppNamespace + "::" + name).str();
+}
+
// Return the C++ namespace of this interface.
StringRef Interface::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 1b186f81cc0e6..77e27ccbe018a 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -711,13 +711,19 @@ void Operator::populateOpStructure() {
continue;
}
+ // Ignore duplicates.
+ if (!traitSet.insert(traitInit).second)
+ continue;
+
+ // If this is an interface with base classes, add the bases to the
+ // trait list.
+ if (def->isSubClassOf("Interface"))
+ insert(def->getValueAsListInit("baseInterfaces"));
+
// Verify if the trait has all the dependent traits declared before
// itself.
verifyTraitValidity(def);
-
- // Keep traits in the same order while skipping over duplicates.
- if (traitSet.insert(traitInit).second)
- traits.push_back(Trait::create(traitInit));
+ traits.push_back(Trait::create(traitInit));
}
};
insert(traitList);
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index 1d9fd9c21e46d..d7ca7089f8f98 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -12,21 +12,35 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaceBase.td"
-// A type interface used to test the ODS generation of type interfaces.
-def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
+// A set of type interfaces used to test interface inheritance.
+def TestBaseTypeInterfacePrintTypeA : TypeInterface<"TestBaseTypeInterfacePrintTypeA"> {
let cppNamespace = "::test";
let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeA", (ins "::mlir::Location":$loc), [{
emitRemark(loc) << $_type << " - TestA";
}]
- >,
+ >
+ ];
+}
+def TestBaseTypeInterfacePrintTypeB
+ : TypeInterface<"TestBaseTypeInterfacePrintTypeB", [TestBaseTypeInterfacePrintTypeA]> {
+ let cppNamespace = "::test";
+ let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeB", (ins "::mlir::Location":$loc),
[{}], /*defaultImplementation=*/[{
emitRemark(loc) << $_type << " - TestB";
}]
- >,
+ >
+ ];
+}
+
+// A type interface used to test the ODS generation of type interfaces.
+def TestTypeInterface
+ : TypeInterface<"TestTypeInterface", [TestBaseTypeInterfacePrintTypeB]> {
+ let cppNamespace = "::test";
+ let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeC", (ins "::mlir::Location":$loc)
>,
diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 8129eb1a8b3bb..76467f50152cb 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -1,4 +1,5 @@
// RUN: mlir-tblgen -gen-op-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-op-interface-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s --check-prefix=OP_DECL
// RUN: mlir-tblgen -gen-op-interface-docs -I %S/../../include %s | FileCheck %s --check-prefix=DOCS
@@ -33,6 +34,66 @@ def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
// DECL-NEXT: return (*static_cast<ConcreteOp *>(this)).someOtherMethod();
// DECL-NEXT: }
+def TestInheritanceBaseInterface : OpInterface<"TestInheritanceBaseInterface"> {
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{some function comment}],
+ /*retTy=*/"int",
+ /*methodName=*/"foo",
+ /*args=*/(ins "int":$input)
+ >
+ ];
+}
+def TestInheritanceMiddleBaseInterface
+ : OpInterface<"TestInheritanceMiddleBaseInterface", [TestInheritanceBaseInterface]> {
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{some function comment}],
+ /*retTy=*/"int",
+ /*methodName=*/"bar",
+ /*args=*/(ins "int":$input)
+ >
+ ];
+}
+def TestInheritanceZDerivedInterface
+ : OpInterface<"TestInheritanceZDerivedInterface", [TestInheritanceMiddleBaseInterface]>;
+
+// DECL: class TestInheritanceZDerivedInterface
+// DECL: struct Concept {
+// DECL: const TestInheritanceBaseInterface::Concept *implTestInheritanceBaseInterface = nullptr;
+// DECL: const TestInheritanceMiddleBaseInterface::Concept *implTestInheritanceMiddleBaseInterface = nullptr;
+
+// DECL: void initializeInterfaceConcept(::mlir::detail::InterfaceMap &interfaceMap) {
+// DECL: implTestInheritanceBaseInterface = interfaceMap.lookup<TestInheritanceBaseInterface>();
+// DECL: assert(implTestInheritanceBaseInterface && "`TestInheritanceZDerivedInterface` expected its base interface `TestInheritanceBaseInterface` to be registered");
+// DECL: implTestInheritanceMiddleBaseInterface = interfaceMap.lookup<TestInheritanceMiddleBaseInterface>();
+// DECL: assert(implTestInheritanceMiddleBaseInterface
+// DECL: }
+
+// DECL: //===----------------------------------------------------------------===//
+// DECL: // Inherited from TestInheritanceBaseInterface
+// DECL: //===----------------------------------------------------------------===//
+// DECL: operator TestInheritanceBaseInterface () const {
+// DECL: return TestInheritanceBaseInterface(*this, getImpl()->implTestInheritanceBaseInterface);
+// DECL: }
+// DECL: /// some function comment
+// DECL: int foo(int input);
+
+// DECL: //===----------------------------------------------------------------===//
+// DECL: // Inherited from TestInheritanceMiddleBaseInterface
+// DECL: //===----------------------------------------------------------------===//
+// DECL: operator TestInheritanceMiddleBaseInterface () const {
+// DECL: return TestInheritanceMiddleBaseInterface(*this, getImpl()->implTestInheritanceMiddleBaseInterface);
+// DECL: }
+// DECL: /// some function comment
+// DECL: int bar(int input);
+
+// DEF: int TestInheritanceZDerivedInterface::foo(int input) {
+// DEF-NEXT: getImpl()->implTestInheritanceBaseInterface->foo(getImpl()->implTestInheritanceBaseInterface, getOperation(), input);
+
+// DEF: int TestInheritanceZDerivedInterface::bar(int input) {
+// DEF-NEXT: return getImpl()->implTestInheritanceMiddleBaseInterface->bar(getImpl()->implTestInheritanceMiddleBaseInterface, getOperation(), input);
+
def TestOpInterface : OpInterface<"TestOpInterface"> {
let description = [{some op interface description}];
@@ -83,6 +144,8 @@ def TestDialect : Dialect {
def OpInterfaceOp : Op<TestDialect, "op_interface_op", [TestOpInterface]>;
+def OpInterfaceInterfacesOp : Op<TestDialect, "op_inherit_interface_op", [TestInheritanceZDerivedInterface]>;
+
def DeclareMethodsOp : Op<TestDialect, "declare_methods_op",
[DeclareOpInterfaceMethods<TestOpInterface>]>;
@@ -113,6 +176,9 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
// OP_DECL: int foo(int input);
// OP_DECL: int default_foo(int input);
+// OP_DECL: class OpInterfaceInterfacesOp :
+// OP_DECL-SAME: TestInheritanceBaseInterface::Trait, TestInheritanceMiddleBaseInterface::Trait, TestInheritanceZDerivedInterface::Trait
+
// DOCS-LABEL: {{^}}## TestOpInterface (`TestOpInterface`)
// DOCS: some op interface description
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 363bec72649dd..279edfd21443c 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -173,28 +173,21 @@ static void emitInterfaceMethodDoc(const InterfaceMethod &method,
if (std::optional<StringRef> description = method.getDescription())
tblgen::emitDescriptionComment(*description, os, prefix);
}
-
-static void emitInterfaceDef(const Interface &interface, StringRef valueType,
- raw_ostream &os) {
- StringRef interfaceName = interface.getName();
- StringRef cppNamespace = interface.getCppNamespace();
- cppNamespace.consume_front("::");
-
- // Insert the method definitions.
- bool isOpInterface = isa<OpInterface>(interface);
+static void emitInterfaceDefMethods(StringRef interfaceQualName,
+ const Interface &interface,
+ StringRef valueType, const Twine &implValue,
+ raw_ostream &os, bool isOpInterface) {
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os);
emitCPPType(method.getReturnType(), os);
- if (!cppNamespace.empty())
- os << cppNamespace << "::";
- os << interfaceName << "::";
+ os << interfaceQualName << "::";
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
// Forward to the method on the concrete operation type.
- os << " {\n return getImpl()->" << method.getName() << '(';
+ os << " {\n return " << implValue << "->" << method.getName() << '(';
if (!method.isStatic()) {
- os << "getImpl(), ";
+ os << implValue << ", ";
os << (isOpInterface ? "getOperation()" : "*this");
os << (method.arg_empty() ? "" : ", ");
}
@@ -205,6 +198,25 @@ static void emitInterfaceDef(const Interface &interface, StringRef valueType,
}
}
+static void emitInterfaceDef(const Interface &interface, StringRef valueType,
+ raw_ostream &os) {
+ std::string interfaceQualNameStr = interface.getFullyQualifiedName();
+ StringRef interfaceQualName = interfaceQualNameStr;
+ interfaceQualName.consume_front("::");
+
+ // Insert the method definitions.
+ bool isOpInterface = isa<OpInterface>(interface);
+ emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()",
+ os, isOpInterface);
+
+ // Insert the method definitions for base classes.
+ for (auto &base : interface.getBaseInterfaces()) {
+ emitInterfaceDefMethods(interfaceQualName, base, valueType,
+ "getImpl()->impl" + base.getName(), os,
+ isOpInterface);
+ }
+}
+
bool InterfaceGenerator::emitInterfaceDefs() {
llvm::emitSourceFileHeader("Interface Definitions", os);
@@ -221,6 +233,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
os << " struct Concept {\n";
// Insert each of the pure virtual concept methods.
+ os << " /// The methods defined by the interface.\n";
for (auto &method : interface.getMethods()) {
os << " ";
emitCPPType(method.getReturnType(), os);
@@ -234,6 +247,33 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
[&](const InterfaceMethod::Argument &arg) { os << arg.type; });
os << ");\n";
}
+
+ // Insert a field containing a concept for each of the base interfaces.
+ auto baseInterfaces = interface.getBaseInterfaces();
+ if (!baseInterfaces.empty()) {
+ os << " /// The base classes of this interface.\n";
+ for (const auto &base : interface.getBaseInterfaces()) {
+ os << " const " << base.getFullyQualifiedName() << "::Concept *impl"
+ << base.getName() << " = nullptr;\n";
+ }
+
+ // Define an "initialize" method that allows for the initialization of the
+ // base class concepts.
+ os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
+ "&interfaceMap) {\n";
+ std::string interfaceQualName = interface.getFullyQualifiedName();
+ for (const auto &base : interface.getBaseInterfaces()) {
+ StringRef baseName = base.getName();
+ std::string baseQualName = base.getFullyQualifiedName();
+ os << " impl" << baseName << " = interfaceMap.lookup<"
+ << baseQualName << ">();\n"
+ << " assert(impl" << baseName << " && \"`" << interfaceQualName
+ << "` expected its base interface `" << baseQualName
+ << "` to be registered\");\n";
+ }
+ os << " }\n";
+ }
+
os << " };\n";
}
@@ -242,9 +282,8 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
for (const char *modelClass : {"Model", "FallbackModel"}) {
os << " template<typename " << valueTemplate << ">\n";
os << " class " << modelClass << " : public Concept {\n public:\n";
- os << " using Interface = " << interface.getCppNamespace()
- << (interface.getCppNamespace().empty() ? "" : "::")
- << interface.getName() << ";\n";
+ os << " using Interface = " << interface.getFullyQualifiedName()
+ << ";\n";
os << " " << modelClass << "() : Concept{";
llvm::interleaveComma(
interface.getMethods(), os,
@@ -455,6 +494,27 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
os << " };\n";
}
+static void emitInterfaceDeclMethods(const Interface &interface,
+ raw_ostream &os, StringRef valueType,
+ bool isOpInterface,
+ tblgen::FmtContext &extraDeclsFmt) {
+ for (auto &method : interface.getMethods()) {
+ emitInterfaceMethodDoc(method, os, " ");
+ emitCPPType(method.getReturnType(), os << " ");
+ emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+ /*addConst=*/!isOpInterface);
+ os << ";\n";
+ }
+
+ // Emit any extra declarations.
+ if (std::optional<StringRef> extraDecls =
+ interface.getExtraClassDeclaration())
+ os << extraDecls->rtrim() << "\n";
+ if (std::optional<StringRef> extraDecls =
+ interface.getExtraSharedClassDeclaration())
+ os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
+}
+
void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
@@ -495,22 +555,30 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
// Insert the method declarations.
bool isOpInterface = isa<OpInterface>(interface);
- for (auto &method : interface.getMethods()) {
- emitInterfaceMethodDoc(method, os, " ");
- emitCPPType(method.getReturnType(), os << " ");
- emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
- /*addConst=*/!isOpInterface);
- os << ";\n";
+ emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
+ extraDeclsFmt);
+
+ // Insert the method declarations for base classes.
+ for (auto &base : interface.getBaseInterfaces()) {
+ std::string baseQualName = base.getFullyQualifiedName();
+ os << " //"
+ "===---------------------------------------------------------------"
+ "-===//\n"
+ << " // Inherited from " << baseQualName << "\n"
+ << " //"
+ "===---------------------------------------------------------------"
+ "-===//\n\n";
+
+ // Allow implicit conversion to the base interface.
+ os << " operator " << baseQualName << " () const {\n"
+ << " return " << baseQualName << "(*this, getImpl()->impl"
+ << base.getName() << ");\n"
+ << " }\n\n";
+
+ // Inherit the base interface's methods.
+ emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt);
}
- // Emit any extra declarations.
- if (std::optional<StringRef> extraDecls =
- interface.getExtraClassDeclaration())
- os << *extraDecls << "\n";
- if (std::optional<StringRef> extraDecls =
- interface.getExtraSharedClassDeclaration())
- os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt);
-
// Emit classof code if necessary.
if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
auto extraClassOfFmt = tblgen::FmtContext();
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index abe4bc65142b7..4f2b57f9872b5 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -236,6 +236,8 @@ static void emitModelDecl(const Availability &availability, raw_ostream &os) {
os << " template<typename ConcreteOp>\n";
os << " class " << modelClass << " : public Concept {\n"
<< " public:\n"
+ << " using Interface = " << availability.getInterfaceClassName()
+ << ";\n"
<< " " << availability.getQueryFnRetType() << " "
<< availability.getQueryFnName()
<< "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
@@ -258,6 +260,7 @@ static void emitInterfaceDecl(const Availability &availability,
StringRef cppNamespace = availability.getInterfaceClassNamespace();
NamespaceEmitter nsEmitter(os, cppNamespace);
+ os << "class " << interfaceName << ";\n\n";
// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
More information about the Mlir-commits
mailing list