[libcxx-commits] [libcxx] 194d196 - Introduce branchless sorting functions for sort3, sort4 and sort5.

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Fri Apr 8 00:05:07 PDT 2022


Author: Marco Gelmi
Date: 2022-04-08T09:00:30+02:00
New Revision: 194d1965d2c841fa81e107d19e27fae1467e7f11

URL: https://github.com/llvm/llvm-project/commit/194d1965d2c841fa81e107d19e27fae1467e7f11
DIFF: https://github.com/llvm/llvm-project/commit/194d1965d2c841fa81e107d19e27fae1467e7f11.diff

LOG: Introduce branchless sorting functions for sort3, sort4 and sort5.

We are introducing branchless variants for sort3, sort4 and sort5.
These sorting functions have been generated using Reinforcement
Learning and aim to replace __sort3, __sort4 and __sort5 variants
for integral types.

The libc++ benchmarks were run on isolated machines for Skylake, ARM and
AMD architectures and achieve statistically significant improvement in
sorting random integers on test cases from sort1 to sort262144 for
uint32 and uint64.

A full performance overview for Intel Skylake, AMD and Arm can be
found here: https://bit.ly/3AtesYf

Reviewed By: ldionne, #libc, philnik

Spies: daniel.mankowitz, mgrang, Quuxplusone, andreamichi, philnik, libcxx-commits, nilayvaish, kristof.beyls

Differential Revision: https://reviews.llvm.org/D118029

Added: 
    

Modified: 
    libcxx/benchmarks/algorithms.bench.cpp
    libcxx/include/__algorithm/sort.h
    libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp
    libcxx/test/std/algorithms/alg.sorting/alg.sort/sort/sort.pass.cpp

Removed: 
    


