[libcxx-commits] [libcxx] [libc++] Add assertions for potential OOB reads in std::nth_element (PR #67023)
via libcxx-commits
libcxx-commits at lists.llvm.org
Thu Sep 21 07:44:29 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-libcxx
<details>
<summary>Changes</summary>
Same as https://reviews.llvm.org/D147089 but for std::nth_element
---
Full diff: https://github.com/llvm/llvm-project/pull/67023.diff
3 Files Affected:
- (modified) libcxx/include/__algorithm/nth_element.h (+22-6)
- (modified) libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp (+60-15)
- (modified) libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h (+68-1)
``````````diff
diff --git a/libcxx/include/__algorithm/nth_element.h b/libcxx/include/__algorithm/nth_element.h
index dbacf58f9ecdbc4..37e43ab0db8ca4f 100644
--- a/libcxx/include/__algorithm/nth_element.h
+++ b/libcxx/include/__algorithm/nth_element.h
@@ -116,10 +116,18 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
return;
}
while (true) {
- while (!__comp(*__first, *__i))
+ while (!__comp(*__first, *__i)) {
++__i;
- while (__comp(*__first, *--__j))
- ;
+ _LIBCPP_ASSERT_UNCATEGORIZED(
+ __i != __last,
+ "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
+ }
+ do {
+ _LIBCPP_ASSERT_UNCATEGORIZED(
+ __j != __first,
+ "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
+ --__j;
+ } while (__comp(*__first, *__j));
if (__i >= __j)
break;
_Ops::iter_swap(__i, __j);
@@ -146,11 +154,19 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
while (true)
{
// __m still guards upward moving __i
- while (__comp(*__i, *__m))
+ while (__comp(*__i, *__m)) {
++__i;
+ _LIBCPP_ASSERT_UNCATEGORIZED(
+ __i != __last,
+ "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
+ }
// It is now known that a guard exists for downward moving __j
- while (!__comp(*--__j, *__m))
- ;
+ do {
+ _LIBCPP_ASSERT_UNCATEGORIZED(
+ __j != __first,
+ "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
+ --__j;
+ } while (!__comp(*__j, *__m));
if (__i >= __j)
break;
_Ops::iter_swap(__i, __j);
diff --git a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
index e5e417fe7bda2d4..1e741344b1fca6b 100644
--- a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
+++ b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
@@ -50,24 +50,34 @@
#include "bad_comparator_values.h"
#include "check_assertion.h"
-void check_oob_sort_read() {
- std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
- for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
- auto values = std::views::split(line, ' ');
- auto it = values.begin();
- std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
- it = std::next(it);
- std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
- it = std::next(it);
- bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
- comparison_results[left][right] = result;
- }
- auto predicate = [&](std::size_t* left, std::size_t* right) {
+class ComparisonResults {
+public:
+ ComparisonResults(std::string_view data) {
+ for (auto line : std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
+ auto values = std::views::split(line, ' ');
+ auto it = values.begin();
+ std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
+ it = std::next(it);
+ std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
+ it = std::next(it);
+ bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
+ comparison_results[left][right] = result;
+ }
+ }
+
+ bool compare(size_t* left, size_t* right) {
assert(left != nullptr && right != nullptr && "something is wrong with the test");
assert(comparison_results.contains(*left) && comparison_results[*left].contains(*right) && "malformed input data?");
return comparison_results[*left][*right];
- };
+ }
+ size_t size() const { return comparison_results.size(); }
+private:
+ std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
+};
+
+void check_oob_sort_read() {
+ ComparisonResults comparison_results(SORT_DATA);
std::vector<std::unique_ptr<std::size_t>> elements;
std::set<std::size_t*> valid_ptrs;
for (std::size_t i = 0; i != comparison_results.size(); ++i) {
@@ -81,7 +91,7 @@ void check_oob_sort_read() {
// because we're reading OOB.
assert(valid_ptrs.contains(left));
assert(valid_ptrs.contains(right));
- return predicate(left, right);
+ return comparison_results.compare(left, right);
};
// Check the classic sorting algorithms
@@ -165,6 +175,39 @@ void check_oob_sort_read() {
}
}
+void check_oob_nth_element_read() {
+ ComparisonResults results(NTH_ELEMENT_DATA);
+ std::vector<std::unique_ptr<std::size_t>> elements;
+ std::set<std::size_t*> valid_ptrs;
+ for (std::size_t i = 0; i != results.size(); ++i) {
+ elements.push_back(std::make_unique<std::size_t>(i));
+ valid_ptrs.insert(elements.back().get());
+ }
+
+ auto checked_predicate = [&](size_t* left, size_t* right) {
+ // If the pointers passed to the comparator are not in the set of pointers we
+ // set up above, then we're being passed garbage values from the algorithm
+ // because we're reading OOB.
+ assert(valid_ptrs.contains(left));
+ assert(valid_ptrs.contains(right));
+ return results.compare(left, right);
+ };
+
+ {
+ std::vector<std::size_t*> copy;
+ for (auto const& e : elements)
+ copy.push_back(e.get());
+ TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.begin(), copy.end(), checked_predicate), "Would read out of bounds");
+ }
+
+ {
+ std::vector<std::size_t*> copy;
+ for (auto const& e : elements)
+ copy.push_back(e.get());
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.begin(), checked_predicate), "Would read out of bounds");
+ }
+}
+
struct FloatContainer {
float value;
bool operator<(const FloatContainer& other) const {
@@ -214,6 +257,8 @@ int main(int, char**) {
check_oob_sort_read();
+ check_oob_nth_element_read();
+
check_nan_floats();
check_irreflexive();
diff --git a/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h b/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h
index 19ea023419ea90a..c0ffd16cd4ac4a1 100644
--- a/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h
+++ b/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h
@@ -11,7 +11,74 @@
#include <string_view>
-inline constexpr std::string_view DATA = R"(
+inline constexpr std::string_view NTH_ELEMENT_DATA = R"(
+0 0 0
+0 1 0
+0 2 0
+0 3 0
+0 4 1
+0 5 0
+0 6 0
+0 7 0
+1 0 0
+1 1 0
+1 2 0
+1 3 1
+1 4 1
+1 5 1
+1 6 1
+1 7 1
+2 0 1
+2 1 1
+2 2 1
+2 3 1
+2 4 1
+2 5 1
+2 6 1
+2 7 1
+3 0 1
+3 1 1
+3 2 1
+3 3 1
+3 4 1
+3 5 1
+3 6 1
+3 7 1
+4 0 1
+4 1 1
+4 2 1
+4 3 1
+4 4 1
+4 5 1
+4 6 1
+4 7 1
+5 0 1
+5 1 1
+5 2 1
+5 3 1
+5 4 1
+5 5 1
+5 6 1
+5 7 1
+6 0 1
+6 1 1
+6 2 1
+6 3 1
+6 4 1
+6 5 1
+6 6 1
+6 7 1
+7 0 1
+7 1 1
+7 2 1
+7 3 1
+7 4 1
+7 5 1
+7 6 1
+7 7 1
+)";
+
+inline constexpr std::string_view SORT_DATA = R"(
0 0 0
0 1 1
0 2 1
``````````
</details>
https://github.com/llvm/llvm-project/pull/67023
More information about the libcxx-commits
mailing list