[clang-tools-extra] [libc++] Add assertions for potential OOB reads in std::nth_element (PR #67023)

Daniel Kutenin via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 18 03:56:02 PDT 2023


https://github.com/danlark1 updated https://github.com/llvm/llvm-project/pull/67023

>From 059bbfab50592026ce2785c5f7d98eaf5c9f8bd6 Mon Sep 17 00:00:00 2001
From: Daniel Kutenin <kutdanila at yandex.ru>
Date: Thu, 21 Sep 2023 14:55:11 +0100
Subject: [PATCH 1/7] Add bound checking in nth_element

---
 libcxx/include/__algorithm/nth_element.h | 28 +++++++++++++++++++-----
 1 file changed, 22 insertions(+), 6 deletions(-)

diff --git a/libcxx/include/__algorithm/nth_element.h b/libcxx/include/__algorithm/nth_element.h
index dbacf58f9ecdbc4..37e43ab0db8ca4f 100644
--- a/libcxx/include/__algorithm/nth_element.h
+++ b/libcxx/include/__algorithm/nth_element.h
@@ -116,10 +116,18 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
                     return;
                 }
                 while (true) {
-                    while (!__comp(*__first, *__i))
+                    while (!__comp(*__first, *__i)) {
                         ++__i;
-                    while (__comp(*__first, *--__j))
-                        ;
+                        _LIBCPP_ASSERT_UNCATEGORIZED(
+                            __i != __last,
+                            "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
+                    }
+                    do {
+                        _LIBCPP_ASSERT_UNCATEGORIZED(
+                            __j != __first,
+                            "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
+                        --__j;
+                    } while (__comp(*__first, *__j));
                     if (__i >= __j)
                         break;
                     _Ops::iter_swap(__i, __j);
@@ -146,11 +154,19 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
             while (true)
             {
                 // __m still guards upward moving __i
-                while (__comp(*__i, *__m))
+                while (__comp(*__i, *__m)) {
                     ++__i;
+                    _LIBCPP_ASSERT_UNCATEGORIZED(
+                        __i != __last,
+                        "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
+                }
                 // It is now known that a guard exists for downward moving __j
-                while (!__comp(*--__j, *__m))
-                    ;
+                do {
+                    _LIBCPP_ASSERT_UNCATEGORIZED(
+                        __j != __first,
+                        "Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
+                    --__j;
+                } while (!__comp(*__j, *__m));
                 if (__i >= __j)
                     break;
                 _Ops::iter_swap(__i, __j);

>From 8e128c3ce6d8dc8afb94ba2465a2585fe3b8525a Mon Sep 17 00:00:00 2001
From: Daniel Kutenin <kutdanila at yandex.ru>
Date: Thu, 21 Sep 2023 15:22:18 +0100
Subject: [PATCH 2/7] Update nth_element out of bound test

---
 .../assert.sort.invalid_comparator.pass.cpp   | 77 +++++++++++++++----
 1 file changed, 61 insertions(+), 16 deletions(-)

diff --git a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
index e5e417fe7bda2d4..b02ae2118ec5f47 100644
--- a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
+++ b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
@@ -50,27 +50,37 @@
 #include "bad_comparator_values.h"
 #include "check_assertion.h"
 
-void check_oob_sort_read() {
-    std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
-    for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
-        auto values = std::views::split(line, ' ');
-        auto it = values.begin();
-        std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
-        it = std::next(it);
-        std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
-        it = std::next(it);
-        bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
-        comparison_results[left][right] = result;
-    }
-    auto predicate = [&](std::size_t* left, std::size_t* right) {
+class ComparisonResults {
+public:
+    ComparisonResults(std::string_view data) {
+        for (auto line : std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
+            auto values = std::views::split(line, ' ');
+            auto it = values.begin();
+            std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
+            it = std::next(it);
+            std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
+            it = std::next(it);
+            bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
+            comparison_results[left][right] = result;
+        }
+    }
+
+    bool compare(size_t* left, size_t* right) {
         assert(left != nullptr && right != nullptr && "something is wrong with the test");
         assert(comparison_results.contains(*left) && comparison_results[*left].contains(*right) && "malformed input data?");
         return comparison_results[*left][*right];
-    };
+    }
 
+    size_t size() const { return comparison_results.size(); }
+private:
+    std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
+};
+
+void check_oob_sort_read() {
+    ComparisonResults results(SORT_DATA);
     std::vector<std::unique_ptr<std::size_t>> elements;
     std::set<std::size_t*> valid_ptrs;
-    for (std::size_t i = 0; i != comparison_results.size(); ++i) {
+    for (std::size_t i = 0; i != results.size(); ++i) {
         elements.push_back(std::make_unique<std::size_t>(i));
         valid_ptrs.insert(elements.back().get());
     }
@@ -81,7 +91,7 @@ void check_oob_sort_read() {
         // because we're reading OOB.
         assert(valid_ptrs.contains(left));
         assert(valid_ptrs.contains(right));
-        return predicate(left, right);
+        return results.compare(left, right);
     };
 
     // Check the classic sorting algorithms
@@ -165,6 +175,39 @@ void check_oob_sort_read() {
     }
 }
 
+void check_oob_nth_element_read() {
+    ComparisonResults results(NTH_ELEMENT_DATA);
+    std::vector<std::unique_ptr<std::size_t>> elements;
+    std::set<std::size_t*> valid_ptrs;
+    for (std::size_t i = 0; i != results.size(); ++i) {
+        elements.push_back(std::make_unique<std::size_t>(i));
+        valid_ptrs.insert(elements.back().get());
+    }
+
+    auto checked_predicate = [&](size_t* left, size_t* right) {
+        // If the pointers passed to the comparator are not in the set of pointers we
+        // set up above, then we're being passed garbage values from the algorithm
+        // because we're reading OOB.
+        assert(valid_ptrs.contains(left));
+        assert(valid_ptrs.contains(right));
+        return results.compare(left, right);
+    };
+
+    {
+        std::vector<std::size_t*> copy;
+        for (auto const& e : elements)
+            copy.push_back(e.get());
+        TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.end(), copy.end(), checked_predicate), "Would read out of bounds");
+    }
+
+    {
+        std::vector<std::size_t*> copy;
+        for (auto const& e : elements)
+            copy.push_back(e.get());
+        TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.end(), checked_predicate), "Would read out of bounds");
+    }
+}
+
 struct FloatContainer {
   float value;
   bool operator<(const FloatContainer& other) const {
@@ -214,6 +257,8 @@ int main(int, char**) {
 
     check_oob_sort_read();
 
+    check_oob_nth_element_read();
+
     check_nan_floats();
 
     check_irreflexive();

>From b38ff937535b3ae9480a2e4f84128a22815b6933 Mon Sep 17 00:00:00 2001
From: Daniel Kutenin <kutdanila at yandex.ru>
Date: Thu, 21 Sep 2023 15:23:37 +0100
Subject: [PATCH 3/7] Update data variables for bad comparators

---
 .../alg.sorting/bad_comparator_values.h       | 69 ++++++++++++++++++-
 1 file changed, 68 insertions(+), 1 deletion(-)

diff --git a/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h b/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h
index 19ea023419ea90a..c0ffd16cd4ac4a1 100644
--- a/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h
+++ b/libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h
@@ -11,7 +11,74 @@
 
 #include <string_view>
 
-inline constexpr std::string_view DATA = R"(
+inline constexpr std::string_view NTH_ELEMENT_DATA = R"(
+0 0 0
+0 1 0
+0 2 0
+0 3 0
+0 4 1
+0 5 0
+0 6 0
+0 7 0
+1 0 0
+1 1 0
+1 2 0
+1 3 1
+1 4 1
+1 5 1
+1 6 1
+1 7 1
+2 0 1
+2 1 1
+2 2 1
+2 3 1
+2 4 1
+2 5 1
+2 6 1
+2 7 1
+3 0 1
+3 1 1
+3 2 1
+3 3 1
+3 4 1
+3 5 1
+3 6 1
+3 7 1
+4 0 1
+4 1 1
+4 2 1
+4 3 1
+4 4 1
+4 5 1
+4 6 1
+4 7 1
+5 0 1
+5 1 1
+5 2 1
+5 3 1
+5 4 1
+5 5 1
+5 6 1
+5 7 1
+6 0 1
+6 1 1
+6 2 1
+6 3 1
+6 4 1
+6 5 1
+6 6 1
+6 7 1
+7 0 1
+7 1 1
+7 2 1
+7 3 1
+7 4 1
+7 5 1
+7 6 1
+7 7 1
+)";
+
+inline constexpr std::string_view SORT_DATA = R"(
 0 0 0
 0 1 1
 0 2 1

>From bb3ed4e694a44881efdc2f98519388d03fe42968 Mon Sep 17 00:00:00 2001
From: Danila Kutenin <kutdanila at yandex.ru>
Date: Thu, 21 Sep 2023 14:41:22 +0000
Subject: [PATCH 4/7] Fix nth_element iterator

---
 .../assert.sort.invalid_comparator.pass.cpp            | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
index b02ae2118ec5f47..1e741344b1fca6b 100644
--- a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
+++ b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
@@ -77,10 +77,10 @@ class ComparisonResults {
 };
 
 void check_oob_sort_read() {
-    ComparisonResults results(SORT_DATA);
+    ComparisonResults comparison_results(SORT_DATA);
     std::vector<std::unique_ptr<std::size_t>> elements;
     std::set<std::size_t*> valid_ptrs;
-    for (std::size_t i = 0; i != results.size(); ++i) {
+    for (std::size_t i = 0; i != comparison_results.size(); ++i) {
         elements.push_back(std::make_unique<std::size_t>(i));
         valid_ptrs.insert(elements.back().get());
     }
@@ -91,7 +91,7 @@ void check_oob_sort_read() {
         // because we're reading OOB.
         assert(valid_ptrs.contains(left));
         assert(valid_ptrs.contains(right));
-        return results.compare(left, right);
+        return comparison_results.compare(left, right);
     };
 
     // Check the classic sorting algorithms
@@ -197,14 +197,14 @@ void check_oob_nth_element_read() {
         std::vector<std::size_t*> copy;
         for (auto const& e : elements)
             copy.push_back(e.get());
-        TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.end(), copy.end(), checked_predicate), "Would read out of bounds");
+        TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.begin(), copy.end(), checked_predicate), "Would read out of bounds");
     }
 
     {
         std::vector<std::size_t*> copy;
         for (auto const& e : elements)
             copy.push_back(e.get());
-        TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.end(), checked_predicate), "Would read out of bounds");
+        TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.begin(), checked_predicate), "Would read out of bounds");
     }
 }
 