################################################################################
diff  --git a/libcxx/benchmarks/algorithms.bench.cpp b/libcxx/benchmarks/algorithms.bench.cpp
index b3724ecf23adb..5a97df1acc604 100644
--- a/libcxx/benchmarks/algorithms.bench.cpp
+++ b/libcxx/benchmarks/algorithms.bench.cpp
@@ -14,23 +14,17 @@
 
 namespace {
 
-enum class ValueType { Uint32, Uint64, Pair, Tuple, String };
-struct AllValueTypes : EnumValuesAsTuple<AllValueTypes, ValueType, 5> {
-  static constexpr const char* Names[] = {
-      "uint32", "uint64", "pair<uint32, uint32>",
-      "tuple<uint32, uint64, uint32>", "string"};
+enum class ValueType { Uint32, Uint64, Pair, Tuple, String, Float };
+struct AllValueTypes : EnumValuesAsTuple<AllValueTypes, ValueType, 6> {
+  static constexpr const char* Names[] = {"uint32", "uint64", "pair<uint32, uint32>", "tuple<uint32, uint64, uint32>",
+                                          "string", "float"};
 };
 
+using Types = std::tuple< uint32_t, uint64_t, std::pair<uint32_t, uint32_t>, std::tuple<uint32_t, uint64_t, uint32_t>,
+                          std::string, float >;
+
 template <class V>
-using Value = std::conditional_t<
-    V() == ValueType::Uint32, uint32_t,
-    std::conditional_t<
-        V() == ValueType::Uint64, uint64_t,
-        std::conditional_t<
-            V() == ValueType::Pair, std::pair<uint32_t, uint32_t>,
-            std::conditional_t<V() == ValueType::Tuple,
-                               std::tuple<uint32_t, uint64_t, uint32_t>,
-                               std::string> > > >;
+using Value = std::tuple_element_t<(int)V::value, Types>;
 
 enum class Order {
   Random,

diff  --git a/libcxx/include/__algorithm/sort.h b/libcxx/include/__algorithm/sort.h
index 27ce647c8129c..3faff6b7db6c3 100644
--- a/libcxx/include/__algorithm/sort.h
+++ b/libcxx/include/__algorithm/sort.h
@@ -123,6 +123,96 @@ __sort5(_ForwardIterator __x1, _ForwardIterator __x2, _ForwardIterator __x3,
     return __r;
 }
 
+template <class _Tp>
+struct __is_simple_comparator : false_type {};
+template <class _Tp>
+struct __is_simple_comparator<__less<_Tp>&> : true_type {};
+template <class _Tp>
+struct __is_simple_comparator<less<_Tp>&> : true_type {};
+template <class _Tp>
+struct __is_simple_comparator<greater<_Tp>&> : true_type {};
+
+template <class _Compare, class _Iter, class _Tp = typename iterator_traits<_Iter>::value_type>
+using __use_branchless_sort =
+    integral_constant<bool, __is_cpp17_contiguous_iterator<_Iter>::value && sizeof(_Tp) <= sizeof(void*) &&
+                                is_arithmetic<_Tp>::value && __is_simple_comparator<_Compare>::value>;
+
+// Ensures that __c(*__x, *__y) is true by swapping *__x and *__y if necessary.
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI void __cond_swap(_RandomAccessIterator __x, _RandomAccessIterator __y, _Compare __c) {
+  using value_type = typename iterator_traits<_RandomAccessIterator>::value_type;
+  bool __r = __c(*__x, *__y);
+  value_type __tmp = __r ? *__x : *__y;
+  *__y = __r ? *__y : *__x;
+  *__x = __tmp;
+}
+
+// Ensures that *__x, *__y and *__z are ordered according to the comparator __c,
+// under the assumption that *__y and *__z are already ordered.
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI void __partially_sorted_swap(_RandomAccessIterator __x, _RandomAccessIterator __y,
+                                                          _RandomAccessIterator __z, _Compare __c) {
+  using value_type = typename iterator_traits<_RandomAccessIterator>::value_type;
+  bool __r = __c(*__z, *__x);
+  value_type __tmp = __r ? *__z : *__x;
+  *__z = __r ? *__x : *__z;
+  __r = __c(__tmp, *__y);
+  *__x = __r ? *__x : *__y;
+  *__y = __r ? *__y : __tmp;
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort3_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _Compare __c) {
+  _VSTD::__cond_swap<_Compare>(__x2, __x3, __c);
+  _VSTD::__partially_sorted_swap<_Compare>(__x1, __x2, __x3, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<!__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort3_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _Compare __c) {
+  _VSTD::__sort3<_Compare>(__x1, __x2, __x3, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort4_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _RandomAccessIterator __x4, _Compare __c) {
+  _VSTD::__cond_swap<_Compare>(__x1, __x3, __c);
+  _VSTD::__cond_swap<_Compare>(__x2, __x4, __c);
+  _VSTD::__cond_swap<_Compare>(__x1, __x2, __c);
+  _VSTD::__cond_swap<_Compare>(__x3, __x4, __c);
+  _VSTD::__cond_swap<_Compare>(__x2, __x3, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<!__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort4_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _RandomAccessIterator __x4, _Compare __c) {
+  _VSTD::__sort4<_Compare>(__x1, __x2, __x3, __x4, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort5_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _RandomAccessIterator __x4, _RandomAccessIterator __x5, _Compare __c) {
+  _VSTD::__cond_swap<_Compare>(__x1, __x2, __c);
+  _VSTD::__cond_swap<_Compare>(__x4, __x5, __c);
+  _VSTD::__partially_sorted_swap<_Compare>(__x3, __x4, __x5, __c);
+  _VSTD::__cond_swap<_Compare>(__x2, __x5, __c);
+  _VSTD::__partially_sorted_swap<_Compare>(__x1, __x3, __x4, __c);
+  _VSTD::__partially_sorted_swap<_Compare>(__x2, __x3, __x4, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<!__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort5_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _RandomAccessIterator __x4, _RandomAccessIterator __x5, _Compare __c) {
+  _VSTD::__sort5<_Compare>(__x1, __x2, __x3, __x4, __x5, __c);
+}
+
 // Assumes size > 0
 template <class _Compare, class _BidirectionalIterator>
 _LIBCPP_CONSTEXPR_AFTER_CXX11 void
@@ -163,7 +253,7 @@ __insertion_sort_3(_RandomAccessIterator __first, _RandomAccessIterator __last,
     typedef typename iterator_traits<_RandomAccessIterator>::
diff erence_type 
diff erence_type;
     typedef typename iterator_traits<_RandomAccessIterator>::value_type value_type;
     _RandomAccessIterator __j = __first+
diff erence_type(2);
-    _VSTD::__sort3<_Compare>(__first, __first+
diff erence_type(1), __j, __comp);
+    _VSTD::__sort3_maybe_branchless<_Compare>(__first, __first + 
diff erence_type(1), __j, __comp);
     for (_RandomAccessIterator __i = __j+
diff erence_type(1); __i != __last; ++__i)
     {
         if (__comp(*__i, *__j))
@@ -197,18 +287,20 @@ __insertion_sort_incomplete(_RandomAccessIterator __first, _RandomAccessIterator
             swap(*__first, *__last);
         return true;
     case 3:
-        _VSTD::__sort3<_Compare>(__first, __first+
diff erence_type(1), --__last, __comp);
-        return true;
+      _VSTD::__sort3_maybe_branchless<_Compare>(__first, __first + 
diff erence_type(1), --__last, __comp);
+      return true;
     case 4:
-        _VSTD::__sort4<_Compare>(__first, __first+
diff erence_type(1), __first+
diff erence_type(2), --__last, __comp);
-        return true;
+      _VSTD::__sort4_maybe_branchless<_Compare>(__first, __first + 
diff erence_type(1), __first + 
diff erence_type(2),
+                                                --__last, __comp);
+      return true;
     case 5:
-        _VSTD::__sort5<_Compare>(__first, __first+
diff erence_type(1), __first+
diff erence_type(2), __first+
diff erence_type(3), --__last, __comp);
-        return true;
+      _VSTD::__sort5_maybe_branchless<_Compare>(__first, __first + 
diff erence_type(1), __first + 
diff erence_type(2),
+                                                __first + 
diff erence_type(3), --__last, __comp);
+      return true;
     }
     typedef typename iterator_traits<_RandomAccessIterator>::value_type value_type;
     _RandomAccessIterator __j = __first+
diff erence_type(2);
-    _VSTD::__sort3<_Compare>(__first, __first+
diff erence_type(1), __j, __comp);
+    _VSTD::__sort3_maybe_branchless<_Compare>(__first, __first + 
diff erence_type(1), __j, __comp);
     const unsigned __limit = 8;
     unsigned __count = 0;
     for (_RandomAccessIterator __i = __j+
diff erence_type(1); __i != __last; ++__i)
@@ -290,14 +382,16 @@ __introsort(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compar
                 swap(*__first, *__last);
             return;
         case 3:
-            _VSTD::__sort3<_Compare>(__first, __first+
diff erence_type(1), --__last, __comp);
-            return;
+          _VSTD::__sort3_maybe_branchless<_Compare>(__first, __first + 
diff erence_type(1), --__last, __comp);
+          return;
         case 4:
-            _VSTD::__sort4<_Compare>(__first, __first+
diff erence_type(1), __first+
diff erence_type(2), --__last, __comp);
-            return;
+          _VSTD::__sort4_maybe_branchless<_Compare>(__first, __first + 
diff erence_type(1), __first + 
diff erence_type(2),
+                                                    --__last, __comp);
+          return;
         case 5:
-            _VSTD::__sort5<_Compare>(__first, __first+
diff erence_type(1), __first+
diff erence_type(2), __first+
diff erence_type(3), --__last, __comp);
-            return;
+          _VSTD::__sort5_maybe_branchless<_Compare>(__first, __first + 
diff erence_type(1), __first + 
diff erence_type(2),
+                                                    __first + 
diff erence_type(3), --__last, __comp);
+          return;
         }
         if (__len <= __limit)
         {

diff  --git a/libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp b/libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp
index 7fbdf1f17536f..66b8b3637f5f3 100644
--- a/libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp
+++ b/libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp
@@ -15,60 +15,67 @@
 
 #include "test_macros.h"
 
+template <class T>
 struct Less {
     int *copies_;
     TEST_CONSTEXPR explicit Less(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 Less(const Less& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 Less& operator=(const Less&) = default;
-    TEST_CONSTEXPR bool operator()(void*, void*) const { return false; }
+    TEST_CONSTEXPR bool operator()(T, T) const { return false; }
 };
 
+template <class T>
 struct Equal {
     int *copies_;
     TEST_CONSTEXPR explicit Equal(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 Equal(const Equal& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 Equal& operator=(const Equal&) = default;
-    TEST_CONSTEXPR bool operator()(void*, void*) const { return true; }
+    TEST_CONSTEXPR bool operator()(T, T) const { return true; }
 };
 
+template <class T>
 struct UnaryVoid {
     int *copies_;
     TEST_CONSTEXPR explicit UnaryVoid(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 UnaryVoid(const UnaryVoid& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 UnaryVoid& operator=(const UnaryVoid&) = default;
-    TEST_CONSTEXPR_CXX14 void operator()(void*) const {}
+    TEST_CONSTEXPR_CXX14 void operator()(T) const {}
 };
 
+template <class T>
 struct UnaryTrue {
     int *copies_;
     TEST_CONSTEXPR explicit UnaryTrue(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 UnaryTrue(const UnaryTrue& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 UnaryTrue& operator=(const UnaryTrue&) = default;
-    TEST_CONSTEXPR bool operator()(void*) const { return true; }
+    TEST_CONSTEXPR bool operator()(T) const { return true; }
 };
 
+template <class T>
 struct NullaryValue {
     int *copies_;
     TEST_CONSTEXPR explicit NullaryValue(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 NullaryValue(const NullaryValue& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 NullaryValue& operator=(const NullaryValue&) = default;
-    TEST_CONSTEXPR std::nullptr_t operator()() const { return nullptr; }
+    TEST_CONSTEXPR T operator()() const { return 0; }
 };
 
+template <class T>
 struct UnaryTransform {
     int *copies_;
     TEST_CONSTEXPR explicit UnaryTransform(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 UnaryTransform(const UnaryTransform& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 UnaryTransform& operator=(const UnaryTransform&) = default;
-    TEST_CONSTEXPR std::nullptr_t operator()(void*) const { return nullptr; }
+    TEST_CONSTEXPR T operator()(T) const { return 0; }
 };
 
+template <class T>
 struct BinaryTransform {
     int *copies_;
     TEST_CONSTEXPR explicit BinaryTransform(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 BinaryTransform(const BinaryTransform& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 BinaryTransform& operator=(const BinaryTransform&) = default;
-    TEST_CONSTEXPR std::nullptr_t operator()(void*, void*) const { return nullptr; }
+    TEST_CONSTEXPR T operator()(T, T) const { return 0; }
 };
 
 #if TEST_STD_VER > 17
@@ -81,124 +88,130 @@ struct ThreeWay {
 };
 #endif
 
+template <class T>
 TEST_CONSTEXPR_CXX20 bool all_the_algorithms()
 {
-    void *a[10] = {};
-    void *b[10] = {};
-    void **first = a;
-    void **mid = a+5;
-    void **last = a+10;
-    void **first2 = b;
-    void **mid2 = b+5;
-    void **last2 = b+10;
-    void *value = nullptr;
+    T a[10] = {};
+    T b[10] = {};
+    T *first = a;
+    T *mid = a+5;
+    T *last = a+10;
+    T *first2 = b;
+    T *mid2 = b+5;
+    T *last2 = b+10;
+    T value = 0;
     int count = 1;
 
     int copies = 0;
-    (void)std::adjacent_find(first, last, Equal(&copies)); assert(copies == 0);
+    (void)std::adjacent_find(first, last, Equal<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::all_of(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::any_of(first, last, UnaryTrue(&copies)); assert(copies == 0);
+    (void)std::all_of(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::any_of(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::binary_search(first, last, value, Less(&copies)); assert(copies == 0);
+    (void)std::binary_search(first, last, value, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 17
-    (void)std::clamp(value, value, value, Less(&copies)); assert(copies == 0);
+    (void)std::clamp(value, value, value, Less<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::count_if(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::copy_if(first, last, first2, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::equal(first, last, first2, Equal(&copies)); assert(copies == 0);
+    (void)std::count_if(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::copy_if(first, last, first2, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::equal(first, last, first2, Equal<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 11
-    (void)std::equal(first, last, first2, last2, Equal(&copies)); assert(copies == 0);
+    (void)std::equal(first, last, first2, last2, Equal<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::equal_range(first, last, value, Less(&copies)); assert(copies == 0);
-    (void)std::find_end(first, last, first2, mid2, Equal(&copies)); assert(copies == 0);
+    (void)std::equal_range(first, last, value, Less<T>(&copies)); assert(copies == 0);
+    (void)std::find_end(first, last, first2, mid2, Equal<T>(&copies)); assert(copies == 0);
     //(void)std::find_first_of(first, last, first2, last2, Equal(&copies)); assert(copies == 0);
-    (void)std::find_if(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::find_if_not(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::for_each(first, last, UnaryVoid(&copies)); assert(copies == 1); copies = 0;
+    (void)std::find_if(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::find_if_not(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::for_each(first, last, UnaryVoid<T>(&copies)); assert(copies == 1); copies = 0;
 #if TEST_STD_VER > 14
-    (void)std::for_each_n(first, count, UnaryVoid(&copies)); assert(copies == 0);
+    (void)std::for_each_n(first, count, UnaryVoid<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::generate(first, last, NullaryValue(&copies)); assert(copies == 0);
-    (void)std::generate_n(first, count, NullaryValue(&copies)); assert(copies == 0);
-    (void)std::includes(first, last, first2, last2, Less(&copies)); assert(copies == 0);
-    (void)std::is_heap(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::is_heap_until(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::is_partitioned(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::is_permutation(first, last, first2, Equal(&copies)); assert(copies == 0);
+    (void)std::generate(first, last, NullaryValue<T>(&copies)); assert(copies == 0);
+    (void)std::generate_n(first, count, NullaryValue<T>(&copies)); assert(copies == 0);
+    (void)std::includes(first, last, first2, last2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::is_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::is_heap_until(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::is_partitioned(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::is_permutation(first, last, first2, Equal<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 11
-    (void)std::is_permutation(first, last, first2, last2, Equal(&copies)); assert(copies == 0);
+    (void)std::is_permutation(first, last, first2, last2, Equal<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::is_sorted(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::is_sorted_until(first, last, Less(&copies)); assert(copies == 0);
-    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::inplace_merge(first, mid, last, Less(&copies)); assert(copies == 0); }
-    (void)std::lexicographical_compare(first, last, first2, last2, Less(&copies)); assert(copies == 0);
+    (void)std::is_sorted(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::is_sorted_until(first, last, Less<T>(&copies)); assert(copies == 0);
+    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::inplace_merge(first, mid, last, Less<T>(&copies)); assert(copies == 0); }
+    (void)std::lexicographical_compare(first, last, first2, last2, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 17
     //(void)std::lexicographical_compare_three_way(first, last, first2, last2, ThreeWay(&copies)); assert(copies == 0);
 #endif
-    (void)std::lower_bound(first, last, value, Less(&copies)); assert(copies == 0);
-    (void)std::make_heap(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::max(value, value, Less(&copies)); assert(copies == 0);
+    (void)std::lower_bound(first, last, value, Less<T>(&copies)); assert(copies == 0);
+    (void)std::make_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::max(value, value, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::max({ value, value }, Less(&copies)); assert(copies == 0);
+    (void)std::max({ value, value }, Less<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::max_element(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::merge(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::min(value, value, Less(&copies)); assert(copies == 0);
+    (void)std::max_element(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::merge(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::min(value, value, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::min({ value, value }, Less(&copies)); assert(copies == 0);
+    (void)std::min({ value, value }, Less<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::min_element(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::minmax(value, value, Less(&copies)); assert(copies == 0);
+    (void)std::min_element(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::minmax(value, value, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::minmax({ value, value }, Less(&copies)); assert(copies == 0);
+    (void)std::minmax({ value, value }, Less<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::minmax_element(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::mismatch(first, last, first2, Equal(&copies)); assert(copies == 0);
+    (void)std::minmax_element(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::mismatch(first, last, first2, Equal<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 11
-    (void)std::mismatch(first, last, first2, last2, Equal(&copies)); assert(copies == 0);
+    (void)std::mismatch(first, last, first2, last2, Equal<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::next_permutation(first, last, Less(&copies)); assert(copies == 0);
+    (void)std::next_permutation(first, last, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::none_of(first, last, UnaryTrue(&copies)); assert(copies == 0);
+    (void)std::none_of(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::nth_element(first, mid, last, Less(&copies)); assert(copies == 0);
-    (void)std::partial_sort(first, mid, last, Less(&copies)); assert(copies == 0);
-    (void)std::partial_sort_copy(first, last, first2, mid2, Less(&copies)); assert(copies == 0);
-    (void)std::partition(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::partition_copy(first, last, first2, last2, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::partition_point(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::pop_heap(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::prev_permutation(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::push_heap(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::remove_copy_if(first, last, first2, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::remove_if(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::replace_copy_if(first, last, first2, UnaryTrue(&copies), value); assert(copies == 0);
-    (void)std::replace_if(first, last, UnaryTrue(&copies), value); assert(copies == 0);
-    (void)std::search(first, last, first2, mid2, Equal(&copies)); assert(copies == 0);
-    (void)std::search_n(first, last, count, value, Equal(&copies)); assert(copies == 0);
-    (void)std::set_
diff erence(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::set_intersection(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::set_symmetric_
diff erence(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::set_union(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::sort(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::sort_heap(first, last, Less(&copies)); assert(copies == 0);
-    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_partition(first, last, UnaryTrue(&copies)); assert(copies == 0); }
-    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_sort(first, last, Less(&copies)); assert(copies == 0); }
-    (void)std::transform(first, last, first2, UnaryTransform(&copies)); assert(copies == 0);
-    (void)std::transform(first, mid, mid, first2, BinaryTransform(&copies)); assert(copies == 0);
-    (void)std::unique(first, last, Equal(&copies)); assert(copies == 0);
-    (void)std::unique_copy(first, last, first2, Equal(&copies)); assert(copies == 0);
-    (void)std::upper_bound(first, last, value, Less(&copies)); assert(copies == 0);
+    (void)std::nth_element(first, mid, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::partial_sort(first, mid, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::partial_sort_copy(first, last, first2, mid2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::partition(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::partition_copy(first, last, first2, last2, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::partition_point(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::pop_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::prev_permutation(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::push_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::remove_copy_if(first, last, first2, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::remove_if(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::replace_copy_if(first, last, first2, UnaryTrue<T>(&copies), value); assert(copies == 0);
+    (void)std::replace_if(first, last, UnaryTrue<T>(&copies), value); assert(copies == 0);
+    (void)std::search(first, last, first2, mid2, Equal<T>(&copies)); assert(copies == 0);
+    (void)std::search_n(first, last, count, value, Equal<T>(&copies)); assert(copies == 0);
+    (void)std::set_
diff erence(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::set_intersection(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::set_symmetric_
diff erence(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::set_union(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort(first, first+3, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort(first, first+4, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort(first, first+5, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_partition(first, last, UnaryTrue<T>(&copies)); assert(copies == 0); }
+    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_sort(first, last, Less<T>(&copies)); assert(copies == 0); }
+    (void)std::transform(first, last, first2, UnaryTransform<T>(&copies)); assert(copies == 0);
+    (void)std::transform(first, mid, mid, first2, BinaryTransform<T>(&copies)); assert(copies == 0);
+    (void)std::unique(first, last, Equal<T>(&copies)); assert(copies == 0);
+    (void)std::unique_copy(first, last, first2, Equal<T>(&copies)); assert(copies == 0);
+    (void)std::upper_bound(first, last, value, Less<T>(&copies)); assert(copies == 0);
 
     return true;
 }
 
 int main(int, char**)
 {
-    all_the_algorithms();
+    all_the_algorithms<void*>();
+    all_the_algorithms<int>();
 #if TEST_STD_VER > 17
-    static_assert(all_the_algorithms());
+    static_assert(all_the_algorithms<void*>());
+    static_assert(all_the_algorithms<int>());
 #endif
 
     return 0;

diff  --git a/libcxx/test/std/algorithms/alg.sorting/alg.sort/sort/sort.pass.cpp b/libcxx/test/std/algorithms/alg.sorting/alg.sort/sort/sort.pass.cpp
index a36981cc41b02..b3c0a0d76eda1 100644
--- a/libcxx/test/std/algorithms/alg.sorting/alg.sort/sort/sort.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.sorting/alg.sort/sort/sort.pass.cpp
@@ -19,118 +19,157 @@
 #include <numeric>
 #include <random>
 #include <cassert>
+#include <vector>
+#include <deque>
 
 #include "test_macros.h"
 
 std::mt19937 randomness;
 
-template <class RI>
+template <class Container, class RI>
 void
 test_sort_helper(RI f, RI l)
 {
-    typedef typename std::iterator_traits<RI>::value_type value_type;
-    typedef typename std::iterator_traits<RI>::
diff erence_type 
diff erence_type;
-
     if (f != l)
     {
-        
diff erence_type len = l - f;
-        value_type* save(new value_type[len]);
+        Container save(l - f);
         do
         {
-            std::copy(f, l, save);
-            std::sort(save, save+len);
-            assert(std::is_sorted(save, save+len));
+            std::copy(f, l, save.begin());
+            std::sort(save.begin(), save.end());
+            assert(std::is_sorted(save.begin(), save.end()));
+            assert(std::is_permutation(save.begin(), save.end(), f));
         } while (std::next_permutation(f, l));
-        delete [] save;
     }
 }
 
-template <class RI>
+template <class T>
+void set_value(T& dest, int value)
+{
+    dest = value;
+}
+
+inline void set_value(std::pair<int, int>& dest, int value)
+{
+    dest.first = value;
+    dest.second = value;
+}
+
+template <class Container, class RI>
 void
 test_sort_driver_driver(RI f, RI l, int start, RI real_last)
 {
     for (RI i = l; i > f + start;)
     {
-        *--i = start;
+        set_value(*--i, start);
         if (f == i)
         {
-            test_sort_helper(f, real_last);
+            test_sort_helper<Container>(f, real_last);
         }
-    if (start > 0)
-        test_sort_driver_driver(f, i, start-1, real_last);
+        if (start > 0)
+            test_sort_driver_driver<Container>(f, i, start-1, real_last);
     }
 }
 
-template <class RI>
+template <class Container, class RI>
 void
 test_sort_driver(RI f, RI l, int start)
 {
-    test_sort_driver_driver(f, l, start, l);
+    test_sort_driver_driver<Container>(f, l, start, l);
 }
 
-template <int sa>
+template <class Container, int sa>
 void
 test_sort_()
 {
-    int ia[sa];
+    Container ia(sa);
     for (int i = 0; i < sa; ++i)
     {
-        test_sort_driver(ia, ia+sa, i);
+        test_sort_driver<Container>(ia.begin(), ia.end(), i);
     }
 }
 
+template <class T>
+T increment_or_reset(T value, int max_value)
+{
+    return value == max_value - 1 ? 0 : value + 1;
+}
+
+inline std::pair<int, int> increment_or_reset(std::pair<int, int> value,
+                                              int max_value)
+{
+    int new_value = value.first + 1;
+    if (new_value == max_value)
+    {
+        new_value = 0;
+    }
+    return std::make_pair(new_value, new_value);
+}
+
+template <class Container>
 void
 test_larger_sorts(int N, int M)
 {
+    using Iter = typename Container::iterator;
+    using ValueType = typename Container::value_type;
     assert(N != 0);
     assert(M != 0);
-    // create array length N filled with M 
diff erent numbers
-    int* array = new int[N];
-    int x = 0;
+    // create container of length N filled with M 
diff erent objects
+    Container array(N);
+    ValueType x = ValueType();
     for (int i = 0; i < N; ++i)
     {
         array[i] = x;
-        if (++x == M)
-            x = 0;
+        x = increment_or_reset(x, M);
     }
+    Container original = array;
+    Iter iter = array.begin();
+    Iter original_iter = original.begin();
+
     // test saw tooth pattern
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test random pattern
-    std::shuffle(array, array+N, randomness);
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::shuffle(iter, iter+N, randomness);
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test sorted pattern
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test reverse sorted pattern
-    std::reverse(array, array+N);
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::reverse(iter, iter+N);
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test swap ranges 2 pattern
-    std::swap_ranges(array, array+N/2, array+N/2);
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::swap_ranges(iter, iter+N/2, iter+N/2);
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test reverse swap ranges 2 pattern
-    std::reverse(array, array+N);
-    std::swap_ranges(array, array+N/2, array+N/2);
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
-    delete [] array;
+    std::reverse(iter, iter+N);
+    std::swap_ranges(iter, iter+N/2, iter+N/2);
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
 }
 
+template <class Container>
 void
 test_larger_sorts(int N)
 {
-    test_larger_sorts(N, 1);
-    test_larger_sorts(N, 2);
-    test_larger_sorts(N, 3);
-    test_larger_sorts(N, N/2-1);
-    test_larger_sorts(N, N/2);
-    test_larger_sorts(N, N/2+1);
-    test_larger_sorts(N, N-2);
-    test_larger_sorts(N, N-1);
-    test_larger_sorts(N, N);
+    test_larger_sorts<Container>(N, 1);
+    test_larger_sorts<Container>(N, 2);
+    test_larger_sorts<Container>(N, 3);
+    test_larger_sorts<Container>(N, N/2-1);
+    test_larger_sorts<Container>(N, N/2);
+    test_larger_sorts<Container>(N, N/2+1);
+    test_larger_sorts<Container>(N, N-2);
+    test_larger_sorts<Container>(N, N-1);
+    test_larger_sorts<Container>(N, N);
 }
 
 void
@@ -205,28 +244,40 @@ void test_adversarial_quicksort(int N) {
   assert(std::is_sorted(V.begin(), V.end()));
 }
 
-int main(int, char**)
+template <class Container>
+void run_sort_tests()
 {
     // test null range
-    int d = 0;
+    using ValueType = typename Container::value_type;
+    ValueType d = ValueType();
     std::sort(&d, &d);
+
     // exhaustively test all possibilities up to length 8
-    test_sort_<1>();
-    test_sort_<2>();
-    test_sort_<3>();
-    test_sort_<4>();
-    test_sort_<5>();
-    test_sort_<6>();
-    test_sort_<7>();
-    test_sort_<8>();
-
-    test_larger_sorts(256);
-    test_larger_sorts(257);
-    test_larger_sorts(499);
-    test_larger_sorts(500);
-    test_larger_sorts(997);
-    test_larger_sorts(1000);
-    test_larger_sorts(1009);
+    test_sort_<Container, 1>();
+    test_sort_<Container, 2>();
+    test_sort_<Container, 3>();
+    test_sort_<Container, 4>();
+    test_sort_<Container, 5>();
+    test_sort_<Container, 6>();
+    test_sort_<Container, 7>();
+    test_sort_<Container, 8>();
+
+    test_larger_sorts<Container>(256);
+    test_larger_sorts<Container>(257);
+    test_larger_sorts<Container>(499);
+    test_larger_sorts<Container>(500);
+    test_larger_sorts<Container>(997);
+    test_larger_sorts<Container>(1000);
+    test_larger_sorts<Container>(1009);
+}
+
+int main(int, char**)
+{
+    // test various combinations of contiguous/non-contiguous containers with
+    // arithmetic/non-arithmetic types
+    run_sort_tests<std::vector<int> >();
+    run_sort_tests<std::deque<int> >();
+    run_sort_tests<std::vector<std::pair<int, int> > >();
 
     test_pointer_sort();
     test_adversarial_quicksort(1 << 20);


        


More information about the libcxx-commits mailing list