[libcxx-commits] [libcxx] [libc++] Add input validation for set_intersection() in debug mode. (PR #101508)
Iuri Chaer via libcxx-commits
libcxx-commits at lists.llvm.org
Sat Nov 9 08:46:54 PST 2024
https://github.com/ichaer updated https://github.com/llvm/llvm-project/pull/101508
>From 5620dce62efa2444af6e8b0ff8b1c4c92c989324 Mon Sep 17 00:00:00 2001
From: Iuri Chaer <ichaer at splunk.com>
Date: Wed, 31 Jul 2024 21:32:42 +0100
Subject: [PATCH 1/4] [libc++] Add input validation for set_intersection() in
debug mode.
The use of one-sided binary search introduced by a066217 changes behaviour on invalid, unsorted input (see https://github.com/llvm/llvm-project/pull/75230#issuecomment-2259408046). Add input validation on `_LIBCPP_HARDENING_MODE_DEBUG` to help users.
* Change interface of `__is_sorted_until()` so that it accepts a sentinel that's of a different type than the beginning iterator, and to ensure it won't try to copy the comparison function object.
* Add one assertion for each input range confirming that they are sorted.
* Stop validating complexity of `set_intersection()` in debug mode, it's hopeless and also not meaningful: there are no complexity guarantees in debug mode, we're happy to trade performance for diagnosability.
* Fix bugs in `ranges_robust_against_differing_projections.pass`: we were using an input range as output for `std::ranges::partial_sort_copy()`, and using projections which return the opposite value means that algorithms requiring a sorted range can't use ranges sorted with ascending values if the comparator is `std::ranges::less`. Added `const` where appropriate to make sure we weren't using inputs as outputs in other places.
---
libcxx/include/__algorithm/is_sorted_until.h | 14 +++---
libcxx/include/__algorithm/set_intersection.h | 11 +++++
.../set_intersection_complexity.pass.cpp | 48 +++++++++----------
...ust_against_differing_projections.pass.cpp | 37 +++++++-------
4 files changed, 61 insertions(+), 49 deletions(-)
diff --git a/libcxx/include/__algorithm/is_sorted_until.h b/libcxx/include/__algorithm/is_sorted_until.h
index 53a49f00de31e8..f84c990ff26751 100644
--- a/libcxx/include/__algorithm/is_sorted_until.h
+++ b/libcxx/include/__algorithm/is_sorted_until.h
@@ -20,18 +20,18 @@
_LIBCPP_BEGIN_NAMESPACE_STD
-template <class _Compare, class _ForwardIterator>
+template <class _Compare, class _ForwardIterator, class _Sent>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator
-__is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) {
+__is_sorted_until(_ForwardIterator __first, _Sent __last, _Compare&& __comp) {
if (__first != __last) {
_ForwardIterator __i = __first;
- while (++__i != __last) {
- if (__comp(*__i, *__first))
- return __i;
- __first = __i;
+ while (++__first != __last) {
+ if (__comp(*__first, *__i))
+ return __first;
+ __i = __first;
}
}
- return __last;
+ return __first;
}
template <class _ForwardIterator, class _Compare>
diff --git a/libcxx/include/__algorithm/set_intersection.h b/libcxx/include/__algorithm/set_intersection.h
index bb0d86cd0f58d2..2cfd984b202a23 100644
--- a/libcxx/include/__algorithm/set_intersection.h
+++ b/libcxx/include/__algorithm/set_intersection.h
@@ -11,12 +11,15 @@
#include <__algorithm/comp.h>
#include <__algorithm/comp_ref_type.h>
+#include <__algorithm/is_sorted_until.h>
#include <__algorithm/iterator_operations.h>
#include <__algorithm/lower_bound.h>
+#include <__assert>
#include <__config>
#include <__functional/identity.h>
#include <__iterator/iterator_traits.h>
#include <__iterator/next.h>
+#include <__type_traits/is_constant_evaluated.h>
#include <__type_traits/is_same.h>
#include <__utility/exchange.h>
#include <__utility/move.h>
@@ -95,6 +98,14 @@ __set_intersection(
_Compare&& __comp,
std::forward_iterator_tag,
std::forward_iterator_tag) {
+#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
+ if (!__libcpp_is_constant_evaluated()) {
+ _LIBCPP_ASSERT_INTERNAL(
+ std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted");
+ _LIBCPP_ASSERT_INTERNAL(
+ std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted");
+ }
+#endif
_LIBCPP_CONSTEXPR std::__identity __proj;
bool __prev_may_be_equal = false;
diff --git a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
index ddf4087ddd6cd1..7c0c394d1f23fe 100644
--- a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
@@ -43,16 +43,15 @@
#include "test_iterators.h"
-namespace {
-
-// __debug_less will perform an additional comparison in an assertion
-static constexpr unsigned std_less_comparison_count_multiplier() noexcept {
-#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
- return 2;
+// debug mode provides no complexity guarantees, testing them would be a waste of effort
+// but we still want to run this test, to ensure we don't trigger any assertions
+#ifdef _LIBCPP_HARDENING_MODE_DEBUG
+# define ASSERT_COMPLEXITY(expression)
#else
- return 1;
+# define ASSERT_COMPLEXITY(expression) assert(expression)
#endif
-}
+
+namespace {
struct [[nodiscard]] OperationCounts {
std::size_t comparisons{};
@@ -60,16 +59,16 @@ struct [[nodiscard]] OperationCounts {
std::size_t proj{};
IteratorOpCounts iterops;
- [[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) {
+ [[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) const noexcept {
return proj >= other.proj && iterops.increments + iterops.decrements + iterops.zero_moves >=
other.iterops.increments + other.iterops.decrements + other.iterops.zero_moves;
}
};
std::array<PerInput, 2> in;
- [[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) {
- return std_less_comparison_count_multiplier() * comparisons >= expect.comparisons &&
- in[0].isNotBetterThan(expect.in[0]) && in[1].isNotBetterThan(expect.in[1]);
+ [[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) const noexcept {
+ return comparisons >= expect.comparisons && in[0].isNotBetterThan(expect.in[0]) &&
+ in[1].isNotBetterThan(expect.in[1]);
}
};
@@ -80,16 +79,17 @@ struct counted_set_intersection_result {
constexpr counted_set_intersection_result() = default;
- constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) : result{contents} {}
+ constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) noexcept
+ : result{contents} {}
- constexpr void assertNotBetterThan(const counted_set_intersection_result& other) {
+ constexpr void assertNotBetterThan(const counted_set_intersection_result& other) const noexcept {
assert(result == other.result);
- assert(opcounts.isNotBetterThan(other.opcounts));
+ ASSERT_COMPLEXITY(opcounts.isNotBetterThan(other.opcounts));
}
};
template <std::size_t ResultSize>
-counted_set_intersection_result(std::array<int, ResultSize>) -> counted_set_intersection_result<ResultSize>;
+counted_set_intersection_result(std::array<int, ResultSize>) noexcept -> counted_set_intersection_result<ResultSize>;
template <template <class...> class InIterType1,
template <class...>
@@ -306,7 +306,7 @@ constexpr bool testComplexityBasic() {
std::array<int, 5> r2{2, 4, 6, 8, 10};
std::array<int, 0> expected{};
- const std::size_t maxOperation = std_less_comparison_count_multiplier() * (2 * (r1.size() + r2.size()) - 1);
+ [[maybe_unused]] const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1;
// std::set_intersection
{
@@ -321,7 +321,7 @@ constexpr bool testComplexityBasic() {
std::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp);
assert(std::ranges::equal(out, expected));
- assert(numberOfComp <= maxOperation);
+ ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
}
// ranges::set_intersection iterator overload
@@ -349,9 +349,9 @@ constexpr bool testComplexityBasic() {
std::ranges::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp, proj1, proj2);
assert(std::ranges::equal(out, expected));
- assert(numberOfComp <= maxOperation);
- assert(numberOfProj1 <= maxOperation);
- assert(numberOfProj2 <= maxOperation);
+ ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
+ ASSERT_COMPLEXITY(numberOfProj1 <= maxOperation);
+ ASSERT_COMPLEXITY(numberOfProj2 <= maxOperation);
}
// ranges::set_intersection range overload
@@ -379,9 +379,9 @@ constexpr bool testComplexityBasic() {
std::ranges::set_intersection(r1, r2, out.data(), comp, proj1, proj2);
assert(std::ranges::equal(out, expected));
- assert(numberOfComp < maxOperation);
- assert(numberOfProj1 < maxOperation);
- assert(numberOfProj2 < maxOperation);
+ ASSERT_COMPLEXITY(numberOfComp < maxOperation);
+ ASSERT_COMPLEXITY(numberOfProj1 < maxOperation);
+ ASSERT_COMPLEXITY(numberOfProj2 < maxOperation);
}
return true;
}
diff --git a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
index 82792249ef5c74..a5f1ceee6fb3df 100644
--- a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
+++ b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
@@ -40,19 +40,20 @@ constexpr bool test_all() {
constexpr auto operator<=>(const A&) const = default;
};
- std::array in = {1, 2, 3};
- std::array in2 = {A{4}, A{5}, A{6}};
+ const std::array in = {1, 2, 3};
+ const std::array in2 = {A{4}, A{5}, A{6}};
std::array output = {7, 8, 9, 10, 11, 12};
auto out = output.begin();
std::array output2 = {A{7}, A{8}, A{9}, A{10}, A{11}, A{12}};
auto out2 = output2.begin();
- std::ranges::equal_to eq;
- std::ranges::less less;
- auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
- auto proj1 = [](int x) { return x * -1; };
- auto proj2 = [](A a) { return a.x * -1; };
+ const std::ranges::equal_to eq;
+ const std::ranges::less less;
+ const std::ranges::greater greater;
+ const auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
+ const auto proj1 = [](int x) { return x * -1; };
+ const auto proj2 = [](A a) { return a.x * -1; };
#if TEST_STD_VER >= 23
test(std::ranges::ends_with, in, in2, eq, proj1, proj2);
@@ -67,17 +68,17 @@ constexpr bool test_all() {
test(std::ranges::find_end, in, in2, eq, proj1, proj2);
test(std::ranges::transform, in, in2, out, sum, proj1, proj2);
test(std::ranges::transform, in, in2, out2, sum, proj1, proj2);
- test(std::ranges::partial_sort_copy, in, in2, less, proj1, proj2);
- test(std::ranges::merge, in, in2, out, less, proj1, proj2);
- test(std::ranges::merge, in, in2, out2, less, proj1, proj2);
- test(std::ranges::set_intersection, in, in2, out, less, proj1, proj2);
- test(std::ranges::set_intersection, in, in2, out2, less, proj1, proj2);
- test(std::ranges::set_difference, in, in2, out, less, proj1, proj2);
- test(std::ranges::set_difference, in, in2, out2, less, proj1, proj2);
- test(std::ranges::set_symmetric_difference, in, in2, out, less, proj1, proj2);
- test(std::ranges::set_symmetric_difference, in, in2, out2, less, proj1, proj2);
- test(std::ranges::set_union, in, in2, out, less, proj1, proj2);
- test(std::ranges::set_union, in, in2, out2, less, proj1, proj2);
+ test(std::ranges::partial_sort_copy, in, output, less, proj1, proj2);
+ test(std::ranges::merge, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::merge, in, in2, out2, greater, proj1, proj2);
+ test(std::ranges::set_intersection, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::set_intersection, in, in2, out2, greater, proj1, proj2);
+ test(std::ranges::set_difference, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::set_difference, in, in2, out2, greater, proj1, proj2);
+ test(std::ranges::set_symmetric_difference, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::set_symmetric_difference, in, in2, out2, greater, proj1, proj2);
+ test(std::ranges::set_union, in, in2, out, greater, proj1, proj2);
+ test(std::ranges::set_union, in, in2, out2, greater, proj1, proj2);
#if TEST_STD_VER > 20
test(std::ranges::starts_with, in, in2, eq, proj1, proj2);
#endif
>From 7f2beae53be0efb43c08bfd88239d9ed6410966d Mon Sep 17 00:00:00 2001
From: Iuri Chaer <ichaer at splunk.com>
Date: Thu, 1 Aug 2024 17:19:03 +0100
Subject: [PATCH 2/4] Oops, missed `std::ranges::includes()`! It also requires
sorted input.
---
.../ranges_robust_against_differing_projections.pass.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
index a5f1ceee6fb3df..cb20b37dc108d7 100644
--- a/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
+++ b/libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp
@@ -61,7 +61,7 @@ constexpr bool test_all() {
test(std::ranges::equal, in, in2, eq, proj1, proj2);
test(std::ranges::lexicographical_compare, in, in2, eq, proj1, proj2);
test(std::ranges::is_permutation, in, in2, eq, proj1, proj2);
- test(std::ranges::includes, in, in2, less, proj1, proj2);
+ test(std::ranges::includes, in, in2, greater, proj1, proj2);
test(std::ranges::find_first_of, in, in2, eq, proj1, proj2);
test(std::ranges::mismatch, in, in2, eq, proj1, proj2);
test(std::ranges::search, in, in2, eq, proj1, proj2);
>From 3006f134c0aa3a635debaad7fece4a99de6ba791 Mon Sep 17 00:00:00 2001
From: Iuri Chaer <ichaer at splunk.com>
Date: Fri, 2 Aug 2024 13:58:09 +0100
Subject: [PATCH 3/4] Addressing @alexfh's feedback.
---
.../set.intersection/set_intersection_complexity.pass.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
index 7c0c394d1f23fe..c21acb9d2b0623 100644
--- a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
@@ -43,10 +43,9 @@
#include "test_iterators.h"
-// debug mode provides no complexity guarantees, testing them would be a waste of effort
-// but we still want to run this test, to ensure we don't trigger any assertions
+// Debug mode provides no complexity guarantees, testing them would be a waste of effort.
#ifdef _LIBCPP_HARDENING_MODE_DEBUG
-# define ASSERT_COMPLEXITY(expression)
+# define ASSERT_COMPLEXITY(expression) (void)(expression)
#else
# define ASSERT_COMPLEXITY(expression) assert(expression)
#endif
@@ -306,7 +305,7 @@ constexpr bool testComplexityBasic() {
std::array<int, 5> r2{2, 4, 6, 8, 10};
std::array<int, 0> expected{};
- [[maybe_unused]] const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1;
+ const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1;
// std::set_intersection
{
>From 22c9f1cbe05b37c2e1761e085b834abbff4a32a0 Mon Sep 17 00:00:00 2001
From: Iuri Chaer <ichaer at splunk.com>
Date: Sat, 9 Nov 2024 16:44:26 +0000
Subject: [PATCH 4/4] * s/__i/__prev/ in __is_sorted_until() * remove bogus
__libcpp_is_constant_evaluated() check in __set_intersection() * change
comment about checking number of steps in debug mode *
s/_LIBCPP_ASSERT_SEMANTIC_REQUIREMENT/_LIBCPP_ASSERT_INTERNAL/ for input
validation
---
libcxx/include/__algorithm/is_sorted_until.h | 6 +++---
libcxx/include/__algorithm/set_intersection.h | 10 ++++------
.../set_intersection_complexity.pass.cpp | 2 +-
3 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/libcxx/include/__algorithm/is_sorted_until.h b/libcxx/include/__algorithm/is_sorted_until.h
index f84c990ff26751..218adf0be4a7ea 100644
--- a/libcxx/include/__algorithm/is_sorted_until.h
+++ b/libcxx/include/__algorithm/is_sorted_until.h
@@ -24,11 +24,11 @@ template <class _Compare, class _ForwardIterator, class _Sent>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator
__is_sorted_until(_ForwardIterator __first, _Sent __last, _Compare&& __comp) {
if (__first != __last) {
- _ForwardIterator __i = __first;
+ _ForwardIterator __prev = __first;
while (++__first != __last) {
- if (__comp(*__first, *__i))
+ if (__comp(*__first, *__prev))
return __first;
- __i = __first;
+ __prev = __first;
}
}
return __first;
diff --git a/libcxx/include/__algorithm/set_intersection.h b/libcxx/include/__algorithm/set_intersection.h
index 2cfd984b202a23..4e9ff169d6d603 100644
--- a/libcxx/include/__algorithm/set_intersection.h
+++ b/libcxx/include/__algorithm/set_intersection.h
@@ -99,12 +99,10 @@ __set_intersection(
std::forward_iterator_tag,
std::forward_iterator_tag) {
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
- if (!__libcpp_is_constant_evaluated()) {
- _LIBCPP_ASSERT_INTERNAL(
- std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted");
- _LIBCPP_ASSERT_INTERNAL(
- std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted");
- }
+ _LIBCPP_ASSERT_SEMANTIC_REQUIREMENT(
+ std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted");
+ _LIBCPP_ASSERT_SEMANTIC_REQUIREMENT(
+ std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted");
#endif
_LIBCPP_CONSTEXPR std::__identity __proj;
bool __prev_may_be_equal = false;
diff --git a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
index c21acb9d2b0623..9caa1cd27e8cfb 100644
--- a/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp
@@ -43,7 +43,7 @@
#include "test_iterators.h"
-// Debug mode provides no complexity guarantees, testing them would be a waste of effort.
+// We don't check number of operations in Debug mode because they are not stable enough due to additional validations.
#ifdef _LIBCPP_HARDENING_MODE_DEBUG
# define ASSERT_COMPLEXITY(expression) (void)(expression)
#else
More information about the libcxx-commits
mailing list