[libcxx-commits] [libcxx] [libc++] Optimize std::find if types are integral (PR #70345)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Thu Dec 14 07:47:59 PST 2023


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/70345

>From 77f44c44ea07d31eeee9102091025ba4c327498a Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Thu, 26 Oct 2023 11:54:29 +0200
Subject: [PATCH] [libc++] Optimize std::find if types are integral

---
 libcxx/include/__algorithm/find.h             |  21 ++
 .../alg.nonmodifying/alg.find/find.pass.cpp   | 268 ++++++++++++++++++
 2 files changed, 289 insertions(+)

diff --git a/libcxx/include/__algorithm/find.h b/libcxx/include/__algorithm/find.h
index 0118489d94699d..075773db85c1c2 100644
--- a/libcxx/include/__algorithm/find.h
+++ b/libcxx/include/__algorithm/find.h
@@ -19,7 +19,10 @@
 #include <__functional/invoke.h>
 #include <__fwd/bit_reference.h>
 #include <__string/constexpr_c_functions.h>
+#include <__type_traits/common_type.h>
+#include <__type_traits/is_integral.h>
 #include <__type_traits/is_same.h>
+#include <limits>
 
 #ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS
 #  include <cwchar>
@@ -73,6 +76,24 @@ __find_impl(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) {
 }
 #endif // _LIBCPP_HAS_NO_WIDE_CHARACTERS
 
+// TODO: This should also be possible to get right with different signedness
+// cast integral types to allow vectorization
+template <class _Tp,
+          class _Up,
+          class _Proj,
+          __enable_if_t<__is_identity<_Proj>::value && !__libcpp_is_trivially_equality_comparable<_Tp, _Up>::value &&
+                            is_integral<_Tp>::value && is_integral<_Up>::value &&
+                            is_signed<_Tp>::value == is_signed<_Up>::value,
+                        int> = 0>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp*
+__find_impl(_Tp* __first, _Tp* __last, const _Up& __value, _Proj& __proj) {
+  using __common_t = typename common_type<_Tp, _Up>::type;
+  if (__common_t(__value) < numeric_limits<__common_t>::min() ||
+      __common_t(__value) > numeric_limits<__common_t>::max())
+    return __last;
+  return std::__find_impl(__first, __last, _Tp(__value), __proj);
+}
+
 // __bit_iterator implementation
 template <bool _ToFind, class _Cp, bool _IsConst>
 _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI __bit_iterator<_Cp, _IsConst>
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp
index b55a852c10cafa..816a5fb806a87b 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp
@@ -113,6 +113,272 @@ struct TestTypes {
   }
 };
 
