[Mlir-commits] [mlir] 7d1452d - [mlir] Refactor OpInterface internals to be faster and factor out common bits.

River Riddle llvmlistbot at llvm.org
Wed Jun 24 17:32:47 PDT 2020


Author: River Riddle
Date: 2020-06-24T17:23:58-07:00
New Revision: 7d1452d8373e5aaaa94b5d0d6c9a1dc4be457311

URL: https://github.com/llvm/llvm-project/commit/7d1452d8373e5aaaa94b5d0d6c9a1dc4be457311
DIFF: https://github.com/llvm/llvm-project/commit/7d1452d8373e5aaaa94b5d0d6c9a1dc4be457311.diff

LOG: [mlir] Refactor OpInterface internals to be faster and factor out common bits.

This revision adds a new support header, InterfaceSupport, to contain various generic bits of functionality for implementing "Interfaces". Interfaces embody a mechanism for attaching concept-based polymorphism to a type system. With this refactoring a new InterfaceMap type is added to allow for efficient interface lookups without going through an indirect call. This should provide a decent performance speedup without changing the size of AbstractOperation.

In a future revision, this functionality will also be used to bring Interface like functionality to Attributes and Types.

Differential Revision: https://reviews.llvm.org/D81882

Added: 
    mlir/include/mlir/Support/InterfaceSupport.h

Modified: 
    mlir/docs/Interfaces.md
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/MLIRContext.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index d1050e76de5b..aca4545a7c5d 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -131,14 +131,14 @@ struct ExampleOpInterfaceTraits {
   /// to be overridden.
   struct Concept {
     virtual ~Concept();
-    virtual unsigned getNumInputs(Operation *op) = 0;
+    virtual unsigned getNumInputs(Operation *op) const = 0;
   };
 
   /// Define a model class that specializes a concept on a given operation type.
   template <typename OpT>
   struct Model : public Concept {
     /// Override the method to dispatch on the concrete operation.
-    unsigned getNumInputs(Operation *op) final {
+    unsigned getNumInputs(Operation *op) const final {
       return llvm::cast<OpT>(op).getNumInputs();
     }
   };
@@ -151,7 +151,7 @@ public:
   using OpInterface<ExampleOpInterface, ExampleOpInterfaceTraits>::OpInterface;
 
   /// The interface dispatches to 'getImpl()', an instance of the concept.
-  unsigned getNumInputs() {
+  unsigned getNumInputs() const {
     return getImpl()->getNumInputs(getOperation());
   }
 };

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 3fc68d09f840..8251906dacb2 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1348,120 +1348,39 @@ class Op : public OpState,
                               traitID);
   }
 
-  /// Returns an opaque pointer to a concept instance of the interface with the
-  /// given ID if one was registered to this operation.
-  static void *getRawInterface(TypeID id) {
-    return InterfaceLookup::template lookup<Traits<ConcreteType>...>(id);
-  }
-
-  struct InterfaceLookup {
-    /// Trait to check if T provides a static 'getInterfaceID' method.
-    template <typename T, typename... Args>
-    using has_get_interface_id = decltype(T::getInterfaceID());
-
-    /// If 'T' is the same interface as 'interfaceID' return the concept
-    /// instance.
-    template <typename T>
-    static typename std::enable_if<
-        llvm::is_detected<has_get_interface_id, T>::value, void *>::type
-    lookup(TypeID interfaceID) {
-      return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
-    }
-
-    /// 'T' is known to not be an interface, return nullptr.
-    template <typename T>
-    static typename std::enable_if<
-        !llvm::is_detected<has_get_interface_id, T>::value, void *>::type
-    lookup(TypeID) {
-      return nullptr;
-    }
-
-    template <typename T, typename T2, typename... Ts>
-    static void *lookup(TypeID interfaceID) {
-      auto *concept = lookup<T>(interfaceID);
-      return concept ? concept : lookup<T2, Ts...>(interfaceID);
-    }
-  };
+  /// Returns an interface map for the interfaces registered to this operation.
+  static detail::InterfaceMap getInterfaceMap() {
+    return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
+  }
 
-  /// Allow access to 'hasTrait' and 'getRawInterface'.
+  /// Allow access to 'hasTrait' and 'getInterfaceMap'.
   friend AbstractOperation;
 };
 
