[libcxx-commits] [libcxx] [libc++] Make sure ranges::find_if_not handles boolean-testables correctly (PR #69378)

Louis Dionne via libcxx-commits libcxx-commits at lists.llvm.org
Tue Oct 24 11:59:48 PDT 2023


https://github.com/ldionne updated https://github.com/llvm/llvm-project/pull/69378

>From 733616cfe9ea1c6d53965209475f4ca3e1d6fe22 Mon Sep 17 00:00:00 2001
From: Louis Dionne <ldionne.2 at gmail.com>
Date: Tue, 17 Oct 2023 13:29:22 -0700
Subject: [PATCH 1/3] [libc++] Make sure ranges::find_if_not handles
 boolean-testables correctly

We would fail to implicitly convert the result of the predicate to bool,
which means we'd potentially perform a copy or move construction of the
boolean-testable, which isn't allowed. We already had tests aiming to
ensure correct handling of these types, but they failed to catch copy
and move construction because of guaranteed RVO.

Fixes #69074
---
 libcxx/include/__algorithm/ranges_find_if_not.h          | 4 ++--
 .../alg.find/ranges.find_if_not.pass.cpp                 | 2 +-
 libcxx/test/support/boolean_testable.h                   | 9 ++++++---
 3 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/libcxx/include/__algorithm/ranges_find_if_not.h b/libcxx/include/__algorithm/ranges_find_if_not.h
index 6beade1462e099c..a18bea43165e0d8 100644
--- a/libcxx/include/__algorithm/ranges_find_if_not.h
+++ b/libcxx/include/__algorithm/ranges_find_if_not.h
@@ -39,14 +39,14 @@ struct __fn {
             indirect_unary_predicate<projected<_Ip, _Proj>> _Pred>
   _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr _Ip
   operator()(_Ip __first, _Sp __last, _Pred __pred, _Proj __proj = {}) const {
-    auto __pred2 = [&](auto&& __e) { return !std::invoke(__pred, std::forward<decltype(__e)>(__e)); };
+    auto __pred2 = [&](auto&& __e) -> bool { return !std::invoke(__pred, std::forward<decltype(__e)>(__e)); };
     return ranges::__find_if_impl(std::move(__first), std::move(__last), __pred2, __proj);
   }
 
   template <input_range _Rp, class _Proj = identity, indirect_unary_predicate<projected<iterator_t<_Rp>, _Proj>> _Pred>
   _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr borrowed_iterator_t<_Rp>
   operator()(_Rp&& __r, _Pred __pred, _Proj __proj = {}) const {
-    auto __pred2 = [&](auto&& __e) { return !std::invoke(__pred, std::forward<decltype(__e)>(__e)); };
+    auto __pred2 = [&](auto&& __e) -> bool { return !std::invoke(__pred, std::forward<decltype(__e)>(__e)); };
     return ranges::__find_if_impl(ranges::begin(__r), ranges::end(__r), __pred2, __proj);
   }
 };
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/ranges.find_if_not.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/ranges.find_if_not.pass.cpp
index 95860745f56204e..03d43ebb752bff2 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/ranges.find_if_not.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/ranges.find_if_not.pass.cpp
@@ -227,7 +227,7 @@ constexpr bool test() {
     }
     {
       int a[] = {1, 2, 3, 4};
-      auto ret = std::ranges::find_if_not(a, [](const int& b) { return BooleanTestable{b != 3}; });
+      auto ret = std::ranges::find_if_not(a, [](const int& i) { return BooleanTestable{i != 3}; });
       assert(ret == a + 2);
     }
   }
diff --git a/libcxx/test/support/boolean_testable.h b/libcxx/test/support/boolean_testable.h
index e810e4e0461dc69..22bcefe04f9a53c 100644
--- a/libcxx/test/support/boolean_testable.h
+++ b/libcxx/test/support/boolean_testable.h
@@ -11,6 +11,8 @@
 
 #include "test_macros.h"
 