>From 6779af5d41d3707fc6447640db07eea0ad6af4ad Mon Sep 17 00:00:00 2001
From: Danila Kutenin <kutdanila at yandex.ru>
Date: Thu, 21 Sep 2023 16:15:39 +0000
Subject: [PATCH 5/7] Add assert include to nth_element.h

---
 libcxx/include/__algorithm/nth_element.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/libcxx/include/__algorithm/nth_element.h b/libcxx/include/__algorithm/nth_element.h
index 37e43ab0db8ca4f..2f26bd74d2c5d23 100644
--- a/libcxx/include/__algorithm/nth_element.h
+++ b/libcxx/include/__algorithm/nth_element.h
@@ -13,6 +13,7 @@
 #include <__algorithm/comp_ref_type.h>
 #include <__algorithm/iterator_operations.h>
 #include <__algorithm/sort.h>
+#include <__assert>
 #include <__config>
 #include <__debug_utils/randomize_range.h>
 #include <__iterator/iterator_traits.h>

>From 1d6fa6a01885c9f3a1eb02d7bc0cbe7f4821969d Mon Sep 17 00:00:00 2001
From: Daniel Kutenin <kutdanila at yandex.ru>
Date: Wed, 27 Sep 2023 16:10:19 +0100
Subject: [PATCH 6/7] Resolve comments

