[libcxx-commits] [libcxx] [libc++] Optimize make_heap() and sift_down() (PR #121480)

Yang Kun via libcxx-commits libcxx-commits at lists.llvm.org
Sun Jan 19 03:59:56 PST 2025


https://github.com/omikrun updated https://github.com/llvm/llvm-project/pull/121480

>From 1b61da143d85fb3dc26007fee338957059c497ba Mon Sep 17 00:00:00 2001
From: Yang Kun <193369907+omikrun at users.noreply.github.com>
Date: Sat, 4 Jan 2025 19:44:44 +0800
Subject: [PATCH] [libc++] Optimize make_heap() and sift_down()

---
 libcxx/include/__algorithm/make_heap.h |  8 ++--
 libcxx/include/__algorithm/pop_heap.h  |  3 +-
 libcxx/include/__algorithm/push_heap.h | 16 +++----
 libcxx/include/__algorithm/sift_down.h | 64 +++++++++++++++-----------
 4 files changed, 50 insertions(+), 41 deletions(-)

diff --git a/libcxx/include/__algorithm/make_heap.h b/libcxx/include/__algorithm/make_heap.h
index e8f0cdb27333a4..746c2621fa70e6 100644
--- a/libcxx/include/__algorithm/make_heap.h
+++ b/libcxx/include/__algorithm/make_heap.h
@@ -34,10 +34,12 @@ __make_heap(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compar
   using difference_type = typename iterator_traits<_RandomAccessIterator>::difference_type;
   difference_type __n   = __last - __first;
   if (__n > 1) {
-    // start from the first parent, there is no need to consider children
-    for (difference_type __start = (__n - 2) / 2; __start >= 0; --__start) {
+    difference_type __start = __n / 2;
+    do {
+      // start from the first parent, there is no need to consider children
+      --__start;
       std::__sift_down<_AlgPolicy>(__first, __comp_ref, __n, __first + __start);
-    }
+    } while (__start != 0);
   }
 }
 
diff --git a/libcxx/include/__algorithm/pop_heap.h b/libcxx/include/__algorithm/pop_heap.h
index 6d23830097ff96..60481938c4bc57 100644
--- a/libcxx/include/__algorithm/pop_heap.h
+++ b/libcxx/include/__algorithm/pop_heap.h
@@ -51,9 +51,8 @@ __pop_heap(_RandomAccessIterator __first,
       *__hole = std::move(__top);
     } else {
       *__hole = _IterOps<_AlgPolicy>::__iter_move(__last);
-      ++__hole;
       *__last = std::move(__top);
-      std::__sift_up<_AlgPolicy>(__first, __hole, __comp_ref, __hole - __first);
+      std::__sift_up<_AlgPolicy>(__first, __hole, __comp_ref);
     }
   }
 }
diff --git a/libcxx/include/__algorithm/push_heap.h b/libcxx/include/__algorithm/push_heap.h
index ec0b445f2b70f3..3a1f95c8651ad6 100644
--- a/libcxx/include/__algorithm/push_heap.h
+++ b/libcxx/include/__algorithm/push_heap.h
@@ -29,17 +29,16 @@ _LIBCPP_BEGIN_NAMESPACE_STD
 
 template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
