[Mlir-commits] [mlir] 31c8866 - [mlir] Remove the use of FilterTypes for template metaprogramming

River Riddle llvmlistbot at llvm.org
Fri Apr 15 12:57:29 PDT 2022


Author: River Riddle
Date: 2022-04-15T12:57:07-07:00
New Revision: 31c88660ab155beb5e6796ec1382afd2c0e52978

URL: https://github.com/llvm/llvm-project/commit/31c88660ab155beb5e6796ec1382afd2c0e52978
DIFF: https://github.com/llvm/llvm-project/commit/31c88660ab155beb5e6796ec1382afd2c0e52978.diff

LOG: [mlir] Remove the use of FilterTypes for template metaprogramming

This technique results in an explosion in compile time, resulting from a
huge number of std::tuple/concat instatiations. This technique is replaced
by simpler metaprogramming and results in a signficant reduction in
compile time. A local debug/asan build saw a 4x speed up in the processing
of ArithmeticOps.h.inc, and given the nature of this change every dialect
should see similar reductions in compile time.

Differential Revision: https://reviews.llvm.org/D123360

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/Support/InterfaceSupport.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 48bdec9483f84..5a1bc0133d1a7 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1540,36 +1540,24 @@ foldTrait(Operation *op, ArrayRef<Attribute> operands,
   // fail to fold this trait.
   return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
 }
+template <typename Trait>
+static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
+                               LogicalResult>
+foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
+  return failure();
+}
 
-/// The internal implementation of `foldTraits` below that returns the result of
-/// folding a set of trait types `Ts` that implement a `foldTrait` method.
+/// Given a tuple type containing a set of traits, return the result of folding
+/// the given operation.
 template <typename... Ts>
