[libcxx-commits] [libcxx] [libcxx][algorithm] Optimize std::stable_sort via radix sort algorithm (PR #104683)

Дмитрий Изволов via libcxx-commits libcxx-commits at lists.llvm.org
Thu Nov 28 12:40:14 PST 2024


================
@@ -0,0 +1,332 @@
+// -*- C++ -*-
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ALGORITHM_RADIX_SORT_H
+#define _LIBCPP___ALGORITHM_RADIX_SORT_H
+
+// This is an implementation of classic LSD radix sort algorithm, running in linear time and using `O(max(N, M))`
+// additional memory, where `N` is size of an input range, `M` - maximum value of
+// a radix of the sorted integer type. Type of the radix and its maximum value are determined at compile time
+// based on type returned by function `__radix`. The default radix is uint8.
+
+// The algorithm is equivalent to several consecutive calls of counting sort for each
+// radix of the sorted numbers from low to high byte.
+// The algorithm uses a temporary buffer of size equal to size of the input range. Each `i`-th pass
+// of the algorithm sorts values by `i`-th radix and moves values to the temporary buffer (for each even `i`, counted
+// from zero), or moves them back to the initial range (for each odd `i`). If there is only one radix in sorted integers
+// (e.g. int8), the sorted values are placed to the buffer, and then moved back to the initial range.
+
+// The implementation also has several optimizations:
+// - the counters for the counting sort are calculated in one pass for all radices;
+// - if all values of a radix are the same, we do not sort that radix, and just move items to the buffer;
+// - if two consecutive radices satisfies condition above, we do nothing for these two radices.
+
+#include <__algorithm/for_each.h>
+#include <__algorithm/move.h>
+#include <__bit/bit_log2.h>
+#include <__bit/countl.h>
+#include <__config>
+#include <__functional/identity.h>
+#include <__iterator/distance.h>
+#include <__iterator/iterator_traits.h>
+#include <__iterator/move_iterator.h>
+#include <__iterator/next.h>
+#include <__iterator/reverse_iterator.h>
+#include <__numeric/partial_sum.h>
+#include <__type_traits/decay.h>
+#include <__type_traits/enable_if.h>
+#include <__type_traits/invoke.h>
+#include <__type_traits/is_assignable.h>
+#include <__type_traits/is_integral.h>
+#include <__type_traits/is_unsigned.h>
+#include <__type_traits/make_unsigned.h>
+#include <__utility/forward.h>
+#include <__utility/integer_sequence.h>
+#include <__utility/move.h>
+#include <__utility/pair.h>
+#include <climits>
+#include <cstdint>
+#include <initializer_list>
+#include <limits>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+_LIBCPP_PUSH_MACROS
+#include <__undef_macros>
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+#if _LIBCPP_STD_VER >= 14
+
+template <class _InputIterator, class _OutputIterator>
+_LIBCPP_HIDE_FROM_ABI pair<_OutputIterator, __iter_value_type<_InputIterator>>
+__partial_sum_max(_InputIterator __first, _InputIterator __last, _OutputIterator __result) {
+  if (__first == __last)
+    return {__result, 0};
+
+  auto __max                              = *__first;
+  __iter_value_type<_InputIterator> __sum = *__first;
+  *__result                               = __sum;
+
+  while (++__first != __last) {
+    if (__max < *__first) {
+      __max = *__first;
+    }
+    __sum       = std::move(__sum) + *__first;
+    *++__result = __sum;
+  }
+  return {++__result, __max};
+}
+
+template <class _Value, class _Map, class _Radix>
+struct __radix_sort_traits {
+  using __image_type = decay_t<typename __invoke_of<_Map, _Value>::type>;
+  static_assert(is_unsigned<__image_type>::value, "");
+
+  using __radix_type = decay_t<typename __invoke_of<_Radix, __image_type>::type>;
+  static_assert(is_integral<__radix_type>::value, "");
+
+  constexpr static auto __radix_value_range = numeric_limits<__radix_type>::max() + 1;
+  constexpr static auto __radix_size        = std::__bit_log2<uint64_t>(__radix_value_range);
+  constexpr static auto __radix_count       = sizeof(__image_type) * CHAR_BIT / __radix_size;
----------------
izvolov wrote:

Done.

https://github.com/llvm/llvm-project/pull/104683


More information about the libcxx-commits mailing list