[libc-commits] [libc] 366805e - [LIBC] Add an optimized memcmp implementation for AArch64

Andre Vieira via libc-commits libc-commits at lists.llvm.org
Wed Jul 7 07:59:30 PDT 2021


Author: Andre Vieira
Date: 2021-07-07T15:59:14+01:00
New Revision: 366805ea175e12d98903e35854c9898964fecde2

URL: https://github.com/llvm/llvm-project/commit/366805ea175e12d98903e35854c9898964fecde2
DIFF: https://github.com/llvm/llvm-project/commit/366805ea175e12d98903e35854c9898964fecde2.diff

LOG: [LIBC] Add an optimized memcmp implementation for AArch64

Differential Revision: https://reviews.llvm.org/D105441

Added: 
    libc/src/string/aarch64/memcmp.cpp
    libc/src/string/memory_utils/elements_aarch64.h

Modified: 
    libc/src/string/CMakeLists.txt
    libc/src/string/memory_utils/elements.h
    libc/test/src/string/CMakeLists.txt
    libc/test/src/string/memcmp_test.cpp

Removed: 
    


################################################################################
diff  --git a/libc/src/string/CMakeLists.txt b/libc/src/string/CMakeLists.txt
index 97ac99cb89fe7..4788baf1966e4 100644
--- a/libc/src/string/CMakeLists.txt
+++ b/libc/src/string/CMakeLists.txt
@@ -295,7 +295,7 @@ endif()
 
 function(add_memcmp memcmp_name)
   add_implementation(memcmp ${memcmp_name}
-    SRCS ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp
+    SRCS ${LIBC_MEMCMP_SRC}
     HDRS ${LIBC_SOURCE_DIR}/src/string/memcmp.h
     DEPENDS
       .memory_utils.memory_utils
@@ -307,13 +307,19 @@ function(add_memcmp memcmp_name)
 endfunction()
 
 if(${LIBC_TARGET_ARCHITECTURE_IS_X86})
+  set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp)
   add_memcmp(memcmp_x86_64_opt_sse2   COMPILE_OPTIONS -march=k8             REQUIRE SSE2)
   add_memcmp(memcmp_x86_64_opt_sse4   COMPILE_OPTIONS -march=nehalem        REQUIRE SSE4_2)
   add_memcmp(memcmp_x86_64_opt_avx2   COMPILE_OPTIONS -march=haswell        REQUIRE AVX2)
   add_memcmp(memcmp_x86_64_opt_avx512 COMPILE_OPTIONS -march=skylake-avx512 REQUIRE AVX512F)
   add_memcmp(memcmp_opt_host          COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
   add_memcmp(memcmp)
+elseif(${LIBC_TARGET_ARCHITECTURE_IS_AARCH64})
+  set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/aarch64/memcmp.cpp)
+  add_memcmp(memcmp)
+  add_memcmp(memcmp_opt_host          COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
 else()
+  set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp)
   add_memcmp(memcmp_opt_host          COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
   add_memcmp(memcmp)
 endif()

diff  --git a/libc/src/string/aarch64/memcmp.cpp b/libc/src/string/aarch64/memcmp.cpp
new file mode 100644
index 0000000000000..503c239e122b7
--- /dev/null
+++ b/libc/src/string/aarch64/memcmp.cpp
@@ -0,0 +1,59 @@
+//===-- Implementation of memcmp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/string/memcmp.h"
+#include "src/__support/common.h"
+#include "src/string/memory_utils/elements.h"
+#include <stddef.h> // size_t
+
+namespace __llvm_libc {
+namespace aarch64 {
+
+static int memcmp_impl(const char *lhs, const char *rhs, size_t count) {
+  if (count == 0)
+    return 0;
+  if (count == 1)
+    return ThreeWayCompare<_1>(lhs, rhs);
+  else if (count == 2)
+    return ThreeWayCompare<_2>(lhs, rhs);
+  else if (count == 3)
+    return ThreeWayCompare<_3>(lhs, rhs);
+  else if (count < 8)
+    return ThreeWayCompare<HeadTail<_4>>(lhs, rhs, count);
+  else if (count < 16)
+    return ThreeWayCompare<HeadTail<_8>>(lhs, rhs, count);
+  else if (count < 128) {
+    if (Equals<_16>(lhs, rhs)) {
+      if (count < 32)
+        return ThreeWayCompare<Tail<_16>>(lhs, rhs, count);
+      else {
+        if (Equals<_16>(lhs + 16, rhs + 16)) {
+          if (count < 64)
+            return ThreeWayCompare<Tail<_32>>(lhs, rhs, count);
+          if (count < 128)
+            return ThreeWayCompare<Loop<_16>>(lhs + 32, rhs + 32, count - 32);
+        } else
+          return ThreeWayCompare<_16>(lhs + count - 32, rhs + count - 32);
+      }
+    }
+    return ThreeWayCompare<_16>(lhs, rhs);
+  } else
+    return ThreeWayCompare<Align<_16, Arg::Lhs>::Then<Loop<_32>>>(lhs, rhs,
+                                                                  count);
+}
+} // namespace aarch64
+
+LLVM_LIBC_FUNCTION(int, memcmp,
+                   (const void *lhs, const void *rhs, size_t count)) {
+
+  const char *_lhs = reinterpret_cast<const char *>(lhs);
+  const char *_rhs = reinterpret_cast<const char *>(rhs);
+  return aarch64::memcmp_impl(_lhs, _rhs, count);
+}
+
+} // namespace __llvm_libc

diff  --git a/libc/src/string/memory_utils/elements.h b/libc/src/string/memory_utils/elements.h
index 2442da760217e..48fb0084d610f 100644
--- a/libc/src/string/memory_utils/elements.h
+++ b/libc/src/string/memory_utils/elements.h
@@ -211,8 +211,8 @@ template <typename T> struct HeadTail {
   }
 
   static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) {
-    if (const int result = T::ThreeWayCompare(lhs, rhs))
-      return result;
+    if (!T::Equals(lhs, rhs))
+      return T::ThreeWayCompare(lhs, rhs);
     return Tail<T>::ThreeWayCompare(lhs, rhs, size);
   }
 
