[Mlir-commits] [mlir] 31c8866 - [mlir] Remove the use of FilterTypes for template metaprogramming
River Riddle
llvmlistbot at llvm.org
Fri Apr 15 12:57:29 PDT 2022
Author: River Riddle
Date: 2022-04-15T12:57:07-07:00
New Revision: 31c88660ab155beb5e6796ec1382afd2c0e52978
URL: https://github.com/llvm/llvm-project/commit/31c88660ab155beb5e6796ec1382afd2c0e52978
DIFF: https://github.com/llvm/llvm-project/commit/31c88660ab155beb5e6796ec1382afd2c0e52978.diff
LOG: [mlir] Remove the use of FilterTypes for template metaprogramming
This technique results in an explosion in compile time, resulting from a
huge number of std::tuple/concat instatiations. This technique is replaced
by simpler metaprogramming and results in a signficant reduction in
compile time. A local debug/asan build saw a 4x speed up in the processing
of ArithmeticOps.h.inc, and given the nature of this change every dialect
should see similar reductions in compile time.
Differential Revision: https://reviews.llvm.org/D123360
Added:
Modified:
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/Support/InterfaceSupport.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 48bdec9483f84..5a1bc0133d1a7 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1540,36 +1540,24 @@ foldTrait(Operation *op, ArrayRef<Attribute> operands,
// fail to fold this trait.
return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
}
+template <typename Trait>
+static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
+ LogicalResult>
+foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
+ return failure();
+}
-/// The internal implementation of `foldTraits` below that returns the result of
-/// folding a set of trait types `Ts` that implement a `foldTrait` method.
+/// Given a tuple type containing a set of traits, return the result of folding
+/// the given operation.
template <typename... Ts>
-static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results,
- std::tuple<Ts...> *) {
+static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
bool anyFolded = false;
(void)std::initializer_list<int>{
(anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
return success(anyFolded);
}
-/// Given a tuple type containing a set of traits that contain a `foldTrait`
-/// method, return the result of folding the given operation.
-template <typename TraitTupleT>
-static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
-foldTraits(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
-}
-/// A variant of the method above that is specialized when there are no traits
-/// that contain a `foldTrait` method.
-template <typename TraitTupleT>
-static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
-foldTraits(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- return failure();
-}
-
//===----------------------------------------------------------------------===//
// Trait Verification
@@ -1587,44 +1575,51 @@ template <typename T>
using detect_has_verify_region_trait =
llvm::is_detected<has_verify_region_trait, T>;
-/// The internal implementation of `verifyTraits` below that returns the result
-/// of verifying the current operation with all of the provided trait types
-/// `Ts`.
+/// Verify the given trait if it provides a verifier.
+template <typename T>
+std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult>
+verifyTrait(Operation *op) {
+ return T::verifyTrait(op);
+}
+template <typename T>
+inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult>
+verifyTrait(Operation *) {
+ return success();
+}
+
+/// Given a set of traits, return the result of verifying the given operation.
template <typename... Ts>
-static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
+LogicalResult verifyTraits(Operation *op) {
LogicalResult result = success();
(void)std::initializer_list<int>{
- (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
+ (result = succeeded(result) ? verifyTrait<Ts>(op) : failure(), 0)...};
return result;
}
-/// Given a tuple type containing a set of traits that contain a
-/// `verifyTrait` method, return the result of verifying the given operation.
-template <typename TraitTupleT>
-static LogicalResult verifyTraits(Operation *op) {
- return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
+/// Verify the given trait if it provides a region verifier.
+template <typename T>
+std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult>
+verifyRegionTrait(Operation *op) {
+ return T::verifyRegionTrait(op);
+}
+template <typename T>
+inline std::enable_if_t<!detect_has_verify_region_trait<T>::value,
+ LogicalResult>
+verifyRegionTrait(Operation *) {
+ return success();
}
-/// The internal implementation of `verifyRegionTraits` below that returns the
-/// result of verifying the current operation with all of the provided trait
-/// types `Ts`.
+/// Given a set of traits, return the result of verifying the regions of the
+/// given operation.
template <typename... Ts>
-static LogicalResult verifyRegionTraitsImpl(Operation *op,
- std::tuple<Ts...> *) {
+LogicalResult verifyRegionTraits(Operation *op) {
(void)op;
LogicalResult result = success();
(void)std::initializer_list<int>{
- (result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(),
+ (result = succeeded(result) ? verifyRegionTrait<Ts>(op) : failure(),
0)...};
return result;
}
-
-/// Given a tuple type containing a set of traits that contain a
-/// `verifyTrait` method, return the result of verifying the given operation.
-template <typename TraitTupleT>
-static LogicalResult verifyRegionTraits(Operation *op) {
- return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr);
-}
} // namespace op_definition_impl
//===----------------------------------------------------------------------===//
@@ -1733,18 +1728,6 @@ class Op : public OpState, public Traits<ConcreteType>... {
decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
template <typename T>
using detect_has_print = llvm::is_detected<has_print, T>;
- /// A tuple type containing the traits that have a `foldTrait` function.
- using FoldableTraitsTupleT = typename detail::FilterTypes<
- op_definition_impl::detect_has_any_fold_trait,
- Traits<ConcreteType>...>::type;
- /// A tuple type containing the traits that have a verify function.
- using VerifiableTraitsTupleT =
- typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
- Traits<ConcreteType>...>::type;
- /// A tuple type containing the region traits that have a verify function.
- using VerifiableRegionTraitsTupleT = typename detail::FilterTypes<
- op_definition_impl::detect_has_verify_region_trait,
- Traits<ConcreteType>...>::type;
/// Returns an interface map containing the interfaces registered to this
/// operation.
@@ -1794,8 +1777,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// In this case, we only need to fold the traits of the operation.
- return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
- results);
+ return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
+ op, operands, results);
};
}
/// Return the result of folding a single result operation that defines a
@@ -1809,7 +1792,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
// If the fold failed or was in-place, try to fold the traits of the
// operation.
if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
- if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
+ if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
op, operands, results)))
return success();
return success(static_cast<bool>(result));
@@ -1826,7 +1809,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
// If the fold failed or was in-place, try to fold the traits of the
// operation.
if (failed(result) || results.empty()) {
- if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
+ if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
op, operands, results)))
return success();
}
@@ -1879,7 +1862,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
static_assert(hasNoDataMembers(),
"Op class shouldn't define new data members");
return failure(
- failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
+ failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
failed(cast<ConcreteType>(op).verify()));
}
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
@@ -1889,9 +1872,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
static LogicalResult verifyRegionInvariants(Operation *op) {
static_assert(hasNoDataMembers(),
"Op class shouldn't define new data members");
- return failure(failed(op_definition_impl::verifyRegionTraits<
- VerifiableRegionTraitsTupleT>(op)) ||
- failed(cast<ConcreteType>(op).verifyRegions()));
+ return failure(
+ failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
+ op)) ||
+ failed(cast<ConcreteType>(op).verifyRegions()));
}
static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index e0f5dd39305b6..f940241679994 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -125,23 +125,17 @@ class Interface : public BaseType {
// InterfaceMap
//===----------------------------------------------------------------------===//
-/// Utility to filter a given sequence of types base upon a predicate.
-template <bool>
-struct FilterTypeT {
- template <class E>
- using type = std::tuple<E>;
-};
-template <>
-struct FilterTypeT<false> {
- template <class E>
- using type = std::tuple<>;
-};
-template <template <class> class Pred, class... Es>
-struct FilterTypes {
- using type = decltype(std::tuple_cat(
- std::declval<
- typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
-};
+/// Template utility that computes the number of elements within `T` that
+/// satisfy the given predicate.
+template <template <class> class Pred, size_t N, typename... Ts>
+struct count_if_t_impl : public std::integral_constant<size_t, N> {};
+template <template <class> class Pred, size_t N, typename T, typename... Us>
+struct count_if_t_impl<Pred, N, T, Us...>
+ : public std::integral_constant<
+ size_t,
+ count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {};
+template <template <class> class Pred, typename... Ts>
+using count_if_t = count_if_t_impl<Pred, 0, Ts...>;
namespace {
/// Type trait indicating whether all template arguments are
@@ -171,8 +165,7 @@ class InterfaceMap {
template <typename T>
using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
template <typename... Types>
- using num_interface_types = typename std::tuple_size<
- typename FilterTypes<detect_get_interface_id, Types...>::type>;
+ using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;
public:
InterfaceMap(InterfaceMap &&) = default;
@@ -192,20 +185,17 @@ class InterfaceMap {
/// types, not all of the types need to be interfaces. The provided types that
/// do not represent interfaces are not added to the interface map.
template <typename... Types>
- static std::enable_if_t<num_interface_types<Types...>::value != 0,
- InterfaceMap>
- get() {
- // Filter the provided types for those that are interfaces.
- using FilteredTupleType =
- typename FilterTypes<detect_get_interface_id, Types...>::type;
- return getImpl((FilteredTupleType *)nullptr);
- }
-
- template <typename... Types>
- static std::enable_if_t<num_interface_types<Types...>::value == 0,
- InterfaceMap>
- get() {
- return InterfaceMap();
+ static InterfaceMap get() {
+ // TODO: Use constexpr if here in C++17.
+ constexpr size_t numInterfaces = num_interface_types_t<Types...>::value;
+ if (numInterfaces == 0)
+ return InterfaceMap();
+
+ std::array<std::pair<TypeID, void *>, numInterfaces> elements;
+ std::pair<TypeID, void *> *elementIt = elements.data();
+ (void)std::initializer_list<int>{
+ 0, (addModelAndUpdateIterator<Types>(elementIt), 0)...};
+ return InterfaceMap(elements);
}
/// Returns an instance of the concept object for the given interface if it
@@ -235,21 +225,28 @@ class InterfaceMap {
}
private:
- /// Compare two TypeID instances by comparing the underlying pointer.
- static bool compare(TypeID lhs, TypeID rhs) {
- return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
- }
-
InterfaceMap() = default;
+ /// Assign the interface model of the type to the given opaque element
+ /// iterator and increment it.
+ template <typename T>
+ static inline std::enable_if_t<detect_get_interface_id<T>::value>
+ addModelAndUpdateIterator(std::pair<TypeID, void *> *&elementIt) {
+ *elementIt = {T::getInterfaceID(), new (malloc(sizeof(typename T::ModelT)))
+ typename T::ModelT()};
+ ++elementIt;
+ }
+ /// Overload when `T` isn't an interface.
+ template <typename T>
+ static inline std::enable_if_t<!detect_get_interface_id<T>::value>
+ addModelAndUpdateIterator(std::pair<TypeID, void *> *&) {}
+
+ /// Insert the given set of interface models into the interface map.
void insert(ArrayRef<std::pair<TypeID, void *>> elements);
- template <typename... Ts>
- static InterfaceMap getImpl(std::tuple<Ts...> *) {
- std::pair<TypeID, void *> elements[] = {std::make_pair(
- Ts::getInterfaceID(),
- new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...};
- return InterfaceMap(elements);
+ /// Compare two TypeID instances by comparing the underlying pointer.
+ static bool compare(TypeID lhs, TypeID rhs) {
+ return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
}
/// Returns an instance of the concept object for the given interface id if it
More information about the Mlir-commits
mailing list