[libcxx-commits] [libcxx] [libc++] Optimize the std::mismatch tail (PR #83440)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Sat Mar 23 07:31:42 PDT 2024


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/83440

>From 8ddae7d6ac81fb5c44900832d347663a745bcc23 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Tue, 27 Feb 2024 11:45:41 +0100
Subject: [PATCH] [libc++] Optimize mismatch tail

---
 .../benchmarks/algorithms/mismatch.bench.cpp  | 15 +++++++--
 libcxx/include/__algorithm/mismatch.h         | 33 ++++++++++++++++++-
 .../mismatch/mismatch.pass.cpp                | 28 ++++++++++++++++
 3 files changed, 72 insertions(+), 4 deletions(-)

diff --git a/libcxx/benchmarks/algorithms/mismatch.bench.cpp b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
index 9274932a764c55..06289068bb0492 100644
--- a/libcxx/benchmarks/algorithms/mismatch.bench.cpp
+++ b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
@@ -10,6 +10,15 @@
 #include <benchmark/benchmark.h>
 #include <random>
 
+void BenchmarkSizes(benchmark::internal::Benchmark* Benchmark) {
+  Benchmark->DenseRange(1, 8);
+  for (size_t i = 16; i != 1 << 20; i *= 2) {
+    Benchmark->Arg(i - 1);
+    Benchmark->Arg(i);
+    Benchmark->Arg(i + 1);
+  }
+}
+
 // TODO: Look into benchmarking aligned and unaligned memory explicitly
 // (currently things happen to be aligned because they are malloced that way)
 template <class T>
@@ -24,8 +33,8 @@ static void bm_mismatch(benchmark::State& state) {
     benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin()));
   }
 }
-BENCHMARK(bm_mismatch<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
-BENCHMARK(bm_mismatch<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
-BENCHMARK(bm_mismatch<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_mismatch<char>)->Apply(BenchmarkSizes);
+BENCHMARK(bm_mismatch<short>)->Apply(BenchmarkSizes);
+BENCHMARK(bm_mismatch<int>)->Apply(BenchmarkSizes);
 
 BENCHMARK_MAIN();
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index 4eb693a1f2e9d8..7bda9ba5508366 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -64,7 +64,10 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
   constexpr size_t __unroll_count = 4;
   constexpr size_t __vec_size     = __native_vector_size<_Tp>;
   using __vec                     = __simd_vector<_Tp, __vec_size>;
+
   if (!__libcpp_is_constant_evaluated()) {
+    auto __orig_first1 = __first1;
+    auto __last2       = __first2 + (__last1 - __first1);
     while (static_cast<size_t>(__last1 - __first1) >= __unroll_count * __vec_size) [[__unlikely__]] {
       __vec __lhs[__unroll_count];
       __vec __rhs[__unroll_count];
@@ -84,8 +87,36 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
       __first1 += __unroll_count * __vec_size;
       __first2 += __unroll_count * __vec_size;
     }
+
+    // check the remaining 0-3 vectors
+    while (static_cast<size_t>(__last1 - __first1) >= __vec_size) {
+      if (auto __cmp_res = std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2);
+          !std::__all_of(__cmp_res)) {
+        auto __offset = std::__find_first_not_set(__cmp_res);
+        return {__first1 + __offset, __first2 + __offset};
+      }
+      __first1 += __vec_size;
+      __first2 += __vec_size;
+    }
+
+    if (__last1 - __first1 == 0)
+      return {__first1, __first2};
+
+    // Check if we can load elements in fron of the current pointer. If that's the case load a vector at
+    // (last - vector_size) to check the remaining elements
+    if (static_cast<size_t>(__first1 - __orig_first1) >= __vec_size) {
+      __first1 = __last1 - __vec_size;
+      __first2 = __last2 - __vec_size;
+      auto __offset =
+          std::__find_first_not_set(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2));
+      return {__first1 + __offset, __first2 + __offset};
+    } // else loop over the elements individually
+
+    // TODO: Consider vectorizing the loop tail further with
+    // - smaller vectors
+    // - loading bytes out of range if it's known to be safe
   }
-  // TODO: Consider vectorizing the tail
+
   return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
 }
 
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index e7f3994d977dcd..55c9eea863c3ff 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -184,5 +184,33 @@ int main(int, char**) {
     }
   }
 
+  { // check the tail of the vectorized loop
+    for (size_t vec_size = 1; vec_size != 256; ++vec_size) {
+      {
+        std::vector<char> lhs(256);
+        std::vector<char> rhs(256);
+
+        check<char*>(lhs, rhs, lhs.size());
+        lhs.back() = 1;
+        check<char*>(lhs, rhs, lhs.size() - 1);
+        lhs.back() = 0;
+        rhs.back() = 1;
+        check<char*>(lhs, rhs, lhs.size() - 1);
+        rhs.back() = 0;
+      }
+      {
+        std::vector<int> lhs(256);
+        std::vector<int> rhs(256);
+
+        check<int*>(lhs, rhs, lhs.size());
+        lhs.back() = 1;
+        check<int*>(lhs, rhs, lhs.size() - 1);
+        lhs.back() = 0;
+        rhs.back() = 1;
+        check<int*>(lhs, rhs, lhs.size() - 1);
+        rhs.back() = 0;
+      }
+    }
+  }
   return 0;
 }



More information about the libcxx-commits mailing list