[Mlir-commits] [mlir] d0e6fd9 - [mlir] Extend the promise interface mechanism
Fabian Mora
llvmlistbot at llvm.org
Tue Sep 5 06:55:33 PDT 2023
Author: Fabian Mora
Date: 2023-09-05T09:55:27-04:00
New Revision: d0e6fd99aa95ff61372ea328e9f89da2ee39c49c
URL: https://github.com/llvm/llvm-project/commit/d0e6fd99aa95ff61372ea328e9f89da2ee39c49c
DIFF: https://github.com/llvm/llvm-project/commit/d0e6fd99aa95ff61372ea328e9f89da2ee39c49c.diff
LOG: [mlir] Extend the promise interface mechanism
This patch pairs a promised interface with the object (Op/Attr/Type/Dialect) requesting the promise, ie:
```
declarePromisedInterface<MyAttr, MyInterface>();
```
Allowing to make fine grained promises. It also adds a mechanism to query if `Op/Attr/Type` has an specific promise returning true if the promise is there or if an implementation has been added. Finally it adds a couple of `Attr|TypeConstraints` that can be used in ODS to query if the promise or an implementation is there.
This patch tries to solve 2 issues:
1. Different entities cannot use the same promise.
```
declarePromisedInterface<MyInterface>();
// Resolves a promise.
MyAttr1::attachInterface<MyInterface>(ctx);
// Doesn't resolves a promise, as the previous attachment removed the promise.
MyAttr2::attachInterface<MyInterface>(ctx);
```
2. Is not possible to query if a promise has been declared.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D158464
Added:
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/DialectRegistry.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/Types.h
mlir/lib/Dialect/Func/IR/FuncOps.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
mlir/lib/IR/Dialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/unittests/IR/InterfaceAttachmentTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index ea4a90400cca32..53a2a1f2ad59f9 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -87,6 +87,15 @@ class Attribute {
friend ::llvm::hash_code hash_value(Attribute arg);
+ /// Returns true if `InterfaceT` has been promised by the dialect or
+ /// implemented.
+ template <typename InterfaceT>
+ bool hasPromiseOrImplementsInterface() {
+ return dialect_extension_detail::hasPromisedInterface(
+ getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
+ mlir::isa<InterfaceT>(*this);
+ }
+
/// Returns true if the type was registered with a particular trait.
template <template <typename T> class Trait>
bool hasTrait() {
@@ -289,7 +298,7 @@ class AttributeInterface
// Check that the current interface isn't an unresolved promise for the
// given attribute.
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
- attr.getDialect(), ConcreteType::getInterfaceID(),
+ attr.getDialect(), attr.getTypeID(), ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
#endif
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 16499c073772b9..45f29f37dd3b97 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -160,7 +160,7 @@ class Dialect {
/// nullptr.
DialectInterface *getRegisteredInterface(TypeID interfaceID) {
#ifndef NDEBUG
- handleUseOfUndefinedPromisedInterface(interfaceID);
+ handleUseOfUndefinedPromisedInterface(getTypeID(), interfaceID);
#endif
auto it = registeredInterfaces.find(interfaceID);
@@ -169,7 +169,8 @@ class Dialect {
template <typename InterfaceT>
InterfaceT *getRegisteredInterface() {
#ifndef NDEBUG
- handleUseOfUndefinedPromisedInterface(InterfaceT::getInterfaceID(),
+ handleUseOfUndefinedPromisedInterface(getTypeID(),
+ InterfaceT::getInterfaceID(),
llvm::getTypeName<InterfaceT>());
#endif
@@ -209,18 +210,21 @@ class Dialect {
/// registration. The promised interface type can be an interface of any type
/// not just a dialect interface, i.e. it may also be an
/// AttributeInterface/OpInterface/TypeInterface/etc.
- template <typename InterfaceT>
+ template <typename ConcreteT, typename InterfaceT>
void declarePromisedInterface() {
- unresolvedPromisedInterfaces.insert(InterfaceT::getInterfaceID());
+ unresolvedPromisedInterfaces.insert(
+ {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
}
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error. `interfaceName` is an optional string that contains a
/// more user readable name for the interface (such as the class name).
- void handleUseOfUndefinedPromisedInterface(TypeID interfaceID,
+ void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
+ TypeID interfaceID,
StringRef interfaceName = "") {
- if (unresolvedPromisedInterfaces.count(interfaceID)) {
+ if (unresolvedPromisedInterfaces.count(
+ {interfaceRequestorID, interfaceID})) {
llvm::report_fatal_error(
"checking for an interface (`" + interfaceName +
"`) that was promised by dialect '" + getNamespace() +
@@ -229,11 +233,27 @@ class Dialect {
"registered.");
}
}
+
/// Checks if the given interface, which is attempting to be attached to a
/// construct owned by this dialect, is a promised interface of this dialect
/// that has yet to be implemented. If so, it resolves the interface promise.
- void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceID) {
- unresolvedPromisedInterfaces.erase(interfaceID);
+ void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
+ TypeID interfaceID) {
+ unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID});
+ }
+
+ /// Checks if a promise has been made for the interface/requestor pair.
+ bool hasPromisedInterface(TypeID interfaceRequestorID,
+ TypeID interfaceID) const {
+ return unresolvedPromisedInterfaces.count(
+ {interfaceRequestorID, interfaceID});
+ }
+
+ /// Checks if a promise has been made for the interface/requestor pair.
+ template <typename ConcreteT, typename InterfaceT>
+ bool hasPromisedInterface() const {
+ return hasPromisedInterface(TypeID::get<ConcreteT>(),
+ InterfaceT::getInterfaceID());
}
protected:
@@ -332,7 +352,7 @@ class Dialect {
/// A set of interfaces that the dialect (or its constructs, i.e.
/// Attributes/Operations/Types/etc.) has promised to implement, but has yet
/// to provide an implementation for.
- DenseSet<TypeID> unresolvedPromisedInterfaces;
+ DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
friend class DialectRegistry;
friend void registerDialect();
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index b49bdc91536ad9..c13a1a1858eb14 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -102,15 +102,29 @@ namespace dialect_extension_detail {
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error.
-void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceID,
+void handleUseOfUndefinedPromisedInterface(Dialect &dialect,
+ TypeID interfaceRequestorID,
+ TypeID interfaceID,
StringRef interfaceName);
/// Checks if the given interface, which is attempting to be attached, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// the promised interface is marked as resolved.
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect,
+ TypeID interfaceRequestorID,
TypeID interfaceID);
+/// Checks if a promise has been made for the interface/requestor pair.
+bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
+ TypeID interfaceID);
+
+/// Checks if a promise has been made for the interface/requestor pair.
+template <typename ConcreteT, typename InterfaceT>
+bool hasPromisedInterface(Dialect &dialect) {
+ return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
+ InterfaceT::getInterfaceID());
+}
+
} // namespace dialect_extension_detail
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 47678a498f7952..236dd74839dfb0 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -450,6 +450,30 @@ class Results<dag rets> {
dag results = rets;
}
+//===----------------------------------------------------------------------===//
+// Common promised interface constraints
+//===----------------------------------------------------------------------===//
+
+// This constrait represents a promise or an implementation of an attr interface.
+class PromisedAttrInterface<AttrInterface interface> : AttrConstraint<
+ CPred<"$_self.hasPromiseOrImplementsInterface<" #
+ !if(!empty(interface.cppNamespace),
+ "",
+ interface.cppNamespace # "::") # interface.cppInterfaceName #">()">,
+ "promising or implementing the `" # interface.cppInterfaceName # "` attr interface">;
+
+// This predicate checks if the type promises or implementats a type interface.
+class HasPromiseOrImplementsTypeInterface<TypeInterface interface> :
+ CPred<"$_self.hasPromiseOrImplementsInterface<" #
+ !if(!empty(interface.cppNamespace),
+ "",
+ interface.cppNamespace # "::") # interface.cppInterfaceName #">()">;
+
+// This constrait represents a promise or an implementation of a type interface.
+class PromisedTypeInterface<TypeInterface interface> : TypeConstraint<
+ HasPromiseOrImplementsTypeInterface<interface>,
+ "promising or implementing the `" # interface.cppInterfaceName # "` type interface">;
+
//===----------------------------------------------------------------------===//
// Common op type constraints
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index f241c194a0d39d..67d923adbf9374 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -2075,7 +2075,7 @@ class OpInterface
// given operation.
if (Dialect *dialect = name.getDialect()) {
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
- *dialect, ConcreteType::getInterfaceID(),
+ *dialect, name.getTypeID(), ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
}
#endif
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 79e493022e0dbd..361a38e87b6ba3 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -698,6 +698,13 @@ class alignas(8) Operation final
/// If folding was unsuccessful, this function returns "failure".
LogicalResult fold(SmallVectorImpl<OpFoldResult> &results);
+ /// Returns true if `InterfaceT` has been promised by the dialect or
+ /// implemented.
+ template <typename InterfaceT>
+ bool hasPromiseOrImplementsInterface() const {
+ return name.hasPromiseOrImplementsInterface<InterfaceT>();
+ }
+
/// Returns true if the operation was registered with a particular trait, e.g.
/// hasTrait<OperandsAreSignlessIntegerLike>().
template <template <typename T> class Trait>
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 8519c703a134e4..670dd289c480a3 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -351,12 +351,21 @@ class OperationName {
void attachInterface() {
// Handle the case where the models resolve a promised interface.
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
- *getDialect(), Models::Interface::getInterfaceID()),
+ *getDialect(), getTypeID(), Models::Interface::getInterfaceID()),
...);
getImpl()->getInterfaceMap().insertModels<Models...>();
}
+ /// Returns true if `InterfaceT` has been promised by the dialect or
+ /// implemented.
+ template <typename InterfaceT>
+ bool hasPromiseOrImplementsInterface() const {
+ return dialect_extension_detail::hasPromisedInterface(
+ getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
+ hasInterface<InterfaceT>();
+ }
+
/// Returns true if this operation has the given interface registered to it.
template <typename T>
bool hasInterface() const {
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 5aa0dc4b02957e..c466e230d341d3 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -163,7 +163,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
// Handle the case where the models resolve a promised interface.
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
- abstract->getDialect(), IfaceModels::Interface::getInterfaceID()),
+ abstract->getDialect(), abstract->getTypeID(),
+ IfaceModels::Interface::getInterfaceID()),
...);
(checkInterfaceTarget<IfaceModels>(), ...);
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 5c4e06da829d9c..8443518027c0b6 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -180,6 +180,15 @@ class Type {
return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
+ /// Returns true if `InterfaceT` has been promised by the dialect or
+ /// implemented.
+ template <typename InterfaceT>
+ bool hasPromiseOrImplementsInterface() {
+ return dialect_extension_detail::hasPromisedInterface(
+ getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
+ mlir::isa<InterfaceT>(*this);
+ }
+
/// Returns true if the type was registered with a particular trait.
template <template <typename T> class Trait>
bool hasTrait() {
@@ -274,7 +283,7 @@ class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
// Check that the current interface isn't an unresolved promise for the
// given type.
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
- type.getDialect(), ConcreteType::getInterfaceID(),
+ type.getDialect(), type.getTypeID(), ConcreteType::getInterfaceID(),
llvm::getTypeName<ConcreteType>());
#endif
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index 756cdd93e63ac9..ca9b19c66595a8 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -40,7 +40,7 @@ void FuncDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
>();
- declarePromisedInterface<DialectInlinerInterface>();
+ declarePromisedInterface<FuncDialect, DialectInlinerInterface>();
}
/// Materialize a single constant operation from a given attribute value with
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 2abeb919d68a47..84965bd6880757 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -994,8 +994,8 @@ void NVVMDialect::initialize() {
// Support unknown operations because not all NVVM operations are
// registered.
allowUnknownOperations();
- declarePromisedInterface<ConvertToLLVMPatternInterface>();
- declarePromisedInterface<gpu::TargetAttrInterface>();
+ declarePromisedInterface<NVVMDialect, ConvertToLLVMPatternInterface>();
+ declarePromisedInterface<NVVMTargetAttr, gpu::TargetAttrInterface>();
}
LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index a6ebdf409768b5..32f34a8889af54 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -247,7 +247,7 @@ void ROCDLDialect::initialize() {
// Support unknown operations because not all ROCDL operations are registered.
allowUnknownOperations();
- declarePromisedInterface<gpu::TargetAttrInterface>();
+ declarePromisedInterface<ROCDLTargetAttr, gpu::TargetAttrInterface>();
}
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index e860299fe4c496..965386681f2709 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -97,7 +97,7 @@ bool Dialect::isValidNamespace(StringRef str) {
/// Register a set of dialect interfaces with this dialect instance.
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
// Handle the case where the models resolve a promised interface.
- handleAdditionOfUndefinedPromisedInterface(interface->getID());
+ handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID());
auto it = registeredInterfaces.try_emplace(interface->getID(),
std::move(interface));
@@ -125,8 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
for (auto *dialect : ctx->getLoadedDialects()) {
#ifndef NDEBUG
- dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
- interfaceName);
+ dialect->handleUseOfUndefinedPromisedInterface(
+ dialect->getTypeID(), interfaceKind, interfaceName);
#endif
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
@@ -151,13 +151,22 @@ DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
DialectExtensionBase::~DialectExtensionBase() = default;
void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
- Dialect &dialect, TypeID interfaceID, StringRef interfaceName) {
- dialect.handleUseOfUndefinedPromisedInterface(interfaceID, interfaceName);
+ Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
+ StringRef interfaceName) {
+ dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
+ interfaceID, interfaceName);
}
void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
- Dialect &dialect, TypeID interfaceID) {
- dialect.handleAdditionOfUndefinedPromisedInterface(interfaceID);
+ Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
+ dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
+ interfaceID);
+}
+
+bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
+ TypeID interfaceRequestorID,
+ TypeID interfaceID) {
+ return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7c059ea0c1189c..3c9f230f60c56f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -368,6 +368,20 @@ def DenseArrayNonNegativeOp : TEST_Op<"confined_non_negative_attr"> {
);
}
+//===----------------------------------------------------------------------===//
+// Test Promised Interfaces Constraints
+//===----------------------------------------------------------------------===//
+
+def PromisedInterfacesOp : TEST_Op<"promised_interfaces"> {
+ let arguments = (ins
+ ConfinedAttr<AnyAttr,
+ [PromisedAttrInterface<TestExternalAttrInterface>]>:$promisedAttr,
+ ConfinedType<AnyType,
+ [HasPromiseOrImplementsTypeInterface<TestExternalTypeInterface>]
+ >:$promisedType
+ );
+}
+
//===----------------------------------------------------------------------===//
// Test Enum Attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index fe855164f87484..2e1309ad776fe5 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -417,4 +417,30 @@ TEST(InterfaceAttachment, OperationDelayedContextAppend) {
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
}
+TEST(InterfaceAttachmentTest, PromisedInterfaces) {
+ // Attribute interfaces use the exact same mechanism as types, so just check
+ // that the promise mechanism works for attributes.
+ MLIRContext context;
+ auto testDialect = context.getOrLoadDialect<test::TestDialect>();
+ auto attr = test::SimpleAAttr::get(&context);
+
+ // `SimpleAAttr` doesn't implement nor promises the
+ // `TestExternalAttrInterface` interface.
+ EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
+ EXPECT_FALSE(
+ attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
+
+ // Add a promise `TestExternalAttrInterface`.
+ testDialect->declarePromisedInterface<test::SimpleAAttr,
+ TestExternalAttrInterface>();
+ EXPECT_TRUE(
+ attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
+
+ // Attach the interface.
+ test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
+ EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
+ EXPECT_TRUE(
+ attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
+}
+
} // namespace
More information about the Mlir-commits
mailing list