@@ -251,8 +251,8 @@ template <typename T> struct Loop {
 
   static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) {
     for (size_t offset = 0; offset < size - T::kSize; offset += T::kSize)
-      if (const int result = T::ThreeWayCompare(lhs + offset, rhs + offset))
-        return result;
+      if (!T::Equals(lhs + offset, rhs + offset))
+        return T::ThreeWayCompare(lhs + offset, rhs + offset);
     return Tail<T>::ThreeWayCompare(lhs, rhs, size);
   }
 
@@ -327,8 +327,8 @@ template <typename AlignmentT, Arg AlignOn> struct Align {
     }
 
     static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) {
-      if (const int result = AlignmentT::ThreeWayCompare(lhs, rhs))
-        return result;
+      if (!AlignmentT::Equals(lhs, rhs))
+        return AlignmentT::ThreeWayCompare(lhs, rhs);
       internal::AlignHelper<AlignOn, Alignment>::Bump(lhs, rhs, size);
       return NextT::ThreeWayCompare(lhs, rhs, size);
     }
@@ -370,12 +370,18 @@ template <size_t Size> struct Builtin {
 #endif
   }
 
+#if __has_builtin(__builtin_memcmp_inline)
+#define LLVM_LIBC_MEMCMP __builtin_memcmp_inline
+#else
+#define LLVM_LIBC_MEMCMP __builtin_memcmp
+#endif
+
   static bool Equals(const char *lhs, const char *rhs) {
-    return __builtin_memcmp(lhs, rhs, kSize) == 0;
+    return LLVM_LIBC_MEMCMP(lhs, rhs, kSize) == 0;
   }
 
   static int ThreeWayCompare(const char *lhs, const char *rhs) {
-    return __builtin_memcmp(lhs, rhs, kSize);
+    return LLVM_LIBC_MEMCMP(lhs, rhs, kSize);
   }
 
   static void SplatSet(char *dst, const unsigned char value) {
@@ -428,6 +434,8 @@ template <typename T> struct Scalar {
     Store(dst, GetSplattedValue(value));
   }
 
+  static int ScalarThreeWayCompare(T a, T b);
+
 private:
   static T Load(const char *ptr) {
     T value;
@@ -440,7 +448,6 @@ template <typename T> struct Scalar {
   static T GetSplattedValue(const unsigned char value) {
     return T(~0) / T(0xFF) * T(value);
   }
-  static int ScalarThreeWayCompare(T a, T b);
 };
 
 template <>
@@ -457,23 +464,15 @@ inline int Scalar<uint16_t>::ScalarThreeWayCompare(uint16_t a, uint16_t b) {
 }
 template <>
 inline int Scalar<uint32_t>::ScalarThreeWayCompare(uint32_t a, uint32_t b) {
-  const int64_t la = Endian::ToBigEndian(a);
-  const int64_t lb = Endian::ToBigEndian(b);
-  if (la < lb)
-    return -1;
-  if (la > lb)
-    return 1;
-  return 0;
+  const uint32_t la = Endian::ToBigEndian(a);
+  const uint32_t lb = Endian::ToBigEndian(b);
+  return la > lb ? 1 : la < lb ? -1 : 0;
 }
 template <>
 inline int Scalar<uint64_t>::ScalarThreeWayCompare(uint64_t a, uint64_t b) {
-  const __int128_t la = Endian::ToBigEndian(a);
-  const __int128_t lb = Endian::ToBigEndian(b);
-  if (la < lb)
-    return -1;
-  if (la > lb)
-    return 1;
-  return 0;
+  const uint64_t la = Endian::ToBigEndian(a);
+  const uint64_t lb = Endian::ToBigEndian(b);
+  return la > lb ? 1 : la < lb ? -1 : 0;
 }
 
 using UINT8 = Scalar<uint8_t>;   // 1 Byte
@@ -494,6 +493,7 @@ using _128 = Repeated<_8, 16>;
 } // namespace scalar
 } // namespace __llvm_libc
 