-/// This class represents the base of an operation interface. Operation
-/// interfaces provide access to derived *Op properties through an opaquely
-/// Operation instance. Derived interfaces must also provide a 'Traits' class
-/// that defines a 'Concept' and a 'Model' class. The 'Concept' class defines an
-/// abstract virtual interface, where as the 'Model' class implements this
-/// interface for a specific derived *Op type. Both of these classes *must* not
-/// contain non-static data. A simple example is shown below:
-///
-///  struct ExampleOpInterfaceTraits {
-///    struct Concept {
-///      virtual unsigned getNumInputs(Operation *op) = 0;
-///    };
-///    template <typename OpT> class Model {
-///      unsigned getNumInputs(Operation *op) final {
-///        return cast<OpT>(op).getNumInputs();
-///      }
-///    };
-///  };
-///
+/// This class represents the base of an operation interface. See the definition
+/// of `detail::Interface` for requirements on the `Traits` type.
 template <typename ConcreteType, typename Traits>
-class OpInterface : public Op<ConcreteType> {
+class OpInterface
+    : public detail::Interface<ConcreteType, Operation *, Traits,
+                               Op<ConcreteType>, OpTrait::TraitBase> {
 public:
-  using Concept = typename Traits::Concept;
-  template <typename T> using Model = typename Traits::template Model<T>;
   using Base = OpInterface<ConcreteType, Traits>;
+  using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
+                                          Op<ConcreteType>, OpTrait::TraitBase>;
 
-  OpInterface(Operation *op = nullptr)
-      : Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {
-    assert((!op || impl) &&
-           "instantiating an interface with an unregistered operation");
-  }
-
-  /// Support 'classof' by checking if the given operation defines the concrete
-  /// interface.
-  static bool classof(Operation *op) { return getInterfaceFor(op); }
-
-  /// Define an accessor for the ID of this interface.
-  static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
-
-  /// This is a special trait that registers a given interface with an
-  /// operation.
-  template <typename ConcreteOp>
-  struct Trait : public OpTrait::TraitBase<ConcreteOp, Trait> {
-    /// Define an accessor for the ID of this interface.
-    static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
-
-    /// Provide an accessor to a static instance of the interface model for the
-    /// concrete operation type.
-    /// The implementation is inspired from Sean Parent's concept-based
-    /// polymorphism. A key 
diff erence is that the set of classes erased is
-    /// statically known, which alleviates the need for using dynamic memory
-    /// allocation.
-    /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
-    /// virtual table and generate a singleton object for each instantiation of
-    /// this class.
-    static Concept &instance() {
-      static Model<ConcreteOp> singleton;
-      return singleton;
-    }
-  };
-
-protected:
-  /// Get the raw concept in the correct derived concept type.
-  Concept *getImpl() { return impl; }
+  /// Inherit the base class constructor.
+  using InterfaceBase::InterfaceBase;
 
 private:
   /// Returns the impl interface instance for the given operation.
-  static Concept *getInterfaceFor(Operation *op) {
+  static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
     // Access the raw interface from the abstract operation.
     auto *abstractOp = op->getAbstractOperation();
     return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
   }
 
-  /// A pointer to the impl concept object.
-  Concept *impl;
+  /// Allow access to `getInterfaceFor`.
+  friend InterfaceBase;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index aa75c57db1ff..e3afaf316154 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -19,7 +19,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
-#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/InterfaceSupport.h"
 #include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
@@ -136,8 +136,7 @@ class AbstractOperation {
   /// was registered to this operation, null otherwise. This should not be used
   /// directly.
   template <typename T> typename T::Concept *getInterface() const {
-    return reinterpret_cast<typename T::Concept *>(
-        getRawInterface(T::getInterfaceID()));
+    return interfaceMap.lookup<T>();
   }
 
   /// Returns if the operation has a particular trait.
@@ -157,7 +156,7 @@ class AbstractOperation {
         T::getOperationName(), dialect, T::getOperationProperties(),
         TypeID::get<T>(), T::parseAssembly, T::printAssembly,
         T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns,
-        T::getRawInterface, T::hasTrait);
+        T::getInterfaceMap(), T::hasTrait);
   }
 
 private:
@@ -171,22 +170,19 @@ class AbstractOperation {
                                 SmallVectorImpl<OpFoldResult> &results),
       void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
                                           MLIRContext *context),