-__sift_up(_RandomAccessIterator __first,
-          _RandomAccessIterator __last,
-          _Compare&& __comp,
-          typename iterator_traits<_RandomAccessIterator>::difference_type __len) {
-  using value_type = typename iterator_traits<_RandomAccessIterator>::value_type;
+__sift_up(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare&& __comp) {
+  using difference_type = typename iterator_traits<_RandomAccessIterator>::difference_type;
+  using value_type      = typename iterator_traits<_RandomAccessIterator>::value_type;
 
+  difference_type __len = __last - __first;
   if (__len > 1) {
-    __len                       = (__len - 2) / 2;
+    __len                       = (__len - 1) / 2;
     _RandomAccessIterator __ptr = __first + __len;
 
-    if (__comp(*__ptr, *--__last)) {
+    if (__comp(*__ptr, *__last)) {
       value_type __t(_IterOps<_AlgPolicy>::__iter_move(__last));
       do {
         *__last = _IterOps<_AlgPolicy>::__iter_move(__ptr);
@@ -58,8 +57,7 @@ __sift_up(_RandomAccessIterator __first,
 template <class _AlgPolicy, class _RandomAccessIterator, class _Compare>
 inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
 __push_heap(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare& __comp) {
-  typename iterator_traits<_RandomAccessIterator>::difference_type __len = __last - __first;
-  std::__sift_up<_AlgPolicy, __comp_ref_type<_Compare> >(std::move(__first), std::move(__last), __comp, __len);
+  std::__sift_up<_AlgPolicy, __comp_ref_type<_Compare> >(std::move(__first), std::move(--__last), __comp);
 }
 
 template <class _RandomAccessIterator, class _Compare>
diff --git a/libcxx/include/__algorithm/sift_down.h b/libcxx/include/__algorithm/sift_down.h
index 42803e30631fb1..0cd5aa32556f3d 100644
--- a/libcxx/include/__algorithm/sift_down.h
+++ b/libcxx/include/__algorithm/sift_down.h
@@ -34,20 +34,22 @@ __sift_down(_RandomAccessIterator __first,
 
   typedef typename iterator_traits<_RandomAccessIterator>::difference_type difference_type;
   typedef typename iterator_traits<_RandomAccessIterator>::value_type value_type;
-  // left-child of __start is at 2 * __start + 1
-  // right-child of __start is at 2 * __start + 2
-  difference_type __child = __start - __first;
 
-  if (__len < 2 || (__len - 2) / 2 < __child)
+  if (__len < 2)
     return;
 
-  __child                         = 2 * __child + 1;
+  // left-child of __start is at 2 * __start + 1
+  // right-child of __start is at 2 * __start + 2
+  difference_type __child         = 2 * (__start - __first) + 1;
   _RandomAccessIterator __child_i = __first + __child;
 
-  if ((__child + 1) < __len && __comp(*__child_i, *(__child_i + difference_type(1)))) {
-    // right-child exists and is greater than left-child
-    ++__child_i;
-    ++__child;
+  if ((__child + 1) < __len) {
+    _RandomAccessIterator __right_i = _Ops::next(__child_i);
+    if (__comp(*__child_i, *__right_i)) {
+      // right-child exists and is greater than left-child
+      __child_i = __right_i;
+      ++__child;
+    }
   }
 
   // check if we are in heap-order
@@ -68,10 +70,13 @@ __sift_down(_RandomAccessIterator __first,
     __child   = 2 * __child + 1;
     __child_i = __first + __child;
 
-    if ((__child + 1) < __len && __comp(*__child_i, *(__child_i + difference_type(1)))) {
-      // right-child exists and is greater than left-child
-      ++__child_i;
-      ++__child;
+    if ((__child + 1) < __len) {
+      _RandomAccessIterator __right_i = _Ops::next(__child_i);
+      if (__comp(*__child_i, *__right_i)) {
+        // right-child exists and is greater than left-child
+        __child_i = __right_i;
+        ++__child;
+      }
     }
 
     // check if we are in heap-order
@@ -84,29 +89,34 @@ _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _RandomAccessIterator __floy
     _RandomAccessIterator __first,
     _Compare&& __comp,
     typename iterator_traits<_RandomAccessIterator>::difference_type __len) {
-  using difference_type = typename iterator_traits<_RandomAccessIterator>::difference_type;
-  _LIBCPP_ASSERT_INTERNAL(__len >= 2, "shouldn't be called unless __len >= 2");
+  _LIBCPP_ASSERT_INTERNAL(__len > 1, "shouldn't be called unless __len > 1");
 
-  _RandomAccessIterator __hole    = __first;
-  _RandomAccessIterator __child_i = __first;
-  difference_type __child         = 0;
+  using _Ops = _IterOps<_AlgPolicy>;
 
-  while (true) {
-    __child_i += difference_type(__child + 1);
-    __child = 2 * __child + 1;
+  typedef typename iterator_traits<_RandomAccessIterator>::difference_type difference_type;
 
-    if ((__child + 1) < __len && __comp(*__child_i, *(__child_i + difference_type(1)))) {
-      // right-child exists and is greater than left-child
-      ++__child_i;
-      ++__child;
+  difference_type __child      = 1;
+  _RandomAccessIterator __hole = __first, __child_i = __first;
+
+  while (true) {
+    __child_i += __child;
+    __child *= 2;
+
+    if (__child < __len) {
+      _RandomAccessIterator __right_i = _Ops::next(__child_i);
+      if (__comp(*__child_i, *__right_i)) {
+        // right-child exists and is greater than left-child
+        __child_i = __right_i;
+        ++__child;
+      }
     }
 
     // swap __hole with its largest child
-    *__hole = _IterOps<_AlgPolicy>::__iter_move(__child_i);
+    *__hole = _Ops::__iter_move(__child_i);
     __hole  = __child_i;
 
     // if __hole is now a leaf, we're done
-    if (__child > (__len - 2) / 2)
+    if (__child > __len / 2)
       return __hole;
   }
 }



More information about the libcxx-commits mailing list