+enum class match {
+  none,
+  first,
+};
+
+template <class T, class U>
+TEST_CONSTEXPR_CXX20 void test_integer_promotion(match min_expected, match max_expected) {
+  { // check min value
+    T val[1]  = {std::numeric_limits<T>::min()};
+    U to_find = static_cast<U>(val[0]);
+    auto ret  = std::find(val, val + 1, to_find);
+    if (min_expected == match::none) {
+      assert(ret == val + 1);
+    } else {
+      assert(ret == val);
+    }
+  }
+  { // check max value
+    T val[1]  = {std::numeric_limits<T>::max()};
+    U to_find = static_cast<U>(val[0]);
+    auto ret  = std::find(val, val + 1, to_find);
+    if (max_expected == match::none) {
+      assert(ret == val + 1);
+    } else {
+      assert(ret == val);
+    }
+  }
+}
+
+TEST_CONSTEXPR_CXX20 void test_integer_promotions() {
+  { // signed char
+    test_integer_promotion<wchar_t, signed char>(match::none, match::none);
+    test_integer_promotion<char8_t, signed char>(match::first, match::none);
+    test_integer_promotion<char16_t, signed char>(match::first, match::none);
+    test_integer_promotion<char32_t, signed char>(match::first, match::first);
+
+    test_integer_promotion<signed char, signed char>(match::first, match::first);
+    test_integer_promotion<short, signed char>(match::none, match::none);
+    test_integer_promotion<int, signed char>(match::none, match::none);
+    test_integer_promotion<long, signed char>(match::none, match::none);
+    test_integer_promotion<long long, signed char>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, signed char>(match::first, match::none);
+    test_integer_promotion<unsigned short, signed char>(match::first, match::none);
+    test_integer_promotion<unsigned int, signed char>(match::first, match::first);
+    test_integer_promotion<unsigned long, signed char>(match::first, match::first);
+    test_integer_promotion<unsigned long long, signed char>(match::first, match::first);
+  }
+  { // unsigned char
+    test_integer_promotion<wchar_t, unsigned char>(match::none, match::none);
+    test_integer_promotion<char8_t, unsigned char>(match::first, match::first);
+    test_integer_promotion<char16_t, unsigned char>(match::first, match::none);
+    test_integer_promotion<char32_t, unsigned char>(match::first, match::none);
+
+    test_integer_promotion<signed char, unsigned char>(match::none, match::first);
+    test_integer_promotion<short, unsigned char>(match::none, match::none);
+    test_integer_promotion<int, unsigned char>(match::none, match::none);
+    test_integer_promotion<long, unsigned char>(match::none, match::none);
+    test_integer_promotion<long long, unsigned char>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, unsigned char>(match::first, match::first);
+    test_integer_promotion<unsigned short, unsigned char>(match::first, match::none);
+    test_integer_promotion<unsigned int, unsigned char>(match::first, match::none);
+    test_integer_promotion<unsigned long, unsigned char>(match::first, match::none);
+    test_integer_promotion<unsigned long long, unsigned char>(match::first, match::none);
+  }
+  { // char8_t
+    test_integer_promotion<wchar_t, char8_t>(match::none, match::none);
+    test_integer_promotion<char8_t, char8_t>(match::first, match::first);
+    test_integer_promotion<char16_t, char8_t>(match::first, match::none);
+    test_integer_promotion<char32_t, char8_t>(match::first, match::none);
+
+    test_integer_promotion<signed char, char8_t>(match::none, match::first);
+    test_integer_promotion<short, char8_t>(match::none, match::none);
+    test_integer_promotion<int, char8_t>(match::none, match::none);
+    test_integer_promotion<long, char8_t>(match::none, match::none);
+    test_integer_promotion<long long, char8_t>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, char8_t>(match::first, match::first);
+    test_integer_promotion<unsigned short, char8_t>(match::first, match::none);
+    test_integer_promotion<unsigned int, char8_t>(match::first, match::none);
+    test_integer_promotion<unsigned long, char8_t>(match::first, match::none);
+    test_integer_promotion<unsigned long long, char8_t>(match::first, match::none);
+  }
+  { // char16_t
+    test_integer_promotion<wchar_t, char16_t>(match::none, match::none);
+    test_integer_promotion<char8_t, char16_t>(match::first, match::first);
+    test_integer_promotion<char16_t, char16_t>(match::first, match::first);
+    test_integer_promotion<char32_t, char16_t>(match::first, match::none);
+
+    test_integer_promotion<signed char, char16_t>(match::none, match::first);
+    test_integer_promotion<short, char16_t>(match::none, match::first);
+    test_integer_promotion<int, char16_t>(match::none, match::none);
+    test_integer_promotion<long, char16_t>(match::none, match::none);
+    test_integer_promotion<long long, char16_t>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, char16_t>(match::first, match::first);
+    test_integer_promotion<unsigned short, char16_t>(match::first, match::first);
+    test_integer_promotion<unsigned int, char16_t>(match::first, match::none);
+    test_integer_promotion<unsigned long, char16_t>(match::first, match::none);
+    test_integer_promotion<unsigned long long, char16_t>(match::first, match::none);
+  }
+  { // char32_t
+    test_integer_promotion<wchar_t, char32_t>(match::first, match::first);
+    test_integer_promotion<char8_t, char32_t>(match::first, match::first);
+    test_integer_promotion<char16_t, char32_t>(match::first, match::first);
+    test_integer_promotion<char32_t, char32_t>(match::first, match::first);
+
+    test_integer_promotion<signed char, char32_t>(match::first, match::first);
+    test_integer_promotion<short, char32_t>(match::first, match::first);
+    test_integer_promotion<int, char32_t>(match::first, match::first);
+    test_integer_promotion<long, char32_t>(match::none, match::none);
+    test_integer_promotion<long long, char32_t>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, char32_t>(match::first, match::first);
+    test_integer_promotion<unsigned short, char32_t>(match::first, match::first);
+    test_integer_promotion<unsigned int, char32_t>(match::first, match::first);
+    test_integer_promotion<unsigned long, char32_t>(match::first, match::none);
+    test_integer_promotion<unsigned long long, char32_t>(match::first, match::none);
+  }
+  { // short
+    test_integer_promotion<wchar_t, short>(match::none, match::none);
+    test_integer_promotion<char8_t, short>(match::first, match::first);
+    test_integer_promotion<char16_t, short>(match::first, match::none);
+    test_integer_promotion<char32_t, short>(match::first, match::first);
+
+    test_integer_promotion<signed char, short>(match::first, match::first);
+    test_integer_promotion<short, short>(match::first, match::first);
+    test_integer_promotion<int, short>(match::none, match::none);
+    test_integer_promotion<long, short>(match::none, match::none);
+    test_integer_promotion<long long, short>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, short>(match::first, match::first);
+    test_integer_promotion<unsigned short, short>(match::first, match::none);
+    test_integer_promotion<unsigned int, short>(match::first, match::first);
+    test_integer_promotion<unsigned long, short>(match::first, match::first);
+    test_integer_promotion<unsigned long long, short>(match::first, match::first);
+  }
+  { // unsigned short
+    test_integer_promotion<wchar_t, unsigned short>(match::none, match::none);
+    test_integer_promotion<char8_t, unsigned short>(match::first, match::first);
+    test_integer_promotion<char16_t, unsigned short>(match::first, match::first);
+    test_integer_promotion<char32_t, unsigned short>(match::first, match::none);
+
+    test_integer_promotion<signed char, unsigned short>(match::none, match::first);
+    test_integer_promotion<short, unsigned short>(match::none, match::first);
+    test_integer_promotion<int, unsigned short>(match::none, match::none);
+    test_integer_promotion<long, unsigned short>(match::none, match::none);
+    test_integer_promotion<long long, unsigned short>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, unsigned short>(match::first, match::first);
+    test_integer_promotion<unsigned short, unsigned short>(match::first, match::first);
+    test_integer_promotion<unsigned int, unsigned short>(match::first, match::none);
+    test_integer_promotion<unsigned long, unsigned short>(match::first, match::none);
+    test_integer_promotion<unsigned long long, unsigned short>(match::first, match::none);
+  }
+  { // int
+    test_integer_promotion<wchar_t, int>(match::first, match::first);
+    test_integer_promotion<char8_t, int>(match::first, match::first);
+    test_integer_promotion<char16_t, int>(match::first, match::first);
+    test_integer_promotion<char32_t, int>(match::first, match::first);
+
+    test_integer_promotion<signed char, int>(match::first, match::first);
+    test_integer_promotion<short, int>(match::first, match::first);
+    test_integer_promotion<int, int>(match::first, match::first);
+    test_integer_promotion<long, int>(match::none, match::none);
+    test_integer_promotion<long long, int>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, int>(match::first, match::first);
+    test_integer_promotion<unsigned short, int>(match::first, match::first);
+    test_integer_promotion<unsigned int, int>(match::first, match::first);
+    test_integer_promotion<unsigned long, int>(match::first, match::first);
+    test_integer_promotion<unsigned long long, int>(match::first, match::first);
+  }
+  { // unsigned int
+    test_integer_promotion<wchar_t, unsigned int>(match::first, match::first);
+    test_integer_promotion<char8_t, unsigned int>(match::first, match::first);
+    test_integer_promotion<char16_t, unsigned int>(match::first, match::first);
+    test_integer_promotion<char32_t, unsigned int>(match::first, match::first);
+
+    test_integer_promotion<signed char, unsigned int>(match::first, match::first);
+    test_integer_promotion<short, unsigned int>(match::first, match::first);
+    test_integer_promotion<int, unsigned int>(match::first, match::first);
+    test_integer_promotion<long, unsigned int>(match::none, match::none);
+    test_integer_promotion<long long, unsigned int>(match::none, match::none);
+
+    test_integer_promotion<unsigned char, unsigned int>(match::first, match::first);
+    test_integer_promotion<unsigned short, unsigned int>(match::first, match::first);
+    test_integer_promotion<unsigned int, unsigned int>(match::first, match::first);
+    test_integer_promotion<unsigned long, unsigned int>(match::first, match::none);
+    test_integer_promotion<unsigned long long, unsigned int>(match::first, match::none);
+  }
+  { // long
+    test_integer_promotion<wchar_t, long>(match::first, match::first);
+    test_integer_promotion<char8_t, long>(match::first, match::first);
+    test_integer_promotion<char16_t, long>(match::first, match::first);
+    test_integer_promotion<char32_t, long>(match::first, match::first);
+
+    test_integer_promotion<signed char, long>(match::first, match::first);
+    test_integer_promotion<short, long>(match::first, match::first);
+    test_integer_promotion<int, long>(match::first, match::first);
+    test_integer_promotion<long, long>(match::first, match::first);
+    test_integer_promotion<long long, long>(match::first, match::first);
+
+    test_integer_promotion<unsigned char, long>(match::first, match::first);
+    test_integer_promotion<unsigned short, long>(match::first, match::first);
+    test_integer_promotion<unsigned int, long>(match::first, match::first);
+    test_integer_promotion<unsigned long, long>(match::first, match::first);
+    test_integer_promotion<unsigned long long, long>(match::first, match::first);
+  }
+  { // unsigned long
+    test_integer_promotion<wchar_t, unsigned long>(match::first, match::first);
+    test_integer_promotion<char8_t, unsigned long>(match::first, match::first);
+    test_integer_promotion<char16_t, unsigned long>(match::first, match::first);
+    test_integer_promotion<char32_t, unsigned long>(match::first, match::first);
+
+    test_integer_promotion<signed char, unsigned long>(match::first, match::first);
+    test_integer_promotion<short, unsigned long>(match::first, match::first);
+    test_integer_promotion<int, unsigned long>(match::first, match::first);
+    test_integer_promotion<long, unsigned long>(match::first, match::first);
+    test_integer_promotion<long long, unsigned long>(match::first, match::first);
+
+    test_integer_promotion<unsigned char, unsigned long>(match::first, match::first);
+    test_integer_promotion<unsigned short, unsigned long>(match::first, match::first);
+    test_integer_promotion<unsigned int, unsigned long>(match::first, match::first);
+    test_integer_promotion<unsigned long, unsigned long>(match::first, match::first);
+    test_integer_promotion<unsigned long long, unsigned long>(match::first, match::first);
+  }
+  { // long long
+    test_integer_promotion<wchar_t, long long>(match::first, match::first);
+    test_integer_promotion<char8_t, long long>(match::first, match::first);
+    test_integer_promotion<char16_t, long long>(match::first, match::first);
+    test_integer_promotion<char32_t, long long>(match::first, match::first);
+
+    test_integer_promotion<signed char, long long>(match::first, match::first);
+    test_integer_promotion<short, long long>(match::first, match::first);
+    test_integer_promotion<int, long long>(match::first, match::first);
+    test_integer_promotion<long, long long>(match::first, match::first);
+    test_integer_promotion<long long, long long>(match::first, match::first);
+
+    test_integer_promotion<unsigned char, long long>(match::first, match::first);
+    test_integer_promotion<unsigned short, long long>(match::first, match::first);
+    test_integer_promotion<unsigned int, long long>(match::first, match::first);
+    test_integer_promotion<unsigned long, long long>(match::first, match::first);
+    test_integer_promotion<unsigned long long, long long>(match::first, match::first);
+  }
+  { // unsigned long long
+    test_integer_promotion<wchar_t, unsigned long long>(match::first, match::first);
+    test_integer_promotion<char8_t, unsigned long long>(match::first, match::first);
+    test_integer_promotion<char16_t, unsigned long long>(match::first, match::first);
+    test_integer_promotion<char32_t, unsigned long long>(match::first, match::first);
+
+    test_integer_promotion<signed char, unsigned long long>(match::first, match::first);
+    test_integer_promotion<short, unsigned long long>(match::first, match::first);
+    test_integer_promotion<int, unsigned long long>(match::first, match::first);
+    test_integer_promotion<long, unsigned long long>(match::first, match::first);
+    test_integer_promotion<long long, unsigned long long>(match::first, match::first);
+
+    test_integer_promotion<unsigned char, unsigned long long>(match::first, match::first);
+    test_integer_promotion<unsigned short, unsigned long long>(match::first, match::first);
+    test_integer_promotion<unsigned int, unsigned long long>(match::first, match::first);
+    test_integer_promotion<unsigned long, unsigned long long>(match::first, match::first);
+    test_integer_promotion<unsigned long long, unsigned long long>(match::first, match::first);
+  }
+}
+
 TEST_CONSTEXPR_CXX20 bool test() {
   types::for_each(types::integer_types(), TestTypes<char>());
   types::for_each(types::integer_types(), TestTypes<int>());
@@ -122,6 +388,8 @@ TEST_CONSTEXPR_CXX20 bool test() {
   Test<TriviallyComparable<wchar_t>, TriviallyComparable<wchar_t>>().operator()<TriviallyComparable<wchar_t>*>();
 #endif
 
+  test_integer_promotions();
+
   return true;
 }
 



More information about the libcxx-commits mailing list