---
 .../assert.sort.invalid_comparator.pass.cpp   | 20 ++++---------------
 1 file changed, 4 insertions(+), 16 deletions(-)

diff --git a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
index 1e741344b1fca6b..96c2821c4a654c0 100644
--- a/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
+++ b/libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
@@ -52,7 +52,7 @@
 
 class ComparisonResults {
 public:
-    ComparisonResults(std::string_view data) {
+    explicit ComparisonResults(std::string_view data) {
         for (auto line : std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
             auto values = std::views::split(line, ' ');
             auto it = values.begin();
@@ -65,10 +65,10 @@ class ComparisonResults {
         }
     }
 
-    bool compare(size_t* left, size_t* right) {
+    bool compare(size_t* left, size_t* right) const {
         assert(left != nullptr && right != nullptr && "something is wrong with the test");
-        assert(comparison_results.contains(*left) && comparison_results[*left].contains(*right) && "malformed input data?");
-        return comparison_results[*left][*right];
+        assert(comparison_results.contains(*left) && comparison_results.at(*left).contains(*right) && "malformed input data?");
+        return comparison_results.at(*left).at(*right);
     }
 
     size_t size() const { return comparison_results.size(); }
@@ -127,12 +127,6 @@ void check_oob_sort_read() {
         std::vector<std::size_t*> results(copy.size(), nullptr);
        TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate), "not a valid strict-weak ordering");
     }
-    {
-        std::vector<std::size_t*> copy;
-        for (auto const& e : elements)
-            copy.push_back(e.get());
-        std::nth_element(copy.begin(), copy.end(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
-    }
 
     // Check the Ranges sorting algorithms
     {
@@ -167,12 +161,6 @@ void check_oob_sort_read() {
         std::vector<std::size_t*> results(copy.size(), nullptr);
         TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort_copy(copy, results, checked_predicate), "not a valid strict-weak ordering");
     }
-    {
-        std::vector<std::size_t*> copy;
-        for (auto const& e : elements)
-            copy.push_back(e.get());
-        std::ranges::nth_element(copy, copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
-    }
 }
 
 void check_oob_nth_element_read() {

>From 7d892dcb63fb0d8d75b039fc2de31fb08034b168 Mon Sep 17 00:00:00 2001
From: Daniel Kutenin <kutdanila at yandex.ru>
Date: Wed, 18 Oct 2023 11:55:48 +0100
Subject: [PATCH 7/7] Disable cognitive complexity warning for nth_element

---
 libcxx/include/__algorithm/nth_element.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/libcxx/include/__algorithm/nth_element.h b/libcxx/include/__algorithm/nth_element.h
index 2f26bd74d2c5d23..ebd1cbf76143d46 100644
--- a/libcxx/include/__algorithm/nth_element.h
+++ b/libcxx/include/__algorithm/nth_element.h
@@ -43,6 +43,7 @@ __nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j,
 
 template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
+// NOLINTNEXTLINE(readability-function-cognitive-complexity)
 __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
 {
     using _Ops = _IterOps<_AlgPolicy>;



More information about the cfe-commits mailing list