[libcxx-commits] [libcxx] [libc++] Speed up set_intersection() by fast-forwarding over ranges of non-matching elements with one-sided binary search. (PR #75230)

Louis Dionne via libcxx-commits libcxx-commits at lists.llvm.org
Thu Jul 11 07:07:37 PDT 2024


================
@@ -38,10 +43,94 @@ struct __set_intersection_result {
       : __in1_(std::move(__in_iter1)), __in2_(std::move(__in_iter2)), __out_(std::move(__out_iter)) {}
 };
 
-template <class _AlgPolicy, class _Compare, class _InIter1, class _Sent1, class _InIter2, class _Sent2, class _OutIter>
-_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 __set_intersection_result<_InIter1, _InIter2, _OutIter>
+// Helper for __set_intersection() with one-sided binary search: populate result and advance input iterators if they
+// haven't advanced in the last 2 calls. This function is very intimately related to the way it is used and doesn't
+// attempt to abstract that, it's not appropriate for general usage outside of its context.
+template <class _InForwardIter1, class _InForwardIter2, class _OutIter>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 void __set_intersection_add_output_unless(
+    bool __advanced, _InForwardIter1& __first1, _InForwardIter2& __first2, _OutIter& __result, bool& __prev_advanced) {
+  if (__advanced || __prev_advanced) {
+    __prev_advanced = __advanced;
+  } else {
+    *__result = *__first1;
+    ++__result;
+    ++__first1;
+    ++__first2;
+    __prev_advanced = true;
+  }
+}
+
+// With forward iterators we can make multiple passes over the data, allowing the use of one-sided binary search to
+// reduce best-case complexity to log(N). Understanding how we can use binary search and still respect complexity
+// guarantees is _not_ straightforward: the guarantee is "at most 2*(N+M)-1 comparisons", and one-sided binary search
+// will necessarily overshoot depending on the position of the needle in the haystack -- for instance, if we're
+// searching for 3 in (1, 2, 3, 4), we'll check if 3<1, then 3<2, then 3<4, and, finally, 3<3, for a total of 4
+// comparisons, when linear search would have yielded 3. However, because we won't need to perform the intervening
+// reciprocal comparisons (ie 1<3, 2<3, 4<3), that extra comparison doesn't run afoul of the guarantee. Additionally,
+// this type of scenario can only happen for match distances of up to 5 elements, because 2*log2(8) is 6, and we'll
+// still be worse-off at position 5 of an 8-element set. From then onwards these scenarios can't happen. TL;DR: we'll be
+// 1 comparison worse-off compared to the classic linear-searching algorithm if matching position 3 of a set with 4
+// elements, or position 5 if the set has 7 or 8 elements, but we'll never exceed the complexity guarantees from the
+// standard.
+template <class _AlgPolicy,
+          class _Compare,
+          class _InForwardIter1,
+          class _Sent1,
+          class _InForwardIter2,
+          class _Sent2,
+          class _OutIter>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI
+_LIBCPP_CONSTEXPR_SINCE_CXX20 __set_intersection_result<_InForwardIter1, _InForwardIter2, _OutIter>
 __set_intersection(
-    _InIter1 __first1, _Sent1 __last1, _InIter2 __first2, _Sent2 __last2, _OutIter __result, _Compare&& __comp) {
+    _InForwardIter1 __first1,
+    _Sent1 __last1,
+    _InForwardIter2 __first2,
+    _Sent2 __last2,
+    _OutIter __result,
+    _Compare&& __comp,
+    std::forward_iterator_tag,
+    std::forward_iterator_tag) {
+  _LIBCPP_CONSTEXPR std::__identity __proj;
+  bool __prev_advanced = true;
+
+  while (__first2 != __last2) {
+    _InForwardIter1 __first1_next =
+        std::__lower_bound_onesided<_AlgPolicy>(__first1, __last1, *__first2, __comp, __proj);
+    std::swap(__first1_next, __first1);
+    std::__set_intersection_add_output_unless(__first1 != __first1_next, __first1, __first2, __result, __prev_advanced);
----------------
ldionne wrote:

I think the algorithm would be easier to read if we inlined this function at its two call sites. Why do we need to keep track of whether we previously advanced the iterators?

https://github.com/llvm/llvm-project/pull/75230


More information about the libcxx-commits mailing list