[libcxx-commits] [libcxx] [libcxx][algorithm] Optimize std::stable_sort via radix sort algorithm (PR #104683)

Дмитрий Изволов via libcxx-commits libcxx-commits at lists.llvm.org
Wed Jan 8 01:12:25 PST 2025


https://github.com/izvolov updated https://github.com/llvm/llvm-project/pull/104683

>From 75155aec4551b8023228a0b14a57a0898bc3d42f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=94=D0=BC=D0=B8=D1=82=D1=80=D0=B8=D0=B9=20=D0=98=D0=B7?=
 =?UTF-8?q?=D0=B2=D0=BE=D0=BB=D0=BE=D0=B2?= <dmitriy at izvolov.ru>
Date: Wed, 8 Jan 2025 12:11:33 +0300
Subject: [PATCH] [libcxx][algorithm] Optimize std::stable_sort via radix sort
 algorithm

---
 libcxx/docs/ReleaseNotes/19.rst               |   2 +
 libcxx/include/CMakeLists.txt                 |   1 +
 libcxx/include/__algorithm/radix_sort.h       | 332 ++++++++++++++++++
 libcxx/include/__algorithm/stable_sort.h      |  44 +++
 libcxx/include/__bit/bit_log2.h               |  11 +-
 libcxx/include/module.modulemap               |   1 +
 .../alg.sort/stable.sort/stable_sort.pass.cpp | 117 +++---
 7 files changed, 449 insertions(+), 59 deletions(-)
 create mode 100644 libcxx/include/__algorithm/radix_sort.h

diff --git a/libcxx/docs/ReleaseNotes/19.rst b/libcxx/docs/ReleaseNotes/19.rst
index e8f76773c54373..aec7865d52fada 100644
--- a/libcxx/docs/ReleaseNotes/19.rst
+++ b/libcxx/docs/ReleaseNotes/19.rst
@@ -126,6 +126,8 @@ Improvements and New Features
 
 - In C++23 and C++26 the number of transitive includes in several headers has been reduced, improving the compilation speed.
 
+- ``std::stable_sort`` uses radix sort for integral types now, which can improve the performance up to 10 times, depending
+  on type of sorted elements and the initial state of the sorted array.
 
 Deprecations and Removals
 -------------------------
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 0b484ebe5e87c8..96071527b52837 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -73,6 +73,7 @@ set(files
   __algorithm/prev_permutation.h
   __algorithm/pstl.h
   __algorithm/push_heap.h
+  __algorithm/radix_sort.h
   __algorithm/ranges_adjacent_find.h
   __algorithm/ranges_all_of.h
   __algorithm/ranges_any_of.h
diff --git a/libcxx/include/__algorithm/radix_sort.h b/libcxx/include/__algorithm/radix_sort.h
new file mode 100644
index 00000000000000..c39b26a8620f46
--- /dev/null
+++ b/libcxx/include/__algorithm/radix_sort.h
@@ -0,0 +1,332 @@
+// -*- C++ -*-
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ALGORITHM_RADIX_SORT_H
+#define _LIBCPP___ALGORITHM_RADIX_SORT_H
+
+// This is an implementation of classic LSD radix sort algorithm, running in linear time and using `O(max(N, M))`
+// additional memory, where `N` is size of an input range, `M` - maximum value of
+// a radix of the sorted integer type. Type of the radix and its maximum value are determined at compile time
+// based on type returned by function `__radix`. The default radix is uint8.
+
+// The algorithm is equivalent to several consecutive calls of counting sort for each
+// radix of the sorted numbers from low to high byte.
+// The algorithm uses a temporary buffer of size equal to size of the input range. Each `i`-th pass
+// of the algorithm sorts values by `i`-th radix and moves values to the temporary buffer (for each even `i`, counted
+// from zero), or moves them back to the initial range (for each odd `i`). If there is only one radix in sorted integers
+// (e.g. int8), the sorted values are placed to the buffer, and then moved back to the initial range.
+
+// The implementation also has several optimizations:
+// - the counters for the counting sort are calculated in one pass for all radices;
+// - if all values of a radix are the same, we do not sort that radix, and just move items to the buffer;
+// - if two consecutive radices satisfies condition above, we do nothing for these two radices.
+
+#include <__algorithm/for_each.h>
+#include <__algorithm/move.h>
+#include <__bit/bit_log2.h>
+#include <__bit/countl.h>
+#include <__config>
+#include <__functional/identity.h>
+#include <__iterator/distance.h>
+#include <__iterator/iterator_traits.h>
+#include <__iterator/move_iterator.h>
+#include <__iterator/next.h>
+#include <__iterator/reverse_iterator.h>
+#include <__numeric/partial_sum.h>
+#include <__type_traits/decay.h>
+#include <__type_traits/enable_if.h>
+#include <__type_traits/invoke.h>
+#include <__type_traits/is_assignable.h>
+#include <__type_traits/is_integral.h>
+#include <__type_traits/is_unsigned.h>
+#include <__type_traits/make_unsigned.h>
+#include <__utility/forward.h>
+#include <__utility/integer_sequence.h>
+#include <__utility/move.h>
+#include <__utility/pair.h>
+#include <climits>
+#include <cstdint>
+#include <initializer_list>
+#include <limits>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+_LIBCPP_PUSH_MACROS
+#include <__undef_macros>
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+#if _LIBCPP_STD_VER >= 14
+
+template <class _InputIterator, class _OutputIterator>
+_LIBCPP_HIDE_FROM_ABI pair<_OutputIterator, __iter_value_type<_InputIterator>>
+__partial_sum_max(_InputIterator __first, _InputIterator __last, _OutputIterator __result) {
+  if (__first == __last)
+    return {__result, 0};
+
+  auto __max                              = *__first;
+  __iter_value_type<_InputIterator> __sum = *__first;
+  *__result                               = __sum;
+
+  while (++__first != __last) {
+    if (__max < *__first) {
+      __max = *__first;
+    }
+    __sum       = std::move(__sum) + *__first;
+    *++__result = __sum;
+  }
+  return {++__result, __max};
+}
+
+template <class _Value, class _Map, class _Radix>
+struct __radix_sort_traits {
+  using __image_type = decay_t<typename __invoke_of<_Map, _Value>::type>;
+  static_assert(is_unsigned<__image_type>::value);
+
+  using __radix_type = decay_t<typename __invoke_of<_Radix, __image_type>::type>;
+  static_assert(is_integral<__radix_type>::value);
+
+  static constexpr auto __radix_value_range = numeric_limits<__radix_type>::max() + 1;
+  static constexpr auto __radix_size        = std::__bit_log2<uint64_t>(__radix_value_range);
+  static constexpr auto __radix_count       = sizeof(__image_type) * CHAR_BIT / __radix_size;
+};
+
+template <class _Value, class _Map>
+struct __counting_sort_traits {
+  using __image_type = decay_t<typename __invoke_of<_Map, _Value>::type>;
+  static_assert(is_unsigned<__image_type>::value);
+
+  static constexpr const auto __value_range = numeric_limits<__image_type>::max() + 1;
+  static constexpr auto __radix_size        = std::__bit_log2<uint64_t>(__value_range);
+};
+
+template <class _Radix, class _Integer>
+_LIBCPP_HIDE_FROM_ABI auto __nth_radix(size_t __radix_number, _Radix __radix, _Integer __n) {
+  static_assert(is_unsigned<_Integer>::value);
+  using __traits = __counting_sort_traits<_Integer, _Radix>;
+
+  return __radix(static_cast<_Integer>(__n >> __traits::__radix_size * __radix_number));
+}
+
+template <class _ForwardIterator, class _Map, class _RandomAccessIterator>
+_LIBCPP_HIDE_FROM_ABI void
+__collect(_ForwardIterator __first, _ForwardIterator __last, _Map __map, _RandomAccessIterator __counters) {
+  using __value_type = __iter_value_type<_ForwardIterator>;
+  using __traits     = __counting_sort_traits<__value_type, _Map>;
+
+  std::for_each(__first, __last, [&__counters, &__map](const auto& __preimage) { ++__counters[__map(__preimage)]; });
+
+  const auto __counters_end = __counters + __traits::__value_range;
+  std::partial_sum(__counters, __counters_end, __counters);
+}
+
+template <class _ForwardIterator, class _RandomAccessIterator1, class _Map, class _RandomAccessIterator2>
+_LIBCPP_HIDE_FROM_ABI void
+__dispose(_ForwardIterator __first,
+          _ForwardIterator __last,
+          _RandomAccessIterator1 __result,
+          _Map __map,
+          _RandomAccessIterator2 __counters) {
+  std::for_each(__first, __last, [&__result, &__counters, &__map](auto&& __preimage) {
+    auto __index      = __counters[__map(__preimage)]++;
+    __result[__index] = std::move(__preimage);
+  });
+}
+
+template <class _ForwardIterator,
+          class _Map,
+          class _Radix,
+          class _RandomAccessIterator1,
+          class _RandomAccessIterator2,
+          size_t... _Radices>
+_LIBCPP_HIDE_FROM_ABI bool __collect_impl(
+    _ForwardIterator __first,
+    _ForwardIterator __last,
+    _Map __map,
+    _Radix __radix,
+    _RandomAccessIterator1 __counters,
+    _RandomAccessIterator2 __maximums,
+    index_sequence<_Radices...>) {
+  using __value_type                 = __iter_value_type<_ForwardIterator>;
+  constexpr auto __radix_value_range = __radix_sort_traits<__value_type, _Map, _Radix>::__radix_value_range;
+
+  auto __previous  = numeric_limits<typename __invoke_of<_Map, __value_type>::type>::min();
+  auto __is_sorted = true;
+  std::for_each(__first, __last, [&__counters, &__map, &__radix, &__previous, &__is_sorted](const auto& __value) {
+    auto __current = __map(__value);
+    __is_sorted &= (__current >= __previous);
+    __previous = __current;
+
+    (++__counters[_Radices][std::__nth_radix(_Radices, __radix, __current)], ...);
+  });
+
+  ((__maximums[_Radices] =
+        std::__partial_sum_max(__counters[_Radices], __counters[_Radices] + __radix_value_range, __counters[_Radices])
+            .second),
+   ...);
+
+  return __is_sorted;
+}
+
+template <class _ForwardIterator, class _Map, class _Radix, class _RandomAccessIterator1, class _RandomAccessIterator2>
+_LIBCPP_HIDE_FROM_ABI bool
+__collect(_ForwardIterator __first,
+          _ForwardIterator __last,
+          _Map __map,
+          _Radix __radix,
+          _RandomAccessIterator1 __counters,
+          _RandomAccessIterator2 __maximums) {
+  using __value_type           = __iter_value_type<_ForwardIterator>;
+  constexpr auto __radix_count = __radix_sort_traits<__value_type, _Map, _Radix>::__radix_count;
+  return std::__collect_impl(
+      __first, __last, __map, __radix, __counters, __maximums, make_index_sequence<__radix_count>());
+}
+
+template <class _BidirectionalIterator, class _RandomAccessIterator1, class _Map, class _RandomAccessIterator2>
+_LIBCPP_HIDE_FROM_ABI void __dispose_backward(
+    _BidirectionalIterator __first,
+    _BidirectionalIterator __last,
+    _RandomAccessIterator1 __result,
+    _Map __map,
+    _RandomAccessIterator2 __counters) {
+  std::for_each(std::make_reverse_iterator(__last),
+                std::make_reverse_iterator(__first),
+                [&__result, &__counters, &__map](auto&& __preimage) {
+                  auto __index      = --__counters[__map(__preimage)];
+                  __result[__index] = std::move(__preimage);
+                });
+}
+
+template <class _ForwardIterator, class _RandomAccessIterator, class _Map>
+_LIBCPP_HIDE_FROM_ABI _RandomAccessIterator
+__counting_sort_impl(_ForwardIterator __first, _ForwardIterator __last, _RandomAccessIterator __result, _Map __map) {
+  using __value_type = __iter_value_type<_ForwardIterator>;
+  using __traits     = __counting_sort_traits<__value_type, _Map>;
+
+  __iter_diff_t<_RandomAccessIterator> __counters[__traits::__value_range + 1] = {0};
+
+  std::__collect(__first, __last, __map, std::next(std::begin(__counters)));
+  std::__dispose(__first, __last, __result, __map, std::begin(__counters));
+
+  return __result + __counters[__traits::__value_range];
+}
+
+template <class _RandomAccessIterator1,
+          class _RandomAccessIterator2,
+          class _Map,
+          class _Radix,
+          enable_if_t< __radix_sort_traits<__iter_value_type<_RandomAccessIterator1>, _Map, _Radix>::__radix_count == 1,
+                       int> = 0>
+_LIBCPP_HIDE_FROM_ABI void __radix_sort_impl(
+    _RandomAccessIterator1 __first,
+    _RandomAccessIterator1 __last,
+    _RandomAccessIterator2 __buffer,
+    _Map __map,
+    _Radix __radix) {
+  auto __buffer_end = std::__counting_sort_impl(__first, __last, __buffer, [&__map, &__radix](const auto& __value) {
+    return __radix(__map(__value));
+  });
+
+  std::move(__buffer, __buffer_end, __first);
+}
+
+template <
+    class _RandomAccessIterator1,
+    class _RandomAccessIterator2,
+    class _Map,
+    class _Radix,
+    enable_if_t< __radix_sort_traits<__iter_value_type<_RandomAccessIterator1>, _Map, _Radix>::__radix_count % 2 == 0,
+                 int> = 0 >
+_LIBCPP_HIDE_FROM_ABI void __radix_sort_impl(
+    _RandomAccessIterator1 __first,
+    _RandomAccessIterator1 __last,
+    _RandomAccessIterator2 __buffer_begin,
+    _Map __map,
+    _Radix __radix) {
+  using __value_type = __iter_value_type<_RandomAccessIterator1>;
+  using __traits     = __radix_sort_traits<__value_type, _Map, _Radix>;
+
+  __iter_diff_t<_RandomAccessIterator1> __counters[__traits::__radix_count][__traits::__radix_value_range] = {{0}};
+  __iter_diff_t<_RandomAccessIterator1> __maximums[__traits::__radix_count]                                = {0};
+  const auto __is_sorted = std::__collect(__first, __last, __map, __radix, __counters, __maximums);
+  if (!__is_sorted) {
+    const auto __range_size = std::distance(__first, __last);
+    auto __buffer_end       = __buffer_begin + __range_size;
+    for (size_t __radix_number = 0; __radix_number < __traits::__radix_count; __radix_number += 2) {
+      const auto __n0th_is_single = __maximums[__radix_number] == __range_size;
+      const auto __n1th_is_single = __maximums[__radix_number + 1] == __range_size;
+
+      if (__n0th_is_single && __n1th_is_single) {
+        continue;
+      }
+
+      if (__n0th_is_single) {
+        std::move(__first, __last, __buffer_begin);
+      } else {
+        auto __n0th = [__radix_number, &__map, &__radix](const auto& __v) {
+          return std::__nth_radix(__radix_number, __radix, __map(__v));
+        };
+        std::__dispose_backward(__first, __last, __buffer_begin, __n0th, __counters[__radix_number]);
+      }
+
+      if (__n1th_is_single) {
+        std::move(__buffer_begin, __buffer_end, __first);
+      } else {
+        auto __n1th = [__radix_number, &__map, &__radix](const auto& __v) {
+          return std::__nth_radix(__radix_number + 1, __radix, __map(__v));
+        };
+        std::__dispose_backward(__buffer_begin, __buffer_end, __first, __n1th, __counters[__radix_number + 1]);
+      }
+    }
+  }
+}
+
+_LIBCPP_HIDE_FROM_ABI constexpr auto __shift_to_unsigned(bool __b) { return __b; }
+
+template <class _Ip>
+_LIBCPP_HIDE_FROM_ABI constexpr auto __shift_to_unsigned(_Ip __n) {
+  constexpr const auto __min_value = numeric_limits<_Ip>::min();
+  return static_cast<make_unsigned_t<_Ip> >(__n ^ __min_value);
+}
+
+struct __low_byte_fn {
+  template <class _Ip>
+  _LIBCPP_HIDE_FROM_ABI constexpr uint8_t operator()(_Ip __integer) const {
+    static_assert(is_unsigned<_Ip>::value);
+
+    return static_cast<uint8_t>(__integer & 0xff);
+  }
+};
+
+template <class _RandomAccessIterator1, class _RandomAccessIterator2, class _Map, class _Radix>
+_LIBCPP_HIDE_FROM_ABI void
+__radix_sort(_RandomAccessIterator1 __first,
+             _RandomAccessIterator1 __last,
+             _RandomAccessIterator2 __buffer,
+             _Map __map,
+             _Radix __radix) {
+  auto __map_to_unsigned = [__map = std::move(__map)](const auto& __x) { return std::__shift_to_unsigned(__map(__x)); };
+  std::__radix_sort_impl(__first, __last, __buffer, __map_to_unsigned, __radix);
+}
+
+template <class _RandomAccessIterator1, class _RandomAccessIterator2>
+_LIBCPP_HIDE_FROM_ABI void
+__radix_sort(_RandomAccessIterator1 __first, _RandomAccessIterator1 __last, _RandomAccessIterator2 __buffer) {
+  std::__radix_sort(__first, __last, __buffer, __identity{}, __low_byte_fn{});
+}
+
+#endif // _LIBCPP_STD_VER >= 14
+
+_LIBCPP_END_NAMESPACE_STD
+
+_LIBCPP_POP_MACROS
+
+#endif // _LIBCPP___ALGORITHM_RADIX_SORT_H
diff --git a/libcxx/include/__algorithm/stable_sort.h b/libcxx/include/__algorithm/stable_sort.h
index 1111f5509bc384..70a85023a17f0d 100644
--- a/libcxx/include/__algorithm/stable_sort.h
+++ b/libcxx/include/__algorithm/stable_sort.h
@@ -13,6 +13,7 @@
 #include <__algorithm/comp_ref_type.h>
 #include <__algorithm/inplace_merge.h>
 #include <__algorithm/iterator_operations.h>
+#include <__algorithm/radix_sort.h>
 #include <__algorithm/sort.h>
 #include <__config>
 #include <__cstddef/ptrdiff_t.h>
@@ -21,7 +22,12 @@
 #include <__memory/destruct_n.h>
 #include <__memory/unique_ptr.h>
 #include <__memory/unique_temporary_buffer.h>
+#include <__type_traits/desugars_to.h>
+#include <__type_traits/enable_if.h>
+#include <__type_traits/is_integral.h>
+#include <__type_traits/is_same.h>
 #include <__type_traits/is_trivially_assignable.h>
+#include <__type_traits/remove_cvref.h>
 #include <__utility/move.h>
 #include <__utility/pair.h>
 
@@ -189,6 +195,28 @@ struct __stable_sort_switch {
   static const unsigned value = 128 * is_trivially_copy_assignable<_Tp>::value;
 };
 
+#if _LIBCPP_STD_VER >= 17
+template <class _Tp>
+_LIBCPP_HIDE_FROM_ABI constexpr unsigned __radix_sort_min_bound() {
+  static_assert(is_integral<_Tp>::value);
+  if constexpr (sizeof(_Tp) == 1) {
+    return 1 << 8;
+  }
+
+  return 1 << 10;
+}
+
+template <class _Tp>
+_LIBCPP_HIDE_FROM_ABI constexpr unsigned __radix_sort_max_bound() {
+  static_assert(is_integral<_Tp>::value);
+  if constexpr (sizeof(_Tp) >= 8) {
+    return 1 << 15;
+  }
+
+  return 1 << 16;
+}
+#endif // _LIBCPP_STD_VER >= 17
+
 template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
 void __stable_sort(_RandomAccessIterator __first,
                    _RandomAccessIterator __last,
@@ -211,6 +239,22 @@ void __stable_sort(_RandomAccessIterator __first,
     std::__insertion_sort<_AlgPolicy, _Compare>(__first, __last, __comp);
     return;
   }
+
+#if _LIBCPP_STD_VER >= 17
+  constexpr auto __default_comp =
+      __desugars_to_v<__totally_ordered_less_tag, __remove_cvref_t<_Compare>, value_type, value_type >;
+  constexpr auto __integral_value =
+      is_integral_v<value_type > && is_same_v< value_type&, __iter_reference<_RandomAccessIterator>>;
+  constexpr auto __allowed_radix_sort = __default_comp && __integral_value;
+  if constexpr (__allowed_radix_sort) {
+    if (__len <= __buff_size && __len >= static_cast<difference_type>(__radix_sort_min_bound<value_type>()) &&
+        __len <= static_cast<difference_type>(__radix_sort_max_bound<value_type>())) {
+      std::__radix_sort(__first, __last, __buff);
+      return;
+    }
+  }
+#endif // _LIBCPP_STD_VER >= 17
+
   typename iterator_traits<_RandomAccessIterator>::difference_type __l2 = __len / 2;
   _RandomAccessIterator __m                                             = __first + __l2;
   if (__len <= __buff_size) {
diff --git a/libcxx/include/__bit/bit_log2.h b/libcxx/include/__bit/bit_log2.h
index 62936f67868600..94ee6c3b2bb1d7 100644
--- a/libcxx/include/__bit/bit_log2.h
+++ b/libcxx/include/__bit/bit_log2.h
@@ -10,8 +10,8 @@
 #define _LIBCPP___BIT_BIT_LOG2_H
 
 #include <__bit/countl.h>
-#include <__concepts/arithmetic.h>
 #include <__config>
+#include <__type_traits/is_unsigned_integer.h>
 #include <limits>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -20,14 +20,15 @@
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
-#if _LIBCPP_STD_VER >= 20
+#if _LIBCPP_STD_VER >= 14
 
-template <__libcpp_unsigned_integer _Tp>
+template <class _Tp>
 _LIBCPP_HIDE_FROM_ABI constexpr _Tp __bit_log2(_Tp __t) noexcept {
-  return numeric_limits<_Tp>::digits - 1 - std::countl_zero(__t);
+  static_assert(__libcpp_is_unsigned_integer<_Tp>::value, "__bit_log2 requires an unsigned integer type");
+  return numeric_limits<_Tp>::digits - 1 - std::__countl_zero(__t);
 }
 
-#endif // _LIBCPP_STD_VER >= 20
+#endif // _LIBCPP_STD_VER >= 14
 
 _LIBCPP_END_NAMESPACE_STD
 
diff --git a/libcxx/include/module.modulemap b/libcxx/include/module.modulemap
index 86efbd36b20d1d..70aee7da79cba8 100644
--- a/libcxx/include/module.modulemap
+++ b/libcxx/include/module.modulemap
@@ -488,6 +488,7 @@ module std [system] {
     module prev_permutation                       { header "__algorithm/prev_permutation.h" }
     module pstl                                   { header "__algorithm/pstl.h" }
     module push_heap                              { header "__algorithm/push_heap.h" }
+    module radix_sort                             { header "__algorithm/radix_sort.h" }
     module ranges_adjacent_find                   { header "__algorithm/ranges_adjacent_find.h" }
     module ranges_all_of                          { header "__algorithm/ranges_all_of.h" }
     module ranges_any_of                          { header "__algorithm/ranges_any_of.h" }
diff --git a/libcxx/test/std/algorithms/alg.sorting/alg.sort/stable.sort/stable_sort.pass.cpp b/libcxx/test/std/algorithms/alg.sorting/alg.sort/stable.sort/stable_sort.pass.cpp
index 4301d22027de85..621234354092d3 100644
--- a/libcxx/test/std/algorithms/alg.sorting/alg.sort/stable.sort/stable_sort.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.sorting/alg.sort/stable.sort/stable_sort.pass.cpp
@@ -50,16 +50,16 @@ template <class RI>
 void
 test_sort_driver_driver(RI f, RI l, int start, RI real_last)
 {
-    for (RI i = l; i > f + start;)
-    {
-        *--i = start;
-        if (f == i)
-        {
-            test_sort_helper(f, real_last);
-        }
+  using value_type = typename std::iterator_traits<RI>::value_type;
+
+  for (RI i = l; i > f + start;) {
+    *--i = static_cast<value_type>(start);
+    if (f == i) {
+      test_sort_helper(f, real_last);
+    }
     if (start > 0)
         test_sort_driver_driver(f, i, start-1, real_last);
-    }
+  }
 }
 
 template <class RI>
@@ -69,55 +69,61 @@ test_sort_driver(RI f, RI l, int start)
     test_sort_driver_driver(f, l, start, l);
 }
 
+template <int sa, class V>
+void test_sort_() {
+  V ia[sa];
+  for (int i = 0; i < sa; ++i) {
+    test_sort_driver(ia, ia + sa, i);
+  }
+}
+
 template <int sa>
-void
-test_sort_()
-{
-    int ia[sa];
-    for (int i = 0; i < sa; ++i)
-    {
-        test_sort_driver(ia, ia+sa, i);
-    }
+void test_sort_() {
+  test_sort_<sa, int>();
+  test_sort_<sa, float>();
 }
 
-void
-test_larger_sorts(int N, int M)
-{
-    assert(N != 0);
-    assert(M != 0);
-    // create array length N filled with M different numbers
-    int* array = new int[N];
-    int x = 0;
-    for (int i = 0; i < N; ++i)
-    {
-        array[i] = x;
-        if (++x == M)
-            x = 0;
-    }
-    // test saw tooth pattern
-    std::stable_sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
-    // test random pattern
-    std::shuffle(array, array+N, randomness);
-    std::stable_sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
-    // test sorted pattern
-    std::stable_sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
-    // test reverse sorted pattern
-    std::reverse(array, array+N);
-    std::stable_sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
-    // test swap ranges 2 pattern
-    std::swap_ranges(array, array+N/2, array+N/2);
-    std::stable_sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
-    // test reverse swap ranges 2 pattern
-    std::reverse(array, array+N);
-    std::swap_ranges(array, array+N/2, array+N/2);
-    std::stable_sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
-    delete [] array;
+template <class V>
+void test_larger_sorts(int N, int M) {
+  assert(N != 0);
+  assert(M != 0);
+  // create array length N filled with M different numbers
+  V* array = new V[N];
+  int x    = 0;
+  for (int i = 0; i < N; ++i) {
+    array[i] = static_cast<V>(x);
+    if (++x == M)
+      x = 0;
+  }
+  // test saw tooth pattern
+  std::stable_sort(array, array + N);
+  assert(std::is_sorted(array, array + N));
+  // test random pattern
+  std::shuffle(array, array + N, randomness);
+  std::stable_sort(array, array + N);
+  assert(std::is_sorted(array, array + N));
+  // test sorted pattern
+  std::stable_sort(array, array + N);
+  assert(std::is_sorted(array, array + N));
+  // test reverse sorted pattern
+  std::reverse(array, array + N);
+  std::stable_sort(array, array + N);
+  assert(std::is_sorted(array, array + N));
+  // test swap ranges 2 pattern
+  std::swap_ranges(array, array + N / 2, array + N / 2);
+  std::stable_sort(array, array + N);
+  assert(std::is_sorted(array, array + N));
+  // test reverse swap ranges 2 pattern
+  std::reverse(array, array + N);
+  std::swap_ranges(array, array + N / 2, array + N / 2);
+  std::stable_sort(array, array + N);
+  assert(std::is_sorted(array, array + N));
+  delete[] array;
+}
+
+void test_larger_sorts(int N, int M) {
+  test_larger_sorts<int>(N, M);
+  test_larger_sorts<float>(N, M);
 }
 
 void
@@ -156,6 +162,9 @@ int main(int, char**)
     test_larger_sorts(997);
     test_larger_sorts(1000);
     test_larger_sorts(1009);
+    test_larger_sorts(1024);
+    test_larger_sorts(1031);
+    test_larger_sorts(2053);
 
 #if !defined(TEST_HAS_NO_EXCEPTIONS)
     { // check that the algorithm works without memory



More information about the libcxx-commits mailing list