-static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
-                                    SmallVectorImpl<OpFoldResult> &results,
-                                    std::tuple<Ts...> *) {
+static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
+                                SmallVectorImpl<OpFoldResult> &results) {
   bool anyFolded = false;
   (void)std::initializer_list<int>{
       (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
   return success(anyFolded);
 }
 
-/// Given a tuple type containing a set of traits that contain a `foldTrait`
-/// method, return the result of folding the given operation.
-template <typename TraitTupleT>
-static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
-foldTraits(Operation *op, ArrayRef<Attribute> operands,
-           SmallVectorImpl<OpFoldResult> &results) {
-  return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
-}
-/// A variant of the method above that is specialized when there are no traits
-/// that contain a `foldTrait` method.
-template <typename TraitTupleT>
-static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
-foldTraits(Operation *op, ArrayRef<Attribute> operands,
-           SmallVectorImpl<OpFoldResult> &results) {
-  return failure();
-}
-
 //===----------------------------------------------------------------------===//
 // Trait Verification
 
@@ -1587,44 +1575,51 @@ template <typename T>
 using detect_has_verify_region_trait =
     llvm::is_detected<has_verify_region_trait, T>;
 
-/// The internal implementation of `verifyTraits` below that returns the result
-/// of verifying the current operation with all of the provided trait types
-/// `Ts`.
+/// Verify the given trait if it provides a verifier.
+template <typename T>
+std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult>
+verifyTrait(Operation *op) {
+  return T::verifyTrait(op);
+}
+template <typename T>
+inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult>
+verifyTrait(Operation *) {
+  return success();
+}
+
+/// Given a set of traits, return the result of verifying the given operation.
 template <typename... Ts>
-static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
+LogicalResult verifyTraits(Operation *op) {
   LogicalResult result = success();
   (void)std::initializer_list<int>{
-      (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
+      (result = succeeded(result) ? verifyTrait<Ts>(op) : failure(), 0)...};
   return result;
 }
 
-/// Given a tuple type containing a set of traits that contain a
-/// `verifyTrait` method, return the result of verifying the given operation.
-template <typename TraitTupleT>
-static LogicalResult verifyTraits(Operation *op) {
-  return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
+/// Verify the given trait if it provides a region verifier.
+template <typename T>
+std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult>
+verifyRegionTrait(Operation *op) {
+  return T::verifyRegionTrait(op);
+}
+template <typename T>
+inline std::enable_if_t<!detect_has_verify_region_trait<T>::value,
+                        LogicalResult>
+verifyRegionTrait(Operation *) {
+  return success();
 }
 
-/// The internal implementation of `verifyRegionTraits` below that returns the
-/// result of verifying the current operation with all of the provided trait
-/// types `Ts`.
+/// Given a set of traits, return the result of verifying the regions of the
+/// given operation.
 template <typename... Ts>
-static LogicalResult verifyRegionTraitsImpl(Operation *op,
-                                            std::tuple<Ts...> *) {
+LogicalResult verifyRegionTraits(Operation *op) {
   (void)op;
   LogicalResult result = success();
   (void)std::initializer_list<int>{
-      (result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(),
+      (result = succeeded(result) ? verifyRegionTrait<Ts>(op) : failure(),
        0)...};
   return result;
 }
-
-/// Given a tuple type containing a set of traits that contain a
-/// `verifyTrait` method, return the result of verifying the given operation.
-template <typename TraitTupleT>
-static LogicalResult verifyRegionTraits(Operation *op) {
-  return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr);
-}
 } // namespace op_definition_impl
 
 //===----------------------------------------------------------------------===//
@@ -1733,18 +1728,6 @@ class Op : public OpState, public Traits<ConcreteType>... {
       decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
   template <typename T>
   using detect_has_print = llvm::is_detected<has_print, T>;
-  /// A tuple type containing the traits that have a `foldTrait` function.
-  using FoldableTraitsTupleT = typename detail::FilterTypes<
-      op_definition_impl::detect_has_any_fold_trait,
-      Traits<ConcreteType>...>::type;
-  /// A tuple type containing the traits that have a verify function.
-  using VerifiableTraitsTupleT =
-      typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
-                                   Traits<ConcreteType>...>::type;
-  /// A tuple type containing the region traits that have a verify function.
-  using VerifiableRegionTraitsTupleT = typename detail::FilterTypes<
-      op_definition_impl::detect_has_verify_region_trait,
-      Traits<ConcreteType>...>::type;
 
   /// Returns an interface map containing the interfaces registered to this
   /// operation.
@@ -1794,8 +1777,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
     return [](Operation *op, ArrayRef<Attribute> operands,
               SmallVectorImpl<OpFoldResult> &results) {
       // In this case, we only need to fold the traits of the operation.
-      return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
-                                                                  results);
+      return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
+          op, operands, results);
     };
   }
   /// Return the result of folding a single result operation that defines a
@@ -1809,7 +1792,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.
     if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
-      if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
+      if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
               op, operands, results)))
         return success();
       return success(static_cast<bool>(result));
@@ -1826,7 +1809,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.
     if (failed(result) || results.empty()) {
-      if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
+      if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
               op, operands, results)))
         return success();
     }
@@ -1879,7 +1862,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
     static_assert(hasNoDataMembers(),
                   "Op class shouldn't define new data members");
     return failure(
-        failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
+        failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
         failed(cast<ConcreteType>(op).verify()));
   }
   static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
