[libcxx-commits] [libcxx] [libc++] Optimize the std::mismatch tail (PR #83440)
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Sat Mar 23 07:31:42 PDT 2024
https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/83440
>From 8ddae7d6ac81fb5c44900832d347663a745bcc23 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Tue, 27 Feb 2024 11:45:41 +0100
Subject: [PATCH] [libc++] Optimize mismatch tail
---
.../benchmarks/algorithms/mismatch.bench.cpp | 15 +++++++--
libcxx/include/__algorithm/mismatch.h | 33 ++++++++++++++++++-
.../mismatch/mismatch.pass.cpp | 28 ++++++++++++++++
3 files changed, 72 insertions(+), 4 deletions(-)
diff --git a/libcxx/benchmarks/algorithms/mismatch.bench.cpp b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
index 9274932a764c55..06289068bb0492 100644
--- a/libcxx/benchmarks/algorithms/mismatch.bench.cpp
+++ b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
@@ -10,6 +10,15 @@
#include <benchmark/benchmark.h>
#include <random>
+void BenchmarkSizes(benchmark::internal::Benchmark* Benchmark) {
+ Benchmark->DenseRange(1, 8);
+ for (size_t i = 16; i != 1 << 20; i *= 2) {
+ Benchmark->Arg(i - 1);
+ Benchmark->Arg(i);
+ Benchmark->Arg(i + 1);
+ }
+}
+
// TODO: Look into benchmarking aligned and unaligned memory explicitly
// (currently things happen to be aligned because they are malloced that way)
template <class T>
@@ -24,8 +33,8 @@ static void bm_mismatch(benchmark::State& state) {
benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin()));
}
}
-BENCHMARK(bm_mismatch<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
-BENCHMARK(bm_mismatch<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
-BENCHMARK(bm_mismatch<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_mismatch<char>)->Apply(BenchmarkSizes);
+BENCHMARK(bm_mismatch<short>)->Apply(BenchmarkSizes);
+BENCHMARK(bm_mismatch<int>)->Apply(BenchmarkSizes);
BENCHMARK_MAIN();
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index 4eb693a1f2e9d8..7bda9ba5508366 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -64,7 +64,10 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
constexpr size_t __unroll_count = 4;
constexpr size_t __vec_size = __native_vector_size<_Tp>;
using __vec = __simd_vector<_Tp, __vec_size>;
+
if (!__libcpp_is_constant_evaluated()) {
+ auto __orig_first1 = __first1;
+ auto __last2 = __first2 + (__last1 - __first1);
while (static_cast<size_t>(__last1 - __first1) >= __unroll_count * __vec_size) [[__unlikely__]] {
__vec __lhs[__unroll_count];
__vec __rhs[__unroll_count];
@@ -84,8 +87,36 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
__first1 += __unroll_count * __vec_size;
__first2 += __unroll_count * __vec_size;
}
+
+ // check the remaining 0-3 vectors
+ while (static_cast<size_t>(__last1 - __first1) >= __vec_size) {
+ if (auto __cmp_res = std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2);
+ !std::__all_of(__cmp_res)) {
+ auto __offset = std::__find_first_not_set(__cmp_res);
+ return {__first1 + __offset, __first2 + __offset};
+ }
+ __first1 += __vec_size;
+ __first2 += __vec_size;
+ }
+
+ if (__last1 - __first1 == 0)
+ return {__first1, __first2};
+
+ // Check if we can load elements in fron of the current pointer. If that's the case load a vector at
+ // (last - vector_size) to check the remaining elements
+ if (static_cast<size_t>(__first1 - __orig_first1) >= __vec_size) {
+ __first1 = __last1 - __vec_size;
+ __first2 = __last2 - __vec_size;
+ auto __offset =
+ std::__find_first_not_set(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2));
+ return {__first1 + __offset, __first2 + __offset};
+ } // else loop over the elements individually
+
+ // TODO: Consider vectorizing the loop tail further with
+ // - smaller vectors
+ // - loading bytes out of range if it's known to be safe
}
- // TODO: Consider vectorizing the tail
+
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
}
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index e7f3994d977dcd..55c9eea863c3ff 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -184,5 +184,33 @@ int main(int, char**) {
}
}
+ { // check the tail of the vectorized loop
+ for (size_t vec_size = 1; vec_size != 256; ++vec_size) {
+ {
+ std::vector<char> lhs(256);
+ std::vector<char> rhs(256);
+
+ check<char*>(lhs, rhs, lhs.size());
+ lhs.back() = 1;
+ check<char*>(lhs, rhs, lhs.size() - 1);
+ lhs.back() = 0;
+ rhs.back() = 1;
+ check<char*>(lhs, rhs, lhs.size() - 1);
+ rhs.back() = 0;
+ }
+ {
+ std::vector<int> lhs(256);
+ std::vector<int> rhs(256);
+
+ check<int*>(lhs, rhs, lhs.size());
+ lhs.back() = 1;
+ check<int*>(lhs, rhs, lhs.size() - 1);
+ lhs.back() = 0;
+ rhs.back() = 1;
+ check<int*>(lhs, rhs, lhs.size() - 1);
+ rhs.back() = 0;
+ }
+ }
+ }
return 0;
}
More information about the libcxx-commits
mailing list