+#include <src/string/memory_utils/elements_aarch64.h>
 #include <src/string/memory_utils/elements_x86.h>
 
 #endif // LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_H

diff  --git a/libc/src/string/memory_utils/elements_aarch64.h b/libc/src/string/memory_utils/elements_aarch64.h
new file mode 100644
index 0000000000000..7f722afbb6a96
--- /dev/null
+++ b/libc/src/string/memory_utils/elements_aarch64.h
@@ -0,0 +1,68 @@
+//===-- Elementary operations for aarch64 --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H
+#define LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H
+
+#include <src/string/memory_utils/elements.h>
+#include <stddef.h> // size_t
+#include <stdint.h> // uint8_t, uint16_t, uint32_t, uint64_t
+
+#ifdef __ARM_NEON
+#include <arm_neon.h>
+#endif
+
+namespace __llvm_libc {
+namespace aarch64 {
+
+using _1 = __llvm_libc::scalar::_1;
+using _2 = __llvm_libc::scalar::_2;
+using _3 = __llvm_libc::scalar::_3;
+using _4 = __llvm_libc::scalar::_4;
+using _8 = __llvm_libc::scalar::_8;
+using _16 = __llvm_libc::scalar::_16;
+
+#ifdef __ARM_NEON
+struct N32 {
+  static constexpr size_t kSize = 32;
+  static bool Equals(const char *lhs, const char *rhs) {
+    uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs);
+    uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs);
+    uint8x16_t l_1 = vld1q_u8((const uint8_t *)(lhs + 16));
+    uint8x16_t r_1 = vld1q_u8((const uint8_t *)(rhs + 16));
+    uint8x16_t temp = vpmaxq_u8(veorq_u8(l_0, r_0), veorq_u8(l_1, r_1));
+    uint64_t res =
+        vgetq_lane_u64(vreinterpretq_u64_u8(vpmaxq_u8(temp, temp)), 0);
+    return res == 0;
+  }
+  static int ThreeWayCompare(const char *lhs, const char *rhs) {
+    uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs);
+    uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs);
+    uint8x16_t l_1 = vld1q_u8((const uint8_t *)(lhs + 16));
+    uint8x16_t r_1 = vld1q_u8((const uint8_t *)(rhs + 16));
+    uint8x16_t temp = vpmaxq_u8(veorq_u8(l_0, r_0), veorq_u8(l_1, r_1));
+    uint64_t res =
+        vgetq_lane_u64(vreinterpretq_u64_u8(vpmaxq_u8(temp, temp)), 0);
+    if (res == 0)
+      return 0;
+    size_t index = (__builtin_ctzl(res) >> 3) << 2;
+    uint32_t l = *((const uint32_t *)(lhs + index));
+    uint32_t r = *((const uint32_t *)(rhs + index));
+    return __llvm_libc::scalar::_4::ScalarThreeWayCompare(l, r);
+  }
+};
+
+using _32 = N32;
+#else
+using _32 = __llvm_libc::scalar::_32;
+#endif // __ARM_NEON
+
+} // namespace aarch64
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H