@@ -1889,9 +1872,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
   static LogicalResult verifyRegionInvariants(Operation *op) {
     static_assert(hasNoDataMembers(),
                   "Op class shouldn't define new data members");
-    return failure(failed(op_definition_impl::verifyRegionTraits<
-                          VerifiableRegionTraitsTupleT>(op)) ||
-                   failed(cast<ConcreteType>(op).verifyRegions()));
+    return failure(
+        failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
+            op)) ||
+        failed(cast<ConcreteType>(op).verifyRegions()));
   }
   static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
     return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index e0f5dd39305b6..f940241679994 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -125,23 +125,17 @@ class Interface : public BaseType {
 // InterfaceMap
 //===----------------------------------------------------------------------===//
 
-/// Utility to filter a given sequence of types base upon a predicate.
-template <bool>
-struct FilterTypeT {
-  template <class E>
-  using type = std::tuple<E>;
-};
-template <>
-struct FilterTypeT<false> {
-  template <class E>
-  using type = std::tuple<>;
-};
-template <template <class> class Pred, class... Es>
-struct FilterTypes {
-  using type = decltype(std::tuple_cat(
-      std::declval<
-          typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
-};
+/// Template utility that computes the number of elements within `T` that
+/// satisfy the given predicate.
+template <template <class> class Pred, size_t N, typename... Ts>
+struct count_if_t_impl : public std::integral_constant<size_t, N> {};
+template <template <class> class Pred, size_t N, typename T, typename... Us>
+struct count_if_t_impl<Pred, N, T, Us...>
+    : public std::integral_constant<
+          size_t,
+          count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {};
+template <template <class> class Pred, typename... Ts>
+using count_if_t = count_if_t_impl<Pred, 0, Ts...>;
 
 namespace {
 /// Type trait indicating whether all template arguments are
@@ -171,8 +165,7 @@ class InterfaceMap {
   template <typename T>
   using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
   template <typename... Types>
-  using num_interface_types = typename std::tuple_size<
-      typename FilterTypes<detect_get_interface_id, Types...>::type>;
+  using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;
 
 public:
   InterfaceMap(InterfaceMap &&) = default;
@@ -192,20 +185,17 @@ class InterfaceMap {
   /// types, not all of the types need to be interfaces. The provided types that
   /// do not represent interfaces are not added to the interface map.
   template <typename... Types>
-  static std::enable_if_t<num_interface_types<Types...>::value != 0,
-                          InterfaceMap>
-  get() {
-    // Filter the provided types for those that are interfaces.
-    using FilteredTupleType =
-        typename FilterTypes<detect_get_interface_id, Types...>::type;
-    return getImpl((FilteredTupleType *)nullptr);
-  }
-
-  template <typename... Types>
-  static std::enable_if_t<num_interface_types<Types...>::value == 0,
-                          InterfaceMap>
-  get() {
-    return InterfaceMap();
+  static InterfaceMap get() {
+    // TODO: Use constexpr if here in C++17.
+    constexpr size_t numInterfaces = num_interface_types_t<Types...>::value;
+    if (numInterfaces == 0)
+      return InterfaceMap();
+
+    std::array<std::pair<TypeID, void *>, numInterfaces> elements;
+    std::pair<TypeID, void *> *elementIt = elements.data();
+    (void)std::initializer_list<int>{
+        0, (addModelAndUpdateIterator<Types>(elementIt), 0)...};
+    return InterfaceMap(elements);
   }
 
   /// Returns an instance of the concept object for the given interface if it
@@ -235,21 +225,28 @@ class InterfaceMap {
   }
 
 private:
-  /// Compare two TypeID instances by comparing the underlying pointer.
-  static bool compare(TypeID lhs, TypeID rhs) {
-    return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
-  }
-
   InterfaceMap() = default;
 
+  /// Assign the interface model of the type to the given opaque element
+  /// iterator and increment it.
+  template <typename T>
+  static inline std::enable_if_t<detect_get_interface_id<T>::value>
+  addModelAndUpdateIterator(std::pair<TypeID, void *> *&elementIt) {
+    *elementIt = {T::getInterfaceID(), new (malloc(sizeof(typename T::ModelT)))
+                                           typename T::ModelT()};
+    ++elementIt;
+  }
+  /// Overload when `T` isn't an interface.
+  template <typename T>
+  static inline std::enable_if_t<!detect_get_interface_id<T>::value>
+  addModelAndUpdateIterator(std::pair<TypeID, void *> *&) {}
+
+  /// Insert the given set of interface models into the interface map.
   void insert(ArrayRef<std::pair<TypeID, void *>> elements);
 
-  template <typename... Ts>
-  static InterfaceMap getImpl(std::tuple<Ts...> *) {
-    std::pair<TypeID, void *> elements[] = {std::make_pair(
-        Ts::getInterfaceID(),
-        new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...};
-    return InterfaceMap(elements);
+  /// Compare two TypeID instances by comparing the underlying pointer.
+  static bool compare(TypeID lhs, TypeID rhs) {
+    return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
   }
 
   /// Returns an instance of the concept object for the given interface id if it


        


More information about the Mlir-commits mailing list