+#include <utility>
+
 #if TEST_STD_VER > 17
 
 class BooleanTestable {
@@ -24,11 +26,12 @@ class BooleanTestable {
   }
 
   friend constexpr BooleanTestable operator!=(const BooleanTestable& lhs, const BooleanTestable& rhs) {
-    return !(lhs == rhs);
+    return lhs.value_ != rhs.value_;
   }
 
-  constexpr BooleanTestable operator!() {
-    return BooleanTestable{!value_};
+  constexpr BooleanTestable&& operator!() && {
+    value_ = !value_;
+    return std::move(*this);
   }
 
   // this class should behave like a bool, so the constructor shouldn't be explicit

>From 825f4df50cbe25cec01011f667e131480e06452b Mon Sep 17 00:00:00 2001
From: Louis Dionne <ldionne.2 at gmail.com>
Date: Tue, 24 Oct 2023 11:08:05 -0700
Subject: [PATCH 2/3] Improve tests and in particular test projections

---
 .../include/__algorithm/ranges_upper_bound.h  |   4 +-
 ...robust_against_nonbool_predicates.pass.cpp | 148 +++++++++++++++---
 2 files changed, 128 insertions(+), 24 deletions(-)

diff --git a/libcxx/include/__algorithm/ranges_upper_bound.h b/libcxx/include/__algorithm/ranges_upper_bound.h
index a12a0e39b084909..7b571fb3448f94c 100644
--- a/libcxx/include/__algorithm/ranges_upper_bound.h
+++ b/libcxx/include/__algorithm/ranges_upper_bound.h
@@ -39,7 +39,7 @@ struct __fn {
             indirect_strict_weak_order<const _Type*, projected<_Iter, _Proj>> _Comp = ranges::less>
   _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr _Iter
   operator()(_Iter __first, _Sent __last, const _Type& __value, _Comp __comp = {}, _Proj __proj = {}) const {
-    auto __comp_lhs_rhs_swapped = [&](const auto& __lhs, const auto& __rhs) {
+    auto __comp_lhs_rhs_swapped = [&](const auto& __lhs, const auto& __rhs) -> bool {
       return !std::invoke(__comp, __rhs, __lhs);
     };
 
@@ -52,7 +52,7 @@ struct __fn {
             indirect_strict_weak_order<const _Type*, projected<iterator_t<_Range>, _Proj>> _Comp = ranges::less>
   _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr borrowed_iterator_t<_Range>
   operator()(_Range&& __r, const _Type& __value, _Comp __comp = {}, _Proj __proj = {}) const {
-    auto __comp_lhs_rhs_swapped = [&](const auto& __lhs, const auto& __rhs) {
+    auto __comp_lhs_rhs_swapped = [&](const auto& __lhs, const auto& __rhs) -> bool {
       return !std::invoke(__comp, __rhs, __lhs);
     };
 
diff --git a/libcxx/test/std/algorithms/ranges_robust_against_nonbool_predicates.pass.cpp b/libcxx/test/std/algorithms/ranges_robust_against_nonbool_predicates.pass.cpp
index 04773f0f5bc802e..de030fd009b9a42 100644
--- a/libcxx/test/std/algorithms/ranges_robust_against_nonbool_predicates.pass.cpp
+++ b/libcxx/test/std/algorithms/ranges_robust_against_nonbool_predicates.pass.cpp
@@ -20,6 +20,7 @@
 #include <initializer_list>
 #include <iterator>
 #include <ranges>
+#include <vector>
 
 #include "boolean_testable.h"
 #include "test_macros.h"
@@ -32,27 +33,129 @@ constexpr auto binary_pred = [](int i, int j) { return BooleanTestable(i < j); }
 static_assert(!std::same_as<decltype(binary_pred(1, 2)), bool>);
 static_assert(std::convertible_to<decltype(binary_pred(1, 2)), bool>);
 
-// Invokes both the (iterator, sentinel, ...) and the (range, ...) overloads of the given niebloid.
+constexpr auto projection = [](int i) { return i; };
 
-// (in, ...)
-template <class Func, std::ranges::range Input, class... Args>
-constexpr void test(Func&& func, Input& in, Args&&... args) {
-  (void)func(in.begin(), in.end(), std::forward<Args>(args)...);
-  (void)func(in, std::forward<Args>(args)...);
+// Invokes both the (iterator, sentinel, ...) and the (range, ...) overloads of the given niebloid,
+// with and without a projection.
+
+// (in, pred)
+template <class Func, std::ranges::range Input, class Predicate>
+constexpr void test(Func&& func, Input& in, Predicate&& pred) {
+  (void)func(in.begin(), in.end(), pred);
+  (void)func(in, pred);
+
+  (void)func(in.begin(), in.end(), pred, projection);
+  (void)func(in, pred, projection);
 }
 
-// (in1, in2, ...)
-template <class Func, std::ranges::range Input, class... Args>
-constexpr void test(Func&& func, Input& in1, Input& in2, Args&&... args) {
-  (void)func(in1.begin(), in1.end(), in2.begin(), in2.end(), std::forward<Args>(args)...);
-  (void)func(in1, in2, std::forward<Args>(args)...);
+// (in1, in2, pred)
+template <class Func, std::ranges::range Input, class Predicate>
+constexpr void test(Func&& func, Input& in1, Input& in2, Predicate&& pred) {
+  (void)func(in1.begin(), in1.end(), in2.begin(), in2.end(), pred);
+  (void)func(in1, in2, pred);
+
+  (void)func(in1.begin(), in1.end(), in2.begin(), in2.end(), pred, projection);
+  (void)func(in1, in2, pred, projection);
 }
 
-// (in, mid, ...)
-template <class Func, std::ranges::range Input, class... Args>
-constexpr void test_mid(Func&& func, Input& in, std::ranges::iterator_t<Input> mid, Args&&... args) {
-  (void)func(in.begin(), mid, in.end(), std::forward<Args>(args)...);
-  (void)func(in, mid, std::forward<Args>(args)...);
+// (in, val, pred)
+template <class Func, std::ranges::range Input, class Predicate>
+constexpr void test(Func&& func, Input& in, std::ranges::range_value_t<Input> const& val, Predicate&& pred) {
+  (void)func(in.begin(), in.end(), val, pred);
+  (void)func(in, val, pred);
+
+  (void)func(in.begin(), in.end(), val, pred, projection);
+  (void)func(in, val, pred, projection);
+}
+
+// (in, n, val, pred)
+template <class Func, std::ranges::range Input, class Predicate>
+constexpr void
+test(Func&& func,
+     Input& in,
+     std::ranges::range_difference_t<Input> n,
+     std::ranges::range_value_t<Input> const& val,
+     Predicate&& pred) {
+  (void)func(in.begin(), in.end(), n, val, pred);
+  (void)func(in, n, val, pred);
+
+  (void)func(in.begin(), in.end(), val, n, pred, projection);
+  (void)func(in, n, val, pred, projection);
+}
+
+// (in, out, pred)
+template <class Func,
+          std::ranges::range Input,
+          std::output_iterator<std::ranges::range_value_t<Input>> Output,
+          class Predicate>
+  requires(!std::same_as<Output, std::ranges::iterator_t<Input>>) // disambiguate with (in, mid, pred)
+constexpr void test(Func&& func, Input& in, Output out, Predicate&& pred) {
+  (void)func(in.begin(), in.end(), out, pred);
+  (void)func(in, out, pred);
+
+  (void)func(in.begin(), in.end(), out, pred, projection);
+  (void)func(in, out, pred, projection);
+}
+
+// (in, mid, pred)
+template <class Func, std::ranges::range Input, class Predicate>
+constexpr void test(Func&& func, Input& in, std::ranges::iterator_t<Input> mid, Predicate&& pred) {
+  (void)func(in.begin(), mid, in.end(), pred);
+  (void)func(in, mid, pred);
+
+  (void)func(in.begin(), mid, in.end(), pred, projection);
+  (void)func(in, mid, pred, projection);
+}
+
+// (in, pred, val) -- e.g. replace_if
+template <class Func, std::ranges::range Input, class Predicate>
+constexpr void test(Func&& func, Input& in, Predicate&& pred, std::ranges::range_value_t<Input> const& val) {
+  (void)func(in.begin(), in.end(), pred, val);
+  (void)func(in, pred, val);
+
+  (void)func(in.begin(), in.end(), pred, val, projection);
+  (void)func(in, pred, val, projection);
+}
+
+// (in, out, pred, val) -- e.g. replace_copy_if
+template <class Func,
+          std::ranges::range Input,
+          std::output_iterator<std::ranges::range_value_t<Input>> Output,
+          class Predicate>
+constexpr void
+test(Func&& func, Input& in, Output out, Predicate&& pred, std::ranges::range_value_t<Input> const& val) {
+  (void)func(in.begin(), in.end(), out, pred, val);
+  (void)func(in, out, pred, val);
+
+  (void)func(in.begin(), in.end(), out, pred, val, projection);
+  (void)func(in, out, pred, val, projection);
+}
+
+// (in, out1, out2, pred) -- e.g. partition_copy
+template <class Func,
+          std::ranges::range Input,
+          std::output_iterator<std::ranges::range_value_t<Input>> Output1,
+          std::output_iterator<std::ranges::range_value_t<Input>> Output2,
+          class Predicate>
+constexpr void test(Func&& func, Input& in, Output1 out1, Output2 out2, Predicate&& pred) {
+  (void)func(in.begin(), in.end(), out1, out2, pred);
+  (void)func(in, out1, out2, pred);
+
+  (void)func(in.begin(), in.end(), out1, out2, pred, projection);
+  (void)func(in, out1, out2, pred, projection);
+}
+
+// (in1, in2, out, pred) -- e.g. merge
+template <class Func,
+          std::ranges::range Input,
+          std::output_iterator<std::ranges::range_value_t<Input>> Output,
+          class Predicate>
+constexpr void test(Func&& func, Input& in1, Input& in2, Output out, Predicate&& pred) {
+  (void)func(in1.begin(), in1.end(), in2.begin(), in2.end(), out, pred);
+  (void)func(in1, in2, out, pred);
+
+  (void)func(in1.begin(), in1.end(), in2.begin(), in2.end(), out, pred, projection, projection);
+  (void)func(in1, in2, out, pred, projection, projection);
 }
 
 constexpr bool test_all() {
@@ -60,9 +163,10 @@ constexpr bool test_all() {
   std::array in2 = {4, 5, 6};
   auto mid       = in.begin() + 1;
 
-  std::array output = {7, 8, 9, 10, 11, 12};
-  auto out          = output.begin();
-  auto out2         = output.begin() + 1;
+  std::vector<int> output1;
+  std::vector<int> output2;
+  auto out  = std::back_insert_iterator(output1);
+  auto out2 = std::back_insert_iterator(output2);
 
   int x     = 2;
   int count = 1;
@@ -137,10 +241,10 @@ constexpr bool test_all() {
   test(std::ranges::sort, in, binary_pred);
   if (!std::is_constant_evaluated())
     test(std::ranges::stable_sort, in, binary_pred);
-  test_mid(std::ranges::partial_sort, in, mid, binary_pred);
-  test_mid(std::ranges::nth_element, in, mid, binary_pred);
+  test(std::ranges::partial_sort, in, mid, binary_pred);
+  test(std::ranges::nth_element, in, mid, binary_pred);
   if (!std::is_constant_evaluated())
-    test_mid(std::ranges::inplace_merge, in, mid, binary_pred);
+    test(std::ranges::inplace_merge, in, mid, binary_pred);
   test(std::ranges::make_heap, in, binary_pred);
   test(std::ranges::push_heap, in, binary_pred);
   test(std::ranges::pop_heap, in, binary_pred);

>From 515fd3c55ee29106790a0e80fbfa226e3803a69c Mon Sep 17 00:00:00 2001
From: Louis Dionne <ldionne.2 at gmail.com>
Date: Tue, 24 Oct 2023 11:59:26 -0700
Subject: [PATCH 3/3] Fix the same problem for chunk_by and add tests for other
 views

---
 libcxx/include/__ranges/chunk_by_view.h       | 12 +--
 ...robust_against_nonbool_predicates.pass.cpp | 75 +++++++++++++++++++
 2 files changed, 81 insertions(+), 6 deletions(-)
 create mode 100644 libcxx/test/std/ranges/range.adaptors/robust_against_nonbool_predicates.pass.cpp

diff --git a/libcxx/include/__ranges/chunk_by_view.h b/libcxx/include/__ranges/chunk_by_view.h
index cfb149b443571e8..b73a1a21365f10d 100644
--- a/libcxx/include/__ranges/chunk_by_view.h
+++ b/libcxx/include/__ranges/chunk_by_view.h
@@ -16,7 +16,6 @@
 #include <__config>
 #include <__functional/bind_back.h>
 #include <__functional/invoke.h>
-#include <__functional/not_fn.h>
 #include <__functional/reference_wrapper.h>
 #include <__iterator/concepts.h>
 #include <__iterator/default_sentinel.h>
@@ -69,10 +68,11 @@ class chunk_by_view : public view_interface<chunk_by_view<_View, _Pred>> {
   _LIBCPP_HIDE_FROM_ABI constexpr iterator_t<_View> __find_next(iterator_t<_View> __current) {
     _LIBCPP_ASSERT_UNCATEGORIZED(
         __pred_.__has_value(), "Trying to call __find_next() on a chunk_by_view that does not have a valid predicate.");
-
-    return ranges::next(ranges::adjacent_find(__current, ranges::end(__base_), std::not_fn(std::ref(*__pred_))),
-                        1,
-                        ranges::end(__base_));
+    auto __reversed_pred = [this]<class _Tp, class _Up>(_Tp&& __x, _Up&& __y) -> bool {
+      return !std::invoke(*__pred_, std::forward<_Tp>(__x), std::forward<_Up>(__y));
+    };
+    return ranges::next(
+        ranges::adjacent_find(__current, ranges::end(__base_), __reversed_pred), 1, ranges::end(__base_));
   }
 
   _LIBCPP_HIDE_FROM_ABI constexpr iterator_t<_View> __find_prev(iterator_t<_View> __current)
@@ -85,7 +85,7 @@ class chunk_by_view : public view_interface<chunk_by_view<_View, _Pred>> {
 
     auto __first = ranges::begin(__base_);
     reverse_view __reversed{subrange{__first, __current}};
-    auto __reversed_pred = [this]<class _Tp, class _Up>(_Tp&& __x, _Up&& __y) {
+    auto __reversed_pred = [this]<class _Tp, class _Up>(_Tp&& __x, _Up&& __y) -> bool {
       return !std::invoke(*__pred_, std::forward<_Up>(__y), std::forward<_Tp>(__x));
     };
     return ranges::prev(ranges::adjacent_find(__reversed, __reversed_pred).base(), 1, std::move(__first));
diff --git a/libcxx/test/std/ranges/range.adaptors/robust_against_nonbool_predicates.pass.cpp b/libcxx/test/std/ranges/range.adaptors/robust_against_nonbool_predicates.pass.cpp
new file mode 100644
index 000000000000000..7b78969c847cd03
--- /dev/null
+++ b/libcxx/test/std/ranges/range.adaptors/robust_against_nonbool_predicates.pass.cpp
@@ -0,0 +1,75 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++03, c++11, c++14, c++17
+
+// <ranges>
+//
+// Range adaptors that take predicates should support predicates that return a non-boolean
+// value as long as the returned type is implicitly convertible to bool.
+
+#include <ranges>
+
+#include <array>
+#include <cassert>
+#include <concepts>
+
+#include "boolean_testable.h"
+#include "test_macros.h"
+
+constexpr auto unary_pred = [](int i) { return BooleanTestable(i > 0); };
+static_assert(!std::same_as<decltype(unary_pred(1)), bool>);
+static_assert(std::convertible_to<decltype(unary_pred(1)), bool>);
+
+constexpr auto binary_pred = [](int i, int j) { return BooleanTestable(i < j); };
+static_assert(!std::same_as<decltype(binary_pred(1, 2)), bool>);
+static_assert(std::convertible_to<decltype(binary_pred(1, 2)), bool>);
+
+template <std::ranges::view View>
+constexpr void use(View view) {
+  // Just use the view in a few ways. Our goal here is to trigger the instantiation
+  // of various functions related to the view and its iterators in the hopes that we
+  // instantiate functions that might have incorrect implementations w.r.t. predicates.
+  auto first = std::ranges::begin(view);
+  auto last  = std::ranges::end(view);
+  assert(first != last);
+  ++first;
+  --first;
+  (void)(first == last);
+  (void)std::ranges::empty(view);
+}
+
+constexpr bool test_all() {
+  std::array in = {1, 2, 3, -1, -2, -3};
+
+  {
+    auto view = std::views::chunk_by(in, binary_pred);
+    use(view);
+  }
+  {
+    auto view = std::views::drop_while(in, unary_pred);
+    use(view);
+  }
+  {
+    auto view = std::views::filter(in, unary_pred);
+    use(view);
+  }
+  {
+    auto view = std::views::take_while(in, unary_pred);
+    use(view);
+  }
+
+  return true;
+}
+
+int main(int, char**) {
+  test_all();
+  static_assert(test_all());
+
+  return 0;
+}



More information about the libcxx-commits mailing list