[libcxx-commits] [libcxx] [libc++] Speed up set_intersection() by fast-forwarding over ranges of non-matching elements with one-sided binary search. (PR #75230)

Iuri Chaer via libcxx-commits libcxx-commits at lists.llvm.org
Mon May 27 14:42:13 PDT 2024


================
@@ -272,6 +278,234 @@ constexpr void runAllIteratorPermutationsTests() {
   static_assert(withAllPermutationsOfInIter1AndInIter2<contiguous_iterator<int*>>());
 }
 
+namespace {
+struct [[nodiscard]] OperationCounts {
+  std::size_t comparisons{};
+  struct PerInput {
+    std::size_t proj{};
+    std::size_t iterator_strides{};
+    std::ptrdiff_t iterator_displacement{};
+
+    // IGNORES proj!
+    [[nodiscard]] constexpr bool operator==(const PerInput& o) const {
+      return iterator_strides == o.iterator_strides && iterator_displacement == o.iterator_displacement;
+    }
+
+    [[nodiscard]] constexpr bool matchesExpectation(const PerInput& expect) {
+      return proj <= expect.proj && iterator_strides <= expect.iterator_strides &&
+             iterator_displacement <= expect.iterator_displacement;
+    }
+  };
+  std::array<PerInput, 2> in;
+
+  [[nodiscard]] constexpr bool matchesExpectation(const OperationCounts& expect) {
+    // __debug_less will perform an additional comparison in an assertion
+    constexpr unsigned comparison_multiplier =
+#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
+        2;
+#else
+        1;
+#endif
+    return comparisons <= comparison_multiplier * expect.comparisons && in[0].matchesExpectation(expect.in[0]) &&
+           in[1].matchesExpectation(expect.in[1]);
+  }
+
+  [[nodiscard]] constexpr bool operator==(const OperationCounts& o) const {
+    return comparisons == o.comparisons && std::ranges::equal(in, o.in);
+  }
+};
+} // namespace
+
+template <template <class...> class In1,
+          template <class...>
+          class In2,
+          class Out,
+          std::size_t N1,
+          std::size_t N2,
+          std::size_t N3>
+constexpr void testSetIntersectionAndReturnOpCounts(
+    std::array<int, N1> in1,
+    std::array<int, N2> in2,
+    std::array<int, N3> expected,
+    const OperationCounts& expectedOpCounts) {
+  OperationCounts ops;
+
+  const auto comp = [&ops](int x, int y) {
+    ++ops.comparisons;
+    return x < y;
+  };
+
+  std::array<int, N3> out;
+
+  stride_counting_iterator b1(
+      In1<decltype(in1.begin())>(in1.begin()), &ops.in[0].iterator_strides, &ops.in[0].iterator_displacement);
+  stride_counting_iterator e1(
+      In1<decltype(in1.end()) >(in1.end()), &ops.in[0].iterator_strides, &ops.in[0].iterator_displacement);
+  stride_counting_iterator b2(
+      In2<decltype(in2.begin())>(in2.begin()), &ops.in[1].iterator_strides, &ops.in[1].iterator_displacement);
+  stride_counting_iterator e2(
+      In2<decltype(in2.end()) >(in2.end()), &ops.in[1].iterator_strides, &ops.in[1].iterator_displacement);
+
----------------
ichaer wrote:

I went for option number 2.

https://github.com/llvm/llvm-project/pull/75230


More information about the libcxx-commits mailing list