[libcxx-commits] [libcxx] [libc++] Add input validation for set_intersection() in debug mode. (PR #101508)
via libcxx-commits
libcxx-commits at lists.llvm.org
Thu Aug 1 09:14:53 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-libcxx
Author: Iuri Chaer (ichaer)
<details>
<summary>Changes</summary>
The use of one-sided binary search introduced by a066217 changes behaviour on invalid, unsorted input (see https://github.com/llvm/llvm-project/pull/75230#issuecomment-2259408046). Add input validation on `_LIBCPP_HARDENING_MODE_DEBUG` to help users.
* Change interface of `__is_sorted_until()` so that it accepts a sentinel that's of a different type than the beginning iterator, and to ensure it won't try to copy the comparison function object.
* Add one assertion for each input range confirming that they are sorted.
* Stop validating complexity of `set_intersection()` in debug mode, it's hopeless and also not meaningful: there are no complexity guarantees in debug mode, we're happy to trade performance for diagnosability.
* Fix bugs in `ranges_robust_against_differing_projections.pass`: we were using an input range as output for `std::ranges::partial_sort_copy()`, and using projections which return the opposite value means that algorithms requiring a sorted range can't use ranges sorted with ascending values if the comparator is `std::ranges::less`. Added `const` where appropriate to make sure we weren't using inputs as outputs in other places.
---
Full diff: https://github.com/llvm/llvm-project/pull/101508.diff
4 Files Affected:
- (modified) libcxx/include/__algorithm/is_sorted_until.h (+7-7)
- (modified) libcxx/include/__algorithm/set_intersection.h (+11)
- (modified) libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp (+24-24)
- (modified) libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp (+19-18)
``````````diff
diff --git a/libcxx/include/__algorithm/is_sorted_until.h b/libcxx/include/__algorithm/is_sorted_until.h
index 53a49f00de31e..f84c990ff2675 100644
--- a/libcxx/include/__algorithm/is_sorted_until.h
+++ b/libcxx/include/__algorithm/is_sorted_until.h
@@ -20,18 +20,18 @@
_LIBCPP_BEGIN_NAMESPACE_STD
-template <class _Compare, class _ForwardIterator>
+template <class _Compare, class _ForwardIterator, class _Sent>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator
-__is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) {
+__is_sorted_until(_ForwardIterator __first, _Sent __last, _Compare&& __comp) {
if (__first != __last) {
_ForwardIterator __i = __first;
- while (++__i != __last) {
- if (__comp(*__i, *__first))
- return __i;
- __first = __i;
+ while (++__first != __last) {
+ if (__comp(*__first, *__i))
+ return __first;
+ __i = __first;
}
}
- return __last;
+ return __first;
}
template <class _ForwardIterator, class _Compare>
diff --git a/libcxx/include/__algorithm/set_intersection.h b/libcxx/include/__algorithm/set_intersection.h
index bb0d86cd0f58d..2cfd984b202a2 100644
--- a/libcxx/include/__algorithm/set_intersection.h
+++ b/libcxx/include/__algorithm/set_intersection.h
@@ -11,12 +11,15 @@
#include <__algorithm/comp.h>
#include <__algorithm/comp_ref_type.h>
+#include <__algorithm/is_sorted_until.h>
#include <__algorithm/iterator_operations.h>
#include <__algorithm/lower_bound.h>
+#include <__assert>
#include <__config>
#include <__functional/identity.h>
#include <__iterator/iterator_traits.h>
#include <__iterator/next.h>
+#include <__type_traits/is_constant_evaluated.h>
#include <__type_traits/is_same.h>
#include <__utility/exchange.h>
#include <__utility/move.h>
@@ -95,6 +98,14 @@ __set_intersection(
_Compare&& __comp,
std::forward_iterator_tag,
std::forward_iterator_tag) {
+#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
+ if (!__libcpp_is_constant_evaluated()) {
+ _LIBCPP_ASSERT_INTERNAL(
+ std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted");
+ _LIBCPP_ASSERT_INTERNAL(
+ std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted");
+ }
+#endif
_LIBCPP_CONSTEXPR std::__identity __proj;
bool __prev_may_be_equal = false;
diff --git a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
index ddf4087ddd6cd..7c0c394d1f23f 100644
--- a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
@@ -43,16 +43,15 @@
#include "test_iterators.h"
-namespace {
-
-// __debug_less will perform an additional comparison in an assertion
-static constexpr unsigned std_less_comparison_count_multiplier() noexcept {
-#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
- return 2;
+// debug mode provides no complexity guarantees, testing them would be a waste of effort
+// but we still want to run this test, to ensure we don't trigger any assertions
+#ifdef _LIBCPP_HARDENING_MODE_DEBUG
+# define ASSERT_COMPLEXITY(expression)
#else
- return 1;
+# define ASSERT_COMPLEXITY(expression) assert(expression)
#endif
-}
+
+namespace {
struct [[nodiscard]] OperationCounts {
std::size_t comparisons{};
@@ -60,16 +59,16 @@ struct [[nodiscard]] OperationCounts {
std::size_t proj{};
IteratorOpCounts iterops;
- [[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) {
+ [[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) const noexcept {
return proj >= other.proj && iterops.increments + iterops.decrements + iterops.zero_moves >=
other.iterops.increments + other.iterops.decrements + other.iterops.zero_moves;
}
};
std::array<PerInput, 2> in;
- [[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) {
- return std_less_comparison_count_multiplier() * comparisons >= expect.comparisons &&
- in[0].isNotBetterThan(expect.in[0]) && in[1].isNotBetterThan(expect.in[1]);
+ [[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) const noexcept {
+ return comparisons >= expect.comparisons && in[0].isNotBetterThan(expect.in[0]) &&
+ in[1].isNotBetterThan(expect.in[1]);
}
};
@@ -80,16 +79,17 @@ struct counted_set_intersection_result {
constexpr counted_set_intersection_result() = default;
- constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) : result{contents} {}
+ constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) noexcept
+ : result{contents} {}
- constexpr void assertNotBetterThan(const counted_set_intersection_result& other) {
+ constexpr void assertNotBetterThan(const counted_set_intersection_result& other) const noexcept {
assert(result == other.result);
- assert(opcounts.isNotBetterThan(other.opcounts));
+ ASSERT_COMPLEXITY(opcounts.isNotBetterThan(other.opcounts));
}
};
template <std::size_t ResultSize>
-counted_set_intersection_result(std::array<int, ResultSize>) -> counted_set_intersection_result<ResultSize>;
+counted_set_intersection_result(std::array<int, ResultSize>) noexcept -> counted_set_intersection_result<ResultSize>;
template <template <class...> class InIterType1,
template <class...>
@@ -306,7 +306,7 @@ constexpr bool testComplexityBasic() {
std::array<int, 5> r2{2, 4, 6, 8, 10};
std::array<int, 0> expected{};
- const std::size_t maxOperation = std_less_comparison_count_multiplier() * (2 * (r1.size() + r2.size()) - 1);
+ [[maybe_unused]] const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1;
// std::set_intersection
{
@@ -321,7 +321,7 @@ constexpr bool testComplexityBasic() {
std::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp);
assert(std::ranges::equal(out, expected));
- assert(numberOfComp <= maxOperation);
+ ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
}
// ranges::set_intersection iterator overload
@@ -349,9 +349,9 @@ constexpr bool testComplexityBasic() {
std::ranges::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp, proj1, proj2);
assert(std::ranges::equal(out, expected));
- assert(numberOfComp <= maxOperation);
- assert(numberOfProj1 <= maxOperation);
- assert(numberOfProj2 <= maxOperation);
+ ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
+ ASSERT_COMPLEXITY(numberOfProj1 <= maxOperation);
+ ASSERT_COMPLEXITY(numberOfProj2 <= maxOperation);
}
// ranges::set_intersection range overload
@@ -379,9 +379,9 @@ constexpr bool testComplexityBasic() {
std::ranges::set_intersection(r1, r2, out.data(), comp, proj1, proj2);
assert(std::ranges::equal(out, expected));
- assert(numberOfComp < maxOperation);
- assert(numberOfProj1 < maxOperation);
- assert(numberOfProj2 < maxOperation);
+ ASSERT_COMPLEXITY(numberOfComp < maxOperation);
+ ASSERT_COMPLEXITY(numberOfProj1 < maxOperation);
+ ASSERT_COMPLEXITY(numberOfProj2 < maxOperation);
}
return true;
}
diff --git a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
index 82792249ef5c7..a5f1ceee6fb3d 100644
--- a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
+++ b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
@@ -40,19 +40,20 @@ constexpr bool test_all() {
constexpr auto operator<=>(const A&) const = default;
};
- std::array in = {1, 2, 3};
- std::array in2 = {A{4}, A{5}, A{6}};
+ const std::array in = {1, 2, 3};
+ const std::array in2 = {A{4}, A{5}, A{6}};
std::array output = {7, 8, 9, 10, 11, 12};
auto out = output.begin();
std::array output2 = {A{7}, A{8}, A{9}, A{10}, A{11}, A{12}};
auto out2 = output2.begin();
- std::ranges::equal_to eq;
- std::ranges::less less;
- auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
- auto proj1 = [](int x) { return x * -1; };
- auto proj2 = [](A a) { return a.x * -1; };
+ const std::ranges::equal_to eq;
+ const std::ranges::less less;
+ const std::ranges::greater greater;
+ const auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
+ const auto proj1 = [](int x) { return x * -1; };
+ const auto proj2 = [](A a) { return a.x * -1; };
#if TEST_STD_VER >= 23
test(std::ranges::ends_with, in, in2, eq, proj1, proj2);
@@ -67,17 +68,17 @@ constexpr bool test_all() {
test(std::ranges::find_end, in, in2, eq, proj1, proj2);
test(std::ranges::transform, in, in2, out, sum, proj1, proj2);
test(std::ranges::transform, in, in2, out2, sum, proj1, proj2);
- test(std::ranges::partial_sort_copy, in, in2, less, proj1, proj2);
- test(std::ranges::merge, in, in2, out, less, proj1, proj2);
- test(std::ranges::merge, in, in2, out2, less, proj1, proj2);
- test(std::ranges::set_intersection, in, in2, out, less, proj1, proj2);
- test(std::ranges::set_intersection, in, in2, out2, less, proj1, proj2);
- test(std::ranges::set_difference, in, in2, out, less, proj1, proj2);
- test(std::ranges::set_difference, in, in2, out2, less, proj1, proj2);
- test(std::ranges::set_symmetric_difference, in, in2, out, less, proj1, proj2);
- test(std::ranges::set_symmetric_difference, in, in2, out2, less, proj1, proj2);
- test(std::ranges::set_union, in, in2, out, less, proj1, proj2);
- test(std::ranges::set_union, in, in2, out2, less, proj1, proj2);
+ test(std::ranges::partial_sort_copy, in, output, less, proj1, proj2);
+ test(std::ranges::merge, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::merge, in, in2, out2, greater, proj1, proj2);
+ test(std::ranges::set_intersection, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::set_intersection, in, in2, out2, greater, proj1, proj2);
+ test(std::ranges::set_difference, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::set_difference, in, in2, out2, greater, proj1, proj2);
+ test(std::ranges::set_symmetric_difference, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::set_symmetric_difference, in, in2, out2, greater, proj1, proj2);
+ test(std::ranges::set_union, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::set_union, in, in2, out2, greater, proj1, proj2);
#if TEST_STD_VER > 20
test(std::ranges::starts_with, in, in2, eq, proj1, proj2);
#endif
``````````
</details>
https://github.com/llvm/llvm-project/pull/101508
More information about the libcxx-commits
mailing list