[Mlir-commits] [mlir] 7d1452d - [mlir] Refactor OpInterface internals to be faster and factor out common bits.
River Riddle
llvmlistbot at llvm.org
Wed Jun 24 17:32:47 PDT 2020
Author: River Riddle
Date: 2020-06-24T17:23:58-07:00
New Revision: 7d1452d8373e5aaaa94b5d0d6c9a1dc4be457311
URL: https://github.com/llvm/llvm-project/commit/7d1452d8373e5aaaa94b5d0d6c9a1dc4be457311
DIFF: https://github.com/llvm/llvm-project/commit/7d1452d8373e5aaaa94b5d0d6c9a1dc4be457311.diff
LOG: [mlir] Refactor OpInterface internals to be faster and factor out common bits.
This revision adds a new support header, InterfaceSupport, to contain various generic bits of functionality for implementing "Interfaces". Interfaces embody a mechanism for attaching concept-based polymorphism to a type system. With this refactoring a new InterfaceMap type is added to allow for efficient interface lookups without going through an indirect call. This should provide a decent performance speedup without changing the size of AbstractOperation.
In a future revision, this functionality will also be used to bring Interface like functionality to Attributes and Types.
Differential Revision: https://reviews.llvm.org/D81882
Added:
mlir/include/mlir/Support/InterfaceSupport.h
Modified:
mlir/docs/Interfaces.md
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/MLIRContext.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index d1050e76de5b..aca4545a7c5d 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -131,14 +131,14 @@ struct ExampleOpInterfaceTraits {
/// to be overridden.
struct Concept {
virtual ~Concept();
- virtual unsigned getNumInputs(Operation *op) = 0;
+ virtual unsigned getNumInputs(Operation *op) const = 0;
};
/// Define a model class that specializes a concept on a given operation type.
template <typename OpT>
struct Model : public Concept {
/// Override the method to dispatch on the concrete operation.
- unsigned getNumInputs(Operation *op) final {
+ unsigned getNumInputs(Operation *op) const final {
return llvm::cast<OpT>(op).getNumInputs();
}
};
@@ -151,7 +151,7 @@ public:
using OpInterface<ExampleOpInterface, ExampleOpInterfaceTraits>::OpInterface;
/// The interface dispatches to 'getImpl()', an instance of the concept.
- unsigned getNumInputs() {
+ unsigned getNumInputs() const {
return getImpl()->getNumInputs(getOperation());
}
};
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 3fc68d09f840..8251906dacb2 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1348,120 +1348,39 @@ class Op : public OpState,
traitID);
}
- /// Returns an opaque pointer to a concept instance of the interface with the
- /// given ID if one was registered to this operation.
- static void *getRawInterface(TypeID id) {
- return InterfaceLookup::template lookup<Traits<ConcreteType>...>(id);
- }
-
- struct InterfaceLookup {
- /// Trait to check if T provides a static 'getInterfaceID' method.
- template <typename T, typename... Args>
- using has_get_interface_id = decltype(T::getInterfaceID());
-
- /// If 'T' is the same interface as 'interfaceID' return the concept
- /// instance.
- template <typename T>
- static typename std::enable_if<
- llvm::is_detected<has_get_interface_id, T>::value, void *>::type
- lookup(TypeID interfaceID) {
- return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
- }
-
- /// 'T' is known to not be an interface, return nullptr.
- template <typename T>
- static typename std::enable_if<
- !llvm::is_detected<has_get_interface_id, T>::value, void *>::type
- lookup(TypeID) {
- return nullptr;
- }
-
- template <typename T, typename T2, typename... Ts>
- static void *lookup(TypeID interfaceID) {
- auto *concept = lookup<T>(interfaceID);
- return concept ? concept : lookup<T2, Ts...>(interfaceID);
- }
- };
+ /// Returns an interface map for the interfaces registered to this operation.
+ static detail::InterfaceMap getInterfaceMap() {
+ return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
+ }
- /// Allow access to 'hasTrait' and 'getRawInterface'.
+ /// Allow access to 'hasTrait' and 'getInterfaceMap'.
friend AbstractOperation;
};
-/// This class represents the base of an operation interface. Operation
-/// interfaces provide access to derived *Op properties through an opaquely
-/// Operation instance. Derived interfaces must also provide a 'Traits' class
-/// that defines a 'Concept' and a 'Model' class. The 'Concept' class defines an
-/// abstract virtual interface, where as the 'Model' class implements this
-/// interface for a specific derived *Op type. Both of these classes *must* not
-/// contain non-static data. A simple example is shown below:
-///
-/// struct ExampleOpInterfaceTraits {
-/// struct Concept {
-/// virtual unsigned getNumInputs(Operation *op) = 0;
-/// };
-/// template <typename OpT> class Model {
-/// unsigned getNumInputs(Operation *op) final {
-/// return cast<OpT>(op).getNumInputs();
-/// }
-/// };
-/// };
-///
+/// This class represents the base of an operation interface. See the definition
+/// of `detail::Interface` for requirements on the `Traits` type.
template <typename ConcreteType, typename Traits>
-class OpInterface : public Op<ConcreteType> {
+class OpInterface
+ : public detail::Interface<ConcreteType, Operation *, Traits,
+ Op<ConcreteType>, OpTrait::TraitBase> {
public:
- using Concept = typename Traits::Concept;
- template <typename T> using Model = typename Traits::template Model<T>;
using Base = OpInterface<ConcreteType, Traits>;
+ using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
+ Op<ConcreteType>, OpTrait::TraitBase>;
- OpInterface(Operation *op = nullptr)
- : Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {
- assert((!op || impl) &&
- "instantiating an interface with an unregistered operation");
- }
-
- /// Support 'classof' by checking if the given operation defines the concrete
- /// interface.
- static bool classof(Operation *op) { return getInterfaceFor(op); }
-
- /// Define an accessor for the ID of this interface.
- static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
-
- /// This is a special trait that registers a given interface with an
- /// operation.
- template <typename ConcreteOp>
- struct Trait : public OpTrait::TraitBase<ConcreteOp, Trait> {
- /// Define an accessor for the ID of this interface.
- static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
-
- /// Provide an accessor to a static instance of the interface model for the
- /// concrete operation type.
- /// The implementation is inspired from Sean Parent's concept-based
- /// polymorphism. A key
diff erence is that the set of classes erased is
- /// statically known, which alleviates the need for using dynamic memory
- /// allocation.
- /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
- /// virtual table and generate a singleton object for each instantiation of
- /// this class.
- static Concept &instance() {
- static Model<ConcreteOp> singleton;
- return singleton;
- }
- };
-
-protected:
- /// Get the raw concept in the correct derived concept type.
- Concept *getImpl() { return impl; }
+ /// Inherit the base class constructor.
+ using InterfaceBase::InterfaceBase;
private:
/// Returns the impl interface instance for the given operation.
- static Concept *getInterfaceFor(Operation *op) {
+ static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
// Access the raw interface from the abstract operation.
auto *abstractOp = op->getAbstractOperation();
return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
}
- /// A pointer to the impl concept object.
- Concept *impl;
+ /// Allow access to `getInterfaceFor`.
+ friend InterfaceBase;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index aa75c57db1ff..e3afaf316154 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -19,7 +19,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
-#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/InterfaceSupport.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
@@ -136,8 +136,7 @@ class AbstractOperation {
/// was registered to this operation, null otherwise. This should not be used
/// directly.
template <typename T> typename T::Concept *getInterface() const {
- return reinterpret_cast<typename T::Concept *>(
- getRawInterface(T::getInterfaceID()));
+ return interfaceMap.lookup<T>();
}
/// Returns if the operation has a particular trait.
@@ -157,7 +156,7 @@ class AbstractOperation {
T::getOperationName(), dialect, T::getOperationProperties(),
TypeID::get<T>(), T::parseAssembly, T::printAssembly,
T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns,
- T::getRawInterface, T::hasTrait);
+ T::getInterfaceMap(), T::hasTrait);
}
private:
@@ -171,22 +170,19 @@ class AbstractOperation {
SmallVectorImpl<OpFoldResult> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context),
- void *(&getRawInterface)(TypeID interfaceID),
- bool (&hasTrait)(TypeID traitID))
+ detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID))
: name(name), dialect(dialect), typeID(typeID),
parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), foldHook(foldHook),
getCanonicalizationPatterns(getCanonicalizationPatterns),
- opProperties(opProperties), getRawInterface(getRawInterface),
+ opProperties(opProperties), interfaceMap(std::move(interfaceMap)),
hasRawTrait(hasTrait) {}
/// The properties of the operation.
const OperationProperties opProperties;
- /// Returns a raw instance of the concept for the given interface id if it is
- /// registered to this operation, nullptr otherwise. This should not be used
- /// directly.
- void *(&getRawInterface)(TypeID interfaceID);
+ /// A map of interfaces that were registered to this operation.
+ detail::InterfaceMap interfaceMap;
/// This hook returns if the operation contains the trait corresponding
/// to the given TypeID.
diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
new file mode 100644
index 000000000000..c29c49d20e07
--- /dev/null
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -0,0 +1,204 @@
+//===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines several support classes for defining interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
+#define MLIR_SUPPORT_INTERFACESUPPORT_H
+
+#include "mlir/Support/TypeID.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/TypeName.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace detail {
+//===----------------------------------------------------------------------===//
+// Interface
+//===----------------------------------------------------------------------===//
+
+/// This class represents an abstract interface. An interface is a simplified
+/// mechanism for attaching concept based polymorphism to a class hierarchy. An
+/// interace is comprised of two components:
+/// * The derived interface class: This is what users interact with, and invoke
+/// methods on.
+/// * An interface `Trait` class: This is the class that is attached to the
+/// object implementing the interface. It is the mechanism with which models
+/// are specialized.
+///
+/// Derived interfaces types must provide the following template types:
+/// * ConcreteType: The CRTP derived type.
+/// * ValueT: The opaque type the derived interface operates on. For example
+/// `Operation*` for operation interfaces, or `Attribute` for
+/// attribute interfaces.
+/// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
+/// class. The 'Concept' class defines an abstract virtual interface,
+/// where as the 'Model' class implements this interface for a
+/// specific derived T type. Both of these classes *must* not contain
+/// non-static data. A simple example is shown below:
+///
+/// ```c++
+/// struct ExampleInterfaceTraits {
+/// struct Concept {
+/// virtual unsigned getNumInputs(T t) const = 0;
+/// };
+/// template <typename DerivedT> class Model {
+/// unsigned getNumInputs(T t) const final {
+/// return cast<DerivedT>(t).getNumInputs();
+/// }
+/// };
+/// };
+/// ```
+///
+/// * BaseType: A desired base type for the interface. This is a class that
+/// provides that provides specific functionality for the `ValueT`
+/// value. For instance the specific `Op` that will wrap the
+/// `Operation*` for an `OpInterface`.
+/// * BaseTrait: The base type for the interface trait. This is the base class
+/// to use for the interface trait that will be attached to each
+/// instance of `ValueT` that implements this interface.
+///
+template <typename ConcreteType, typename ValueT, typename Traits,
+ typename BaseType,
+ template <typename, template <typename> class> class BaseTrait>
+class Interface : public BaseType {
+public:
+ using Concept = typename Traits::Concept;
+ template <typename T> using Model = typename Traits::template Model<T>;
+ using InterfaceBase =
+ Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
+
+ Interface(ValueT t = ValueT())
+ : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
+ assert((!t || impl) &&
+ "instantiating an interface with an unregistered operation");
+ }
+
+ /// Support 'classof' by checking if the given object defines the concrete
+ /// interface.
+ static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
+
+ /// Define an accessor for the ID of this interface.
+ static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
+
+ /// This is a special trait that registers a given interface with an object.
+ template <typename ConcreteT>
+ struct Trait : public BaseTrait<ConcreteT, Trait> {
+ /// Define an accessor for the ID of this interface.
+ static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
+
+ /// Provide an accessor to a static instance of the interface model for the
+ /// concrete T type.
+ /// The implementation is inspired from Sean Parent's concept-based
+ /// polymorphism. A key
diff erence is that the set of classes erased is
+ /// statically known, which alleviates the need for using dynamic memory
+ /// allocation.
+ /// We use a zero-sized templated class `Model<ConcreteT>` to emit the
+ /// virtual table and generate a singleton object for each instantiation of
+ /// this class.
+ static Concept &instance() {
+ static Model<ConcreteT> singleton;
+ return singleton;
+ }
+ };
+
+protected:
+ /// Get the raw concept in the correct derived concept type.
+ const Concept *getImpl() const { return impl; }
+ Concept *getImpl() { return impl; }
+
+private:
+ /// A pointer to the impl concept object.
+ Concept *impl;
+};
+
+//===----------------------------------------------------------------------===//
+// InterfaceMap
+//===----------------------------------------------------------------------===//
+
+/// This class provides an efficient mapping between a given `Interface` type,
+/// and a particular implementation of its concept.
+class InterfaceMap {
+public:
+ /// Construct an InterfaceMap with the given set of template types. For
+ /// convenience given that object trait lists may contain other non-interface
+ /// types, not all of the types need to be interfaces. The provided types that
+ /// do not represent interfaces are not added to the interface map.
+ template <typename... Types> static InterfaceMap get() {
+ return InterfaceMap(MapBuilder::create<Types...>());
+ }
+
+ /// Returns an instance of the concept object for the given interface if it
+ /// was registered to this map, null otherwise.
+ template <typename T> typename T::Concept *lookup() const {
+ if (!interfaces)
+ return nullptr;
+ return reinterpret_cast<typename T::Concept *>(
+ interfaces->lookup(T::getInterfaceID()));
+ }
+
+private:
+ /// This struct provides support for building a map of interfaces.
+ class MapBuilder {
+ public:
+ template <typename... Types>
+ static std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> create() {
+ // Filter the provided types for those that are interfaces. This reduces
+ // the amount of maps that are generated.
+ return createImpl((typename FilterTypes<detect_get_interface_id,
+ Types...>::type *)nullptr);
+ }
+
+ private:
+ /// Trait to check if T provides a static 'getInterfaceID' method.
+ template <typename T, typename... Args>
+ using has_get_interface_id = decltype(T::getInterfaceID());
+ template <typename T>
+ using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
+
+ /// Utility to filter a given sequence of types base upon a predicate.
+ template <bool> struct FilterTypeT {
+ template <class E> using type = std::tuple<E>;
+ };
+ template <> struct FilterTypeT<false> {
+ template <class E> using type = std::tuple<>;
+ };
+ template <template <class> class Pred, class... Es> struct FilterTypes {
+ using type = decltype(std::tuple_cat(
+ std::declval<
+ typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
+ };
+ template <typename... Ts>
+ static std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>>
+ createImpl(std::tuple<Ts...> *) {
+ // Only create an instance of the map if there are any interface types.
+ if (sizeof...(Ts) == 0)
+ return std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>>();
+
+ auto map = std::make_unique<llvm::SmallDenseMap<TypeID, void *>>();
+ (void)std::initializer_list<int>{
+ 0, (map->try_emplace(Ts::getInterfaceID(), &Ts::instance()), 0)...};
+ return map;
+ }
+ };
+
+private:
+ InterfaceMap(std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces)
+ : interfaces(std::move(interfaces)) {}
+
+ /// The internal map of interfaces. This is constructed statically for each
+ /// set of interfaces.
+ std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2beb1a91dbae..a6c3a4ca1a2f 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -560,9 +560,10 @@ void Dialect::addOperation(AbstractOperation opInfo) {
auto &impl = context->getImpl();
// Lock access to the context registry.
+ StringRef opName = opInfo.name;
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
- if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) {
- llvm::errs() << "error: operation named '" << opInfo.name
+ if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
+ llvm::errs() << "error: operation named '" << opName
<< "' is already registered.\n";
abort();
}
More information about the Mlir-commits
mailing list