diff  --git a/libc/test/src/string/CMakeLists.txt b/libc/test/src/string/CMakeLists.txt
index 722d85a631336..dbf390d561c02 100644
--- a/libc/test/src/string/CMakeLists.txt
+++ b/libc/test/src/string/CMakeLists.txt
@@ -52,27 +52,6 @@ add_libc_unittest(
     libc.src.string.memchr
 )
 
-add_libc_unittest(
-  memcmp_test
-  SUITE
-    libc_string_unittests
-  SRCS
-    memcmp_test.cpp
-  DEPENDS
-    libc.src.string.memcmp
-)
-
-add_libc_unittest(
-  memmove_test
-  SUITE
-    libc_string_unittests
-  SRCS
-    memmove_test.cpp
-  DEPENDS
-    libc.src.string.memcmp
-    libc.src.string.memmove
-)
-
 add_libc_unittest(
   strchr_test
   SUITE
@@ -209,3 +188,5 @@ endfunction()
 add_libc_multi_impl_test(memcpy SRCS memcpy_test.cpp)
 add_libc_multi_impl_test(memset SRCS memset_test.cpp)
 add_libc_multi_impl_test(bzero SRCS bzero_test.cpp)
+add_libc_multi_impl_test(memcmp SRCS memcmp_test.cpp)
+add_libc_multi_impl_test(memmove SRCS memmove_test.cpp)

diff  --git a/libc/test/src/string/memcmp_test.cpp b/libc/test/src/string/memcmp_test.cpp
index 2c1d9bfb9a16a..81b6709205fdc 100644
--- a/libc/test/src/string/memcmp_test.cpp
+++ b/libc/test/src/string/memcmp_test.cpp
@@ -8,6 +8,7 @@
 
 #include "src/string/memcmp.h"
 #include "utils/UnitTest/Test.h"
+#include <cstring>
 
 TEST(LlvmLibcMemcmpTest, CmpZeroByte) {
   const char *lhs = "ab";
@@ -32,3 +33,22 @@ TEST(LlvmLibcMemcmpTest, LhsAfterRhsLexically) {
   const char *rhs = "ab";
   EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, 2), 1);
 }
+
+TEST(LlvmLibcMemcmpTest, Sweep) {
+  static constexpr size_t kMaxSize = 1024;
+  char lhs[kMaxSize];
+  char rhs[kMaxSize];
+
+  memset(lhs, 'a', sizeof(lhs));
+  memset(rhs, 'a', sizeof(rhs));
+  for (int i = 0; i < kMaxSize; ++i)
+    EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, i), 0);
+
+  memset(lhs, 'a', sizeof(lhs));
+  memset(rhs, 'a', sizeof(rhs));
+  for (int i = 0; i < kMaxSize; ++i) {
+    rhs[i] = 'b';
+    EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, kMaxSize), -1);
+    rhs[i] = 'a';
+  }
+}


        


More information about the libc-commits mailing list