-      void *(&getRawInterface)(TypeID interfaceID),
-      bool (&hasTrait)(TypeID traitID))
+      detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID))
       : name(name), dialect(dialect), typeID(typeID),
         parseAssembly(parseAssembly), printAssembly(printAssembly),
         verifyInvariants(verifyInvariants), foldHook(foldHook),
         getCanonicalizationPatterns(getCanonicalizationPatterns),
-        opProperties(opProperties), getRawInterface(getRawInterface),
+        opProperties(opProperties), interfaceMap(std::move(interfaceMap)),
         hasRawTrait(hasTrait) {}
 
   /// The properties of the operation.
   const OperationProperties opProperties;
 
-  /// Returns a raw instance of the concept for the given interface id if it is
-  /// registered to this operation, nullptr otherwise. This should not be used
-  /// directly.
-  void *(&getRawInterface)(TypeID interfaceID);
+  /// A map of interfaces that were registered to this operation.
+  detail::InterfaceMap interfaceMap;
 
   /// This hook returns if the operation contains the trait corresponding
   /// to the given TypeID.

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
new file mode 100644
index 000000000000..c29c49d20e07
--- /dev/null
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -0,0 +1,204 @@
+//===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines several support classes for defining interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
+#define MLIR_SUPPORT_INTERFACESUPPORT_H
+
+#include "mlir/Support/TypeID.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/TypeName.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace detail {
+//===----------------------------------------------------------------------===//
+// Interface
+//===----------------------------------------------------------------------===//
+
+/// This class represents an abstract interface. An interface is a simplified
+/// mechanism for attaching concept based polymorphism to a class hierarchy. An
+/// interace is comprised of two components:
+/// * The derived interface class: This is what users interact with, and invoke
+///   methods on.
+/// * An interface `Trait` class: This is the class that is attached to the
+///   object implementing the interface. It is the mechanism with which models
+///   are specialized.
+///
+/// Derived interfaces types must provide the following template types:
+/// * ConcreteType: The CRTP derived type.
+/// * ValueT: The opaque type the derived interface operates on. For example
+///           `Operation*` for operation interfaces, or `Attribute` for
+///           attribute interfaces.
+/// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
+///           class. The 'Concept' class defines an abstract virtual interface,
+///           where as the 'Model' class implements this interface for a
+///           specific derived T type. Both of these classes *must* not contain
+///           non-static data. A simple example is shown below:
+///
+/// ```c++
+///    struct ExampleInterfaceTraits {
+///      struct Concept {
+///        virtual unsigned getNumInputs(T t) const = 0;
+///      };
+///      template <typename DerivedT> class Model {
+///        unsigned getNumInputs(T t) const final {
+///          return cast<DerivedT>(t).getNumInputs();
+///        }
+///      };
+///    };
+/// ```
+///
+/// * BaseType: A desired base type for the interface. This is a class that
+///             provides that provides specific functionality for the `ValueT`
+///             value. For instance the specific `Op` that will wrap the
+///             `Operation*` for an `OpInterface`.
+/// * BaseTrait: The base type for the interface trait. This is the base class
+///              to use for the interface trait that will be attached to each
+///              instance of `ValueT` that implements this interface.
+///
+template <typename ConcreteType, typename ValueT, typename Traits,
+          typename BaseType,
+          template <typename, template <typename> class> class BaseTrait>
+class Interface : public BaseType {
+public:
+  using Concept = typename Traits::Concept;
+  template <typename T> using Model = typename Traits::template Model<T>;
+  using InterfaceBase =
+      Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
+
+  Interface(ValueT t = ValueT())
+      : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
+    assert((!t || impl) &&
+           "instantiating an interface with an unregistered operation");
+  }
+
+  /// Support 'classof' by checking if the given object defines the concrete
+  /// interface.
+  static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
+
+  /// Define an accessor for the ID of this interface.
+  static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
+
+  /// This is a special trait that registers a given interface with an object.
+  template <typename ConcreteT>
+  struct Trait : public BaseTrait<ConcreteT, Trait> {
+    /// Define an accessor for the ID of this interface.
+    static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
+
+    /// Provide an accessor to a static instance of the interface model for the
+    /// concrete T type.
+    /// The implementation is inspired from Sean Parent's concept-based
+    /// polymorphism. A key 
diff erence is that the set of classes erased is
+    /// statically known, which alleviates the need for using dynamic memory
+    /// allocation.
+    /// We use a zero-sized templated class `Model<ConcreteT>` to emit the
+    /// virtual table and generate a singleton object for each instantiation of
+    /// this class.
+    static Concept &instance() {
+      static Model<ConcreteT> singleton;
+      return singleton;
+    }
+  };
+
+protected:
+  /// Get the raw concept in the correct derived concept type.
+  const Concept *getImpl() const { return impl; }
+  Concept *getImpl() { return impl; }
+
+private:
+  /// A pointer to the impl concept object.
+  Concept *impl;
+};
+
+//===----------------------------------------------------------------------===//
+// InterfaceMap
+//===----------------------------------------------------------------------===//
+
+/// This class provides an efficient mapping between a given `Interface` type,
+/// and a particular implementation of its concept.
+class InterfaceMap {
+public:
+  /// Construct an InterfaceMap with the given set of template types. For
+  /// convenience given that object trait lists may contain other non-interface
+  /// types, not all of the types need to be interfaces. The provided types that
+  /// do not represent interfaces are not added to the interface map.
+  template <typename... Types> static InterfaceMap get() {
+    return InterfaceMap(MapBuilder::create<Types...>());
+  }
+
+  /// Returns an instance of the concept object for the given interface if it
+  /// was registered to this map, null otherwise.
+  template <typename T> typename T::Concept *lookup() const {
+    if (!interfaces)
+      return nullptr;
+    return reinterpret_cast<typename T::Concept *>(
+        interfaces->lookup(T::getInterfaceID()));
+  }
+
+private:
+  /// This struct provides support for building a map of interfaces.
+  class MapBuilder {
+  public:
+    template <typename... Types>
+    static std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> create() {
+      // Filter the provided types for those that are interfaces. This reduces
+      // the amount of maps that are generated.
+      return createImpl((typename FilterTypes<detect_get_interface_id,
+                                              Types...>::type *)nullptr);
+    }
+
+  private:
+    /// Trait to check if T provides a static 'getInterfaceID' method.
+    template <typename T, typename... Args>
+    using has_get_interface_id = decltype(T::getInterfaceID());
+    template <typename T>
+    using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
+
+    /// Utility to filter a given sequence of types base upon a predicate.
+    template <bool> struct FilterTypeT {
+      template <class E> using type = std::tuple<E>;
+    };
+    template <> struct FilterTypeT<false> {
+      template <class E> using type = std::tuple<>;
+    };
+    template <template <class> class Pred, class... Es> struct FilterTypes {
+      using type = decltype(std::tuple_cat(
+          std::declval<
+              typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
+    };
+    template <typename... Ts>
+    static std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>>
+    createImpl(std::tuple<Ts...> *) {
+      // Only create an instance of the map if there are any interface types.
+      if (sizeof...(Ts) == 0)
+        return std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>>();
+
+      auto map = std::make_unique<llvm::SmallDenseMap<TypeID, void *>>();
+      (void)std::initializer_list<int>{
+          0, (map->try_emplace(Ts::getInterfaceID(), &Ts::instance()), 0)...};
+      return map;
+    }
+  };
+
+private:
+  InterfaceMap(std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces)
+      : interfaces(std::move(interfaces)) {}
+
+  /// The internal map of interfaces. This is constructed statically for each
+  /// set of interfaces.
+  std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2beb1a91dbae..a6c3a4ca1a2f 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -560,9 +560,10 @@ void Dialect::addOperation(AbstractOperation opInfo) {
   auto &impl = context->getImpl();
 
   // Lock access to the context registry.
+  StringRef opName = opInfo.name;
   ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
-  if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) {
-    llvm::errs() << "error: operation named '" << opInfo.name
+  if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
+    llvm::errs() << "error: operation named '" << opName
                  << "' is already registered.\n";
     abort();
   }


        


More information about the Mlir-commits mailing list