[libcxx-commits] [libcxx] [libc++] Add input validation for set_intersection() in debug mode. (PR #101508)

via libcxx-commits libcxx-commits at lists.llvm.org
Thu Aug 1 09:14:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-libcxx

Author: Iuri Chaer (ichaer)

<details>
<summary>Changes</summary>

The use of one-sided binary search introduced by a066217 changes behaviour on invalid, unsorted input (see https://github.com/llvm/llvm-project/pull/75230#issuecomment-2259408046). Add input validation on `_LIBCPP_HARDENING_MODE_DEBUG` to help users.

* Change interface of `__is_sorted_until()` so that it accepts a sentinel that's of a different type than the beginning iterator, and to ensure it won't try to copy the comparison function object.
* Add one assertion for each input range confirming that they are sorted.
* Stop validating complexity of `set_intersection()` in debug mode, it's hopeless and also not meaningful: there are no complexity guarantees in debug mode, we're happy to trade performance for diagnosability.
* Fix bugs in `ranges_robust_against_differing_projections.pass`: we were using an input range as output for `std::ranges::partial_sort_copy()`, and using projections which return the opposite value means that algorithms requiring a sorted range can't use ranges sorted with ascending values if the comparator is `std::ranges::less`. Added `const` where appropriate to make sure we weren't using inputs as outputs in other places.

---
Full diff: https://github.com/llvm/llvm-project/pull/101508.diff


4 Files Affected:

- (modified) libcxx/include/__algorithm/is_sorted_until.h (+7-7) 
- (modified) libcxx/include/__algorithm/set_intersection.h (+11) 
- (modified) libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp (+24-24) 
- (modified) libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp (+19-18) 


``````````diff
diff --git a/libcxx/include/__algorithm/is_sorted_until.h b/libcxx/include/__algorithm/is_sorted_until.h
index 53a49f00de31e..f84c990ff2675 100644
--- a/libcxx/include/__algorithm/is_sorted_until.h
+++ b/libcxx/include/__algorithm/is_sorted_until.h
@@ -20,18 +20,18 @@
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
-template <class _Compare, class _ForwardIterator>
+template <class _Compare, class _ForwardIterator, class _Sent>
 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator
-__is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) {
+__is_sorted_until(_ForwardIterator __first, _Sent __last, _Compare&& __comp) {
   if (__first != __last) {
     _ForwardIterator __i = __first;
-    while (++__i != __last) {
-      if (__comp(*__i, *__first))
-        return __i;
-      __first = __i;
+    while (++__first != __last) {
+      if (__comp(*__first, *__i))
+        return __first;
+      __i = __first;
     }
   }
-  return __last;
+  return __first;
 }
 
 template <class _ForwardIterator, class _Compare>
diff --git a/libcxx/include/__algorithm/set_intersection.h b/libcxx/include/__algorithm/set_intersection.h
index bb0d86cd0f58d..2cfd984b202a2 100644
--- a/libcxx/include/__algorithm/set_intersection.h
+++ b/libcxx/include/__algorithm/set_intersection.h
@@ -11,12 +11,15 @@
 
 #include <__algorithm/comp.h>
 #include <__algorithm/comp_ref_type.h>
+#include <__algorithm/is_sorted_until.h>
 #include <__algorithm/iterator_operations.h>
 #include <__algorithm/lower_bound.h>
+#include <__assert>
 #include <__config>
 #include <__functional/identity.h>
 #include <__iterator/iterator_traits.h>
 #include <__iterator/next.h>
+#include <__type_traits/is_constant_evaluated.h>
 #include <__type_traits/is_same.h>
 #include <__utility/exchange.h>
 #include <__utility/move.h>
@@ -95,6 +98,14 @@ __set_intersection(
     _Compare&& __comp,
     std::forward_iterator_tag,
     std::forward_iterator_tag) {
+#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
+  if (!__libcpp_is_constant_evaluated()) {
+    _LIBCPP_ASSERT_INTERNAL(
+        std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted");
+    _LIBCPP_ASSERT_INTERNAL(
+        std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted");
+  }
+#endif
   _LIBCPP_CONSTEXPR std::__identity __proj;
   bool __prev_may_be_equal = false;
 
diff --git a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
index ddf4087ddd6cd..7c0c394d1f23f 100644
--- a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
@@ -43,16 +43,15 @@
 
 #include "test_iterators.h"
 
-namespace {
-
-// __debug_less will perform an additional comparison in an assertion
-static constexpr unsigned std_less_comparison_count_multiplier() noexcept {
-#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
-  return 2;
+// debug mode provides no complexity guarantees, testing them would be a waste of effort
+// but we still want to run this test, to ensure we don't trigger any assertions
+#ifdef _LIBCPP_HARDENING_MODE_DEBUG
+#  define ASSERT_COMPLEXITY(expression)
 #else
-  return 1;
+#  define ASSERT_COMPLEXITY(expression) assert(expression)
 #endif
-}
+
+namespace {
 
 struct [[nodiscard]] OperationCounts {
   std::size_t comparisons{};
@@ -60,16 +59,16 @@ struct [[nodiscard]] OperationCounts {
     std::size_t proj{};
     IteratorOpCounts iterops;
 
-    [[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) {
+    [[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) const noexcept {
       return proj >= other.proj && iterops.increments + iterops.decrements + iterops.zero_moves >=
                                        other.iterops.increments + other.iterops.decrements + other.iterops.zero_moves;
     }
   };
   std::array<PerInput, 2> in;
 
-  [[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) {
-    return std_less_comparison_count_multiplier() * comparisons >= expect.comparisons &&
-           in[0].isNotBetterThan(expect.in[0]) && in[1].isNotBetterThan(expect.in[1]);
+  [[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) const noexcept {
+    return comparisons >= expect.comparisons && in[0].isNotBetterThan(expect.in[0]) &&
+           in[1].isNotBetterThan(expect.in[1]);
   }
 };
 
@@ -80,16 +79,17 @@ struct counted_set_intersection_result {
 
   constexpr counted_set_intersection_result() = default;
 
-  constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) : result{contents} {}
+  constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) noexcept
+      : result{contents} {}
 
-  constexpr void assertNotBetterThan(const counted_set_intersection_result& other) {
+  constexpr void assertNotBetterThan(const counted_set_intersection_result& other) const noexcept {
     assert(result == other.result);
-    assert(opcounts.isNotBetterThan(other.opcounts));
+    ASSERT_COMPLEXITY(opcounts.isNotBetterThan(other.opcounts));
   }
 };
 
 template <std::size_t ResultSize>
-counted_set_intersection_result(std::array<int, ResultSize>) -> counted_set_intersection_result<ResultSize>;
+counted_set_intersection_result(std::array<int, ResultSize>) noexcept -> counted_set_intersection_result<ResultSize>;
 
 template <template <class...> class InIterType1,
           template <class...>
@@ -306,7 +306,7 @@ constexpr bool testComplexityBasic() {
   std::array<int, 5> r2{2, 4, 6, 8, 10};
   std::array<int, 0> expected{};
 
-  const std::size_t maxOperation = std_less_comparison_count_multiplier() * (2 * (r1.size() + r2.size()) - 1);
+  [[maybe_unused]] const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1;
 
   // std::set_intersection
   {
@@ -321,7 +321,7 @@ constexpr bool testComplexityBasic() {
     std::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp);
 
     assert(std::ranges::equal(out, expected));
-    assert(numberOfComp <= maxOperation);
+    ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
   }
 
   // ranges::set_intersection iterator overload
@@ -349,9 +349,9 @@ constexpr bool testComplexityBasic() {
     std::ranges::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp, proj1, proj2);
 
     assert(std::ranges::equal(out, expected));
-    assert(numberOfComp <= maxOperation);
-    assert(numberOfProj1 <= maxOperation);
-    assert(numberOfProj2 <= maxOperation);
+    ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
+    ASSERT_COMPLEXITY(numberOfProj1 <= maxOperation);
+    ASSERT_COMPLEXITY(numberOfProj2 <= maxOperation);
   }
 
   // ranges::set_intersection range overload
@@ -379,9 +379,9 @@ constexpr bool testComplexityBasic() {
     std::ranges::set_intersection(r1, r2, out.data(), comp, proj1, proj2);
 
     assert(std::ranges::equal(out, expected));
-    assert(numberOfComp < maxOperation);
-    assert(numberOfProj1 < maxOperation);
-    assert(numberOfProj2 < maxOperation);
+    ASSERT_COMPLEXITY(numberOfComp < maxOperation);
+    ASSERT_COMPLEXITY(numberOfProj1 < maxOperation);
+    ASSERT_COMPLEXITY(numberOfProj2 < maxOperation);
   }
   return true;
 }
diff --git a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
index 82792249ef5c7..a5f1ceee6fb3d 100644
--- a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
+++ b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
@@ -40,19 +40,20 @@ constexpr bool test_all() {
     constexpr auto operator<=>(const A&) const = default;
   };
 
-  std::array in = {1, 2, 3};
-  std::array in2 = {A{4}, A{5}, A{6}};
+  const std::array in  = {1, 2, 3};
+  const std::array in2 = {A{4}, A{5}, A{6}};
 
   std::array output = {7, 8, 9, 10, 11, 12};
   auto out = output.begin();
   std::array output2 = {A{7}, A{8}, A{9}, A{10}, A{11}, A{12}};
   auto out2 = output2.begin();
 
-  std::ranges::equal_to eq;
-  std::ranges::less less;
-  auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
-  auto proj1 = [](int x) { return x * -1; };
-  auto proj2 = [](A a) { return a.x * -1; };
+  const std::ranges::equal_to eq;
+  const std::ranges::less less;
+  const std::ranges::greater greater;
+  const auto sum   = [](int lhs, A rhs) { return lhs + rhs.x; };
+  const auto proj1 = [](int x) { return x * -1; };
+  const auto proj2 = [](A a) { return a.x * -1; };
 
 #if TEST_STD_VER >= 23
   test(std::ranges::ends_with, in, in2, eq, proj1, proj2);
@@ -67,17 +68,17 @@ constexpr bool test_all() {
   test(std::ranges::find_end, in, in2, eq, proj1, proj2);
   test(std::ranges::transform, in, in2, out, sum, proj1, proj2);
   test(std::ranges::transform, in, in2, out2, sum, proj1, proj2);
-  test(std::ranges::partial_sort_copy, in, in2, less, proj1, proj2);
-  test(std::ranges::merge, in, in2, out, less, proj1, proj2);
-  test(std::ranges::merge, in, in2, out2, less, proj1, proj2);
-  test(std::ranges::set_intersection, in, in2, out, less, proj1, proj2);
-  test(std::ranges::set_intersection, in, in2, out2, less, proj1, proj2);
-  test(std::ranges::set_difference, in, in2, out, less, proj1, proj2);
-  test(std::ranges::set_difference, in, in2, out2, less, proj1, proj2);
-  test(std::ranges::set_symmetric_difference, in, in2, out, less, proj1, proj2);
-  test(std::ranges::set_symmetric_difference, in, in2, out2, less, proj1, proj2);
-  test(std::ranges::set_union, in, in2, out, less, proj1, proj2);
-  test(std::ranges::set_union, in, in2, out2, less, proj1, proj2);
+  test(std::ranges::partial_sort_copy, in, output, less, proj1, proj2);
+  test(std::ranges::merge, in, in2, out, greater, proj1, proj2);
+  test(std::ranges::merge, in, in2, out2, greater, proj1, proj2);
+  test(std::ranges::set_intersection, in, in2, out, greater, proj1, proj2);
+  test(std::ranges::set_intersection, in, in2, out2, greater, proj1, proj2);
+  test(std::ranges::set_difference, in, in2, out, greater, proj1, proj2);
+  test(std::ranges::set_difference, in, in2, out2, greater, proj1, proj2);
+  test(std::ranges::set_symmetric_difference, in, in2, out, greater, proj1, proj2);
+  test(std::ranges::set_symmetric_difference, in, in2, out2, greater, proj1, proj2);
+  test(std::ranges::set_union, in, in2, out, greater, proj1, proj2);
+  test(std::ranges::set_union, in, in2, out2, greater, proj1, proj2);
 #if TEST_STD_VER > 20
   test(std::ranges::starts_with, in, in2, eq, proj1, proj2);
 #endif

``````````

</details>


https://github.com/llvm/llvm-project/pull/101508


More information about the libcxx-commits mailing list