[libcxx-commits] [libcxx] aea7929 - [libc++] Unify __is_trivial_equality_predicate and __is_trivial_plus_operation into __desugars_to (#68642)

via libcxx-commits libcxx-commits at lists.llvm.org
Thu Nov 23 10:56:00 PST 2023


Author: Anton Rydahl
Date: 2023-11-23T13:55:55-05:00
New Revision: aea7929b0a04c5ea0ac85aba2b85fb58c718626f

URL: https://github.com/llvm/llvm-project/commit/aea7929b0a04c5ea0ac85aba2b85fb58c718626f
DIFF: https://github.com/llvm/llvm-project/commit/aea7929b0a04c5ea0ac85aba2b85fb58c718626f.diff

LOG: [libc++] Unify __is_trivial_equality_predicate and __is_trivial_plus_operation into __desugars_to (#68642)

When working on an OpenMP offloading backend for standard parallel
algorithms (https://github.com/llvm/llvm-project/pull/66968) we noticed
the need of a generalization of `__is_trivial_plus_operation`. This patch
merges `__is_trivial_equality_predicate` and `__is_trivial_plus_operation`
into `__desugars_to`, and in the future we might extend the latter to support
other binary operations as well.

Co-authored-by: Louis Dionne <ldionne.2 at gmail.com>

Added: 
    

Modified: 
    libcxx/include/CMakeLists.txt
    libcxx/include/__algorithm/comp.h
    libcxx/include/__algorithm/equal.h
    libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
    libcxx/include/__functional/operations.h
    libcxx/include/__functional/ranges_operations.h
    libcxx/include/__numeric/pstl_transform_reduce.h
    libcxx/include/__type_traits/operation_traits.h
    libcxx/include/module.modulemap.in

Removed: 
    libcxx/include/__type_traits/predicate_traits.h


################################################################################
diff  --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 889d7fedbf2965f..aef54ea25fd52f3 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -816,7 +816,6 @@ set(files
   __type_traits/negation.h
   __type_traits/noexcept_move_assign_container.h
   __type_traits/operation_traits.h
-  __type_traits/predicate_traits.h
   __type_traits/promote.h
   __type_traits/rank.h
   __type_traits/remove_all_extents.h

diff  --git a/libcxx/include/__algorithm/comp.h b/libcxx/include/__algorithm/comp.h
index 9474536615ffb67..3902f7560304a12 100644
--- a/libcxx/include/__algorithm/comp.h
+++ b/libcxx/include/__algorithm/comp.h
@@ -11,7 +11,7 @@
 
 #include <__config>
 #include <__type_traits/integral_constant.h>
-#include <__type_traits/predicate_traits.h>
+#include <__type_traits/operation_traits.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #  pragma GCC system_header
@@ -26,8 +26,8 @@ struct __equal_to {
   }
 };
 
-template <class _Lhs, class _Rhs>
-struct __is_trivial_equality_predicate<__equal_to, _Lhs, _Rhs> : true_type {};
+template <class _Tp, class _Up>
+struct __desugars_to<__equal_tag, __equal_to, _Tp, _Up> : true_type {};
 
 // The definition is required because __less is part of the ABI, but it's empty
 // because all comparisons should be transparent.

diff  --git a/libcxx/include/__algorithm/equal.h b/libcxx/include/__algorithm/equal.h
index b69aeff92bb9289..ca2e49ca5679a4f 100644
--- a/libcxx/include/__algorithm/equal.h
+++ b/libcxx/include/__algorithm/equal.h
@@ -23,7 +23,7 @@
 #include <__type_traits/is_constant_evaluated.h>
 #include <__type_traits/is_equality_comparable.h>
 #include <__type_traits/is_volatile.h>
-#include <__type_traits/predicate_traits.h>
+#include <__type_traits/operation_traits.h>
 #include <__utility/move.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -41,13 +41,12 @@ _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 boo
   return true;
 }
 
-template <
-    class _Tp,
-    class _Up,
-    class _BinaryPredicate,
-    __enable_if_t<__is_trivial_equality_predicate<_BinaryPredicate, _Tp, _Up>::value && !is_volatile<_Tp>::value &&
-                      !is_volatile<_Up>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
-                  int> = 0>
+template <class _Tp,
+          class _Up,
+          class _BinaryPredicate,
+          __enable_if_t<__desugars_to<__equal_tag, _BinaryPredicate, _Tp, _Up>::value && !is_volatile<_Tp>::value &&
+                            !is_volatile<_Up>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
+                        int> = 0>
 _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
 __equal_iter_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _BinaryPredicate&) {
   return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1));
@@ -94,12 +93,12 @@ template <class _Tp,
           class _Pred,
           class _Proj1,
           class _Proj2,
-          __enable_if_t<__is_trivial_equality_predicate<_Pred, _Tp, _Up>::value && __is_identity<_Proj1>::value &&
+          __enable_if_t<__desugars_to<__equal_tag, _Pred, _Tp, _Up>::value && __is_identity<_Proj1>::value &&
                             __is_identity<_Proj2>::value && !is_volatile<_Tp>::value && !is_volatile<_Up>::value &&
                             __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
                         int> = 0>
-_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_impl(
-    _Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&, _Proj2&) {
+_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
+__equal_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&, _Proj2&) {
   return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1));
 }
 

diff  --git a/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h b/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
index a5ca9c89d1ab23b..ab2e3172b8b63bd 100644
--- a/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
+++ b/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
@@ -29,12 +29,14 @@
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
-template <
-    typename _DifferenceType,
-    typename _Tp,
-    typename _BinaryOperation,
-    typename _UnaryOperation,
-    __enable_if_t<__is_trivial_plus_operation<_BinaryOperation, _Tp, _Tp>::value && is_arithmetic_v<_Tp>, int> = 0>
+template <typename _DifferenceType,
+          typename _Tp,
+          typename _BinaryOperation,
+          typename _UnaryOperation,
+          typename _UnaryResult = invoke_result_t<_UnaryOperation, _DifferenceType>,
+          __enable_if_t<__desugars_to<__plus_tag, _BinaryOperation, _Tp, _UnaryResult>::value && is_arithmetic_v<_Tp> &&
+                            is_arithmetic_v<_UnaryResult>,
+                        int>    = 0>
 _LIBCPP_HIDE_FROM_ABI _Tp
 __simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _UnaryOperation __f) noexcept {
   _PSTL_PRAGMA_SIMD_REDUCTION(+ : __init)
@@ -43,12 +45,14 @@ __simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _Unar
   return __init;
 }
 
-template <
-    typename _Size,
-    typename _Tp,
-    typename _BinaryOperation,
-    typename _UnaryOperation,
-    __enable_if_t<!(__is_trivial_plus_operation<_BinaryOperation, _Tp, _Tp>::value && is_arithmetic_v<_Tp>), int> = 0>
+template <typename _Size,
+          typename _Tp,
+          typename _BinaryOperation,
+          typename _UnaryOperation,
+          typename _UnaryResult = invoke_result_t<_UnaryOperation, _Size>,
+          __enable_if_t<!(__desugars_to<__plus_tag, _BinaryOperation, _Tp, _UnaryResult>::value &&
+                          is_arithmetic_v<_Tp> && is_arithmetic_v<_UnaryResult>),
+                        int>    = 0>
 _LIBCPP_HIDE_FROM_ABI _Tp
 __simd_transform_reduce(_Size __n, _Tp __init, _BinaryOperation __binary_op, _UnaryOperation __f) noexcept {
   const _Size __block_size = __lane_size / sizeof(_Tp);

diff  --git a/libcxx/include/__functional/operations.h b/libcxx/include/__functional/operations.h
index 6cdb89d6b449bcd..9812ccf8e4136ff 100644
--- a/libcxx/include/__functional/operations.h
+++ b/libcxx/include/__functional/operations.h
@@ -15,7 +15,6 @@
 #include <__functional/unary_function.h>
 #include <__type_traits/integral_constant.h>
 #include <__type_traits/operation_traits.h>
-#include <__type_traits/predicate_traits.h>
 #include <__utility/forward.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -41,13 +40,13 @@ struct _LIBCPP_TEMPLATE_VIS plus
 };
 _LIBCPP_CTAD_SUPPORTED_FOR_TYPE(plus);
 
+// The non-transparent std::plus specialization is only equivalent to a raw plus
+// operator when we don't perform an implicit conversion when calling it.
 template <class _Tp>
-struct __is_trivial_plus_operation<plus<_Tp>, _Tp, _Tp> : true_type {};
+struct __desugars_to<__plus_tag, plus<_Tp>, _Tp, _Tp> : true_type {};
 
-#if _LIBCPP_STD_VER >= 14
 template <class _Tp, class _Up>
-struct __is_trivial_plus_operation<plus<>, _Tp, _Up> : true_type {};
-#endif
+struct __desugars_to<__plus_tag, plus<void>, _Tp, _Up> : true_type {};
 
 #if _LIBCPP_STD_VER >= 14
 template <>
@@ -352,13 +351,14 @@ struct _LIBCPP_TEMPLATE_VIS equal_to<void>
 };
 #endif
 
+// The non-transparent std::equal_to specialization is only equivalent to a raw equality
+// comparison when we don't perform an implicit conversion when calling it.
 template <class _Tp>
-struct __is_trivial_equality_predicate<equal_to<_Tp>, _Tp, _Tp> : true_type {};
+struct __desugars_to<__equal_tag, equal_to<_Tp>, _Tp, _Tp> : true_type {};
 
-#if _LIBCPP_STD_VER >= 14
-template <class _Tp>
-struct __is_trivial_equality_predicate<equal_to<>, _Tp, _Tp> : true_type {};
-#endif
+// In the transparent case, we do not enforce that
+template <class _Tp, class _Up>
+struct __desugars_to<__equal_tag, equal_to<void>, _Tp, _Up> : true_type {};
 
 #if _LIBCPP_STD_VER >= 14
 template <class _Tp = void>

diff  --git a/libcxx/include/__functional/ranges_operations.h b/libcxx/include/__functional/ranges_operations.h
index c344fc38f98ddd9..b54589f8c0d8792 100644
--- a/libcxx/include/__functional/ranges_operations.h
+++ b/libcxx/include/__functional/ranges_operations.h
@@ -14,7 +14,7 @@
 #include <__concepts/totally_ordered.h>
 #include <__config>
 #include <__type_traits/integral_constant.h>
-#include <__type_traits/predicate_traits.h>
+#include <__type_traits/operation_traits.h>
 #include <__utility/forward.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -95,8 +95,10 @@ struct greater_equal {
 
 } // namespace ranges
 
-template <class _Lhs, class _Rhs>
-struct __is_trivial_equality_predicate<ranges::equal_to, _Lhs, _Rhs> : true_type {};
+// For ranges we do not require that the types on each side of the equality
+// operator are of the same type
+template <class _Tp, class _Up>
+struct __desugars_to<__equal_tag, ranges::equal_to, _Tp, _Up> : true_type {};
 
 #endif // _LIBCPP_STD_VER >= 20
 

diff  --git a/libcxx/include/__numeric/pstl_transform_reduce.h b/libcxx/include/__numeric/pstl_transform_reduce.h
index 4127ee21e3045c8..1127726046665c0 100644
--- a/libcxx/include/__numeric/pstl_transform_reduce.h
+++ b/libcxx/include/__numeric/pstl_transform_reduce.h
@@ -84,7 +84,7 @@ _LIBCPP_HIDE_FROM_ABI _Tp transform_reduce(
 }
 
 // This overload doesn't get a customization point because it's trivial to detect (through e.g.
-// __is_trivial_plus_operation) when specializing the more general variant, which should always be preferred
+// __desugars_to) when specializing the more general variant, which should always be preferred
 template <class _ExecutionPolicy,
           class _ForwardIterator1,
           class _ForwardIterator2,

diff  --git a/libcxx/include/__type_traits/operation_traits.h b/libcxx/include/__type_traits/operation_traits.h
index 7dda93e9083a404..ef6e71693430d01 100644
--- a/libcxx/include/__type_traits/operation_traits.h
+++ b/libcxx/include/__type_traits/operation_traits.h
@@ -18,8 +18,22 @@
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
-template <class _Pred, class _Lhs, class _Rhs>
-struct __is_trivial_plus_operation : false_type {};
+// Tags to represent the canonical operations
+struct __equal_tag {};
+struct __plus_tag {};
+
+// This class template is used to determine whether an operation "desugars"
+// (or boils down) to a given canonical operation.
+//
+// For example, `std::equal_to<>`, our internal `std::__equal_to` helper and
+// `ranges::equal_to` are all just fancy ways of representing a transparent
+// equality operation, so they all desugar to `__equal_tag`.
+//
+// This is useful to optimize some functions in cases where we know e.g. the
+// predicate being passed is actually going to call a builtin operator, or has
+// some specific semantics.
+template <class _CanonicalTag, class _Operation, class... _Args>
+struct __desugars_to : false_type {};
 
 _LIBCPP_END_NAMESPACE_STD
 

diff  --git a/libcxx/include/__type_traits/predicate_traits.h b/libcxx/include/__type_traits/predicate_traits.h
deleted file mode 100644
index 872608e6ac3be3f..000000000000000
--- a/libcxx/include/__type_traits/predicate_traits.h
+++ /dev/null
@@ -1,26 +0,0 @@
-//===----------------------------------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS
-#define _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS
-
-#include <__config>
-#include <__type_traits/integral_constant.h>
-
-#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
-#  pragma GCC system_header
-#endif
-
-_LIBCPP_BEGIN_NAMESPACE_STD
-
-template <class _Pred, class _Lhs, class _Rhs>
-struct __is_trivial_equality_predicate : false_type {};
-
-_LIBCPP_END_NAMESPACE_STD
-
-#endif // _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS

diff  --git a/libcxx/include/module.modulemap.in b/libcxx/include/module.modulemap.in
index 17ebe48f329963d..b4a68c8ecde0ab1 100644
--- a/libcxx/include/module.modulemap.in
+++ b/libcxx/include/module.modulemap.in
@@ -2013,7 +2013,6 @@ module std_private_type_traits_nat                                       [system
 module std_private_type_traits_negation                                  [system] { header "__type_traits/negation.h" }
 module std_private_type_traits_noexcept_move_assign_container            [system] { header "__type_traits/noexcept_move_assign_container.h" }
 module std_private_type_traits_operation_traits                          [system] { header "__type_traits/operation_traits.h" }
-module std_private_type_traits_predicate_traits                          [system] { header "__type_traits/predicate_traits.h" }
 module std_private_type_traits_promote                                   [system] { header "__type_traits/promote.h" }
 module std_private_type_traits_rank                                      [system] { header "__type_traits/rank.h" }
 module std_private_type_traits_remove_all_extents                        [system] { header "__type_traits/remove_all_extents.h" }


        


More information about the libcxx-commits mailing list