[libc-commits] [libc] 1a4696d - [libc][NFC] Use new approach based on types to code memset

Guillaume Chatelet via libc-commits libc-commits at lists.llvm.org
Tue Apr 11 04:32:19 PDT 2023


Author: Guillaume Chatelet
Date: 2023-04-11T11:32:09Z
New Revision: 1a4696d9ec26eb4f7e9768bc279235667232e92a

URL: https://github.com/llvm/llvm-project/commit/1a4696d9ec26eb4f7e9768bc279235667232e92a
DIFF: https://github.com/llvm/llvm-project/commit/1a4696d9ec26eb4f7e9768bc279235667232e92a.diff

LOG: [libc][NFC] Use new approach based on types to code memset

Added: 
    

Modified: 
    libc/src/string/memory_utils/memset_implementations.h
    libc/src/string/memory_utils/op_aarch64.h
    libc/src/string/memory_utils/op_generic.h
    libc/test/src/string/memory_utils/op_tests.cpp

Removed: 
    


################################################################################
diff  --git a/libc/src/string/memory_utils/memset_implementations.h b/libc/src/string/memory_utils/memset_implementations.h
index 16c11470572b5..dfc7df87aceb2 100644
--- a/libc/src/string/memory_utils/memset_implementations.h
+++ b/libc/src/string/memory_utils/memset_implementations.h
@@ -26,86 +26,101 @@ namespace __llvm_libc {
 inline_memset_embedded_tiny(Ptr dst, uint8_t value, size_t count) {
   LIBC_LOOP_NOUNROLL
   for (size_t offset = 0; offset < count; ++offset)
-    generic::Memset<1, 1>::block(dst + offset, value);
+    generic::Memset<uint8_t>::block(dst + offset, value);
 }
 
 #if defined(LIBC_TARGET_ARCH_IS_X86)
-template <size_t MaxSize>
 [[maybe_unused]] LIBC_INLINE static void
 inline_memset_x86(Ptr dst, uint8_t value, size_t count) {
+#if defined(__AVX512F__)
+  using uint128_t = uint8x16_t;
+  using uint256_t = uint8x32_t;
+  using uint512_t = uint8x64_t;
+#elif defined(__AVX__)
+  using uint128_t = uint8x16_t;
+  using uint256_t = uint8x32_t;
+  using uint512_t = cpp::array<uint8x32_t, 2>;
+#elif defined(__SSE2__)
+  using uint128_t = uint8x16_t;
+  using uint256_t = cpp::array<uint8x16_t, 2>;
+  using uint512_t = cpp::array<uint8x16_t, 4>;
+#else
+  using uint128_t = cpp::array<uint64_t, 2>;
+  using uint256_t = cpp::array<uint64_t, 4>;
+  using uint512_t = cpp::array<uint64_t, 8>;
+#endif
+
   if (count == 0)
     return;
   if (count == 1)
-    return generic::Memset<1, MaxSize>::block(dst, value);
+    return generic::Memset<uint8_t>::block(dst, value);
   if (count == 2)
-    return generic::Memset<2, MaxSize>::block(dst, value);
+    return generic::Memset<uint16_t>::block(dst, value);
   if (count == 3)
-    return generic::Memset<3, MaxSize>::block(dst, value);
+    return generic::Memset<uint16_t, uint8_t>::block(dst, value);
   if (count <= 8)
-    return generic::Memset<4, MaxSize>::head_tail(dst, value, count);
+    return generic::Memset<uint32_t>::head_tail(dst, value, count);
   if (count <= 16)
-    return generic::Memset<8, MaxSize>::head_tail(dst, value, count);
+    return generic::Memset<uint64_t>::head_tail(dst, value, count);
   if (count <= 32)
-    return generic::Memset<16, MaxSize>::head_tail(dst, value, count);
+    return generic::Memset<uint128_t>::head_tail(dst, value, count);
   if (count <= 64)
-    return generic::Memset<32, MaxSize>::head_tail(dst, value, count);
+    return generic::Memset<uint256_t>::head_tail(dst, value, count);
   if (count <= 128)
-    return generic::Memset<64, MaxSize>::head_tail(dst, value, count);
+    return generic::Memset<uint512_t>::head_tail(dst, value, count);
   // Aligned loop
-  generic::Memset<32, MaxSize>::block(dst, value);
+  generic::Memset<uint256_t>::block(dst, value);
   align_to_next_boundary<32>(dst, count);
-  return generic::Memset<32, MaxSize>::loop_and_tail(dst, value, count);
+  return generic::Memset<uint256_t>::loop_and_tail(dst, value, count);
 }
 #endif // defined(LIBC_TARGET_ARCH_IS_X86)
 
 #if defined(LIBC_TARGET_ARCH_IS_AARCH64)
-template <size_t MaxSize>
 [[maybe_unused]] LIBC_INLINE static void
 inline_memset_aarch64(Ptr dst, uint8_t value, size_t count) {
+  static_assert(aarch64::kNeon, "aarch64 supports vector types");
+  using uint128_t = uint8x16_t;
+  using uint256_t = uint8x32_t;
+  using uint512_t = uint8x64_t;
   if (count == 0)
     return;
   if (count <= 3) {
-    generic::Memset<1, MaxSize>::block(dst, value);
+    generic::Memset<uint8_t>::block(dst, value);
     if (count > 1)
-      generic::Memset<2, MaxSize>::tail(dst, value, count);
+      generic::Memset<uint16_t>::tail(dst, value, count);
     return;
   }
   if (count <= 8)
-    return generic::Memset<4, MaxSize>::head_tail(dst, value, count);
+    return generic::Memset<uint32_t>::head_tail(dst, value, count);
   if (count <= 16)
-    return generic::Memset<8, MaxSize>::head_tail(dst, value, count);
+    return generic::Memset<uint64_t>::head_tail(dst, value, count);
   if (count <= 32)
-    return generic::Memset<16, MaxSize>::head_tail(dst, value, count);
+    return generic::Memset<uint128_t>::head_tail(dst, value, count);
   if (count <= (32 + 64)) {
-    generic::Memset<32, MaxSize>::block(dst, value);
+    generic::Memset<uint256_t>::block(dst, value);
     if (count <= 64)
-      return generic::Memset<32, MaxSize>::tail(dst, value, count);
-    generic::Memset<32, MaxSize>::block(dst + 32, value);
-    generic::Memset<32, MaxSize>::tail(dst, value, count);
+      return generic::Memset<uint256_t>::tail(dst, value, count);
+    generic::Memset<uint256_t>::block(dst + 32, value);
+    generic::Memset<uint256_t>::tail(dst, value, count);
     return;
   }
   if (count >= 448 && value == 0 && aarch64::neon::hasZva()) {
-    generic::Memset<64, MaxSize>::block(dst, 0);
+    generic::Memset<uint512_t>::block(dst, 0);
     align_to_next_boundary<64>(dst, count);
-    return aarch64::neon::BzeroCacheLine<64>::loop_and_tail(dst, 0, count);
+    return aarch64::neon::BzeroCacheLine::loop_and_tail(dst, 0, count);
   } else {
-    generic::Memset<16, MaxSize>::block(dst, value);
+    generic::Memset<uint128_t>::block(dst, value);
     align_to_next_boundary<16>(dst, count);
-    return generic::Memset<64, MaxSize>::loop_and_tail(dst, value, count);
+    return generic::Memset<uint512_t>::loop_and_tail(dst, value, count);
   }
 }
 #endif // defined(LIBC_TARGET_ARCH_IS_AARCH64)
 
 LIBC_INLINE static void inline_memset(Ptr dst, uint8_t value, size_t count) {
 #if defined(LIBC_TARGET_ARCH_IS_X86)
-  static constexpr size_t kMaxSize = x86::kAvx512F ? 64
-                                     : x86::kAvx   ? 32
-                                     : x86::kSse2  ? 16
-                                                   : 8;
-  return inline_memset_x86<kMaxSize>(dst, value, count);
+  return inline_memset_x86(dst, value, count);
 #elif defined(LIBC_TARGET_ARCH_IS_AARCH64)
-  static constexpr size_t kMaxSize = aarch64::kNeon ? 16 : 8;
-  return inline_memset_aarch64<kMaxSize>(dst, value, count);
+  return inline_memset_aarch64(dst, value, count);
 #else
   return inline_memset_embedded_tiny(dst, value, count);
 #endif

diff  --git a/libc/src/string/memory_utils/op_aarch64.h b/libc/src/string/memory_utils/op_aarch64.h
index f9aabd0fbcade..e8c8b211e57b5 100644
--- a/libc/src/string/memory_utils/op_aarch64.h
+++ b/libc/src/string/memory_utils/op_aarch64.h
@@ -30,11 +30,10 @@ static inline constexpr bool kNeon = LLVM_LIBC_IS_DEFINED(__ARM_NEON);
 
 namespace neon {
 
-template <size_t Size> struct BzeroCacheLine {
-  static constexpr size_t SIZE = Size;
+struct BzeroCacheLine {
+  static constexpr size_t SIZE = 64;
 
   LIBC_INLINE static void block(Ptr dst, uint8_t) {
-    static_assert(Size == 64);
 #if __SIZEOF_POINTER__ == 4
     asm("dc zva, %w[dst]" : : [dst] "r"(dst) : "memory");
 #else
@@ -43,15 +42,13 @@ template <size_t Size> struct BzeroCacheLine {
   }
 
   LIBC_INLINE static void loop_and_tail(Ptr dst, uint8_t value, size_t count) {
-    static_assert(Size > 1, "a loop of size 1 does not need tail");
     size_t offset = 0;
     do {
       block(dst + offset, value);
       offset += SIZE;
     } while (offset < count - SIZE);
     // Unaligned store, we can't use 'dc zva' here.
-    static constexpr size_t kMaxSize = kNeon ? 16 : 8;
-    generic::Memset<Size, kMaxSize>::tail(dst, value, count);
+    generic::Memset<uint8x64_t>::tail(dst, value, count);
   }
 };
 

diff  --git a/libc/src/string/memory_utils/op_generic.h b/libc/src/string/memory_utils/op_generic.h
index fd63ac67d005a..a7c5636c2d1ca 100644
--- a/libc/src/string/memory_utils/op_generic.h
+++ b/libc/src/string/memory_utils/op_generic.h
@@ -33,8 +33,7 @@
 
 #include <stdint.h>
 
-namespace __llvm_libc::generic {
-
+namespace __llvm_libc {
 // Compiler types using the vector attributes.
 using uint8x1_t = uint8_t __attribute__((__vector_size__(1)));
 using uint8x2_t = uint8_t __attribute__((__vector_size__(2)));
@@ -43,13 +42,14 @@ using uint8x8_t = uint8_t __attribute__((__vector_size__(8)));
 using uint8x16_t = uint8_t __attribute__((__vector_size__(16)));
 using uint8x32_t = uint8_t __attribute__((__vector_size__(32)));
 using uint8x64_t = uint8_t __attribute__((__vector_size__(64)));
+} // namespace __llvm_libc
 
+namespace __llvm_libc::generic {
 // We accept three types of values as elements for generic operations:
 // - scalar : unsigned integral types
 // - vector : compiler types using the vector attributes
 // - array  : a cpp::array<T, N> where T is itself either a scalar or a vector.
 // The following traits help discriminate between these cases.
-
 template <typename T>
 constexpr bool is_scalar_v = cpp::is_integral_v<T> && cpp::is_unsigned_v<T>;
 
@@ -109,23 +109,11 @@ template <typename T> T splat(uint8_t value) {
     T Out;
     // This for loop is optimized out for vector types.
     for (size_t i = 0; i < sizeof(T); ++i)
-      Out[i] = static_cast<uint8_t>(value);
+      Out[i] = value;
     return Out;
   }
 }
 
-template <typename T> void set(Ptr dst, uint8_t value) {
-  static_assert(is_element_type_v<T>);
-  if constexpr (is_scalar_v<T> || is_vector_v<T>) {
-    store<T>(dst, splat<T>(value));
-  } else if constexpr (is_array_v<T>) {
-    using value_type = typename T::value_type;
-    const value_type Splat = splat<value_type>(value);
-    for (size_t I = 0; I < array_size_v<T>; ++I)
-      store<value_type>(dst + (I * sizeof(value_type)), Splat);
-  }
-}
-
 static_assert((UINTPTR_MAX == 4294967295U) ||
                   (UINTPTR_MAX == 18446744073709551615UL),
               "We currently only support 32- or 64-bit platforms");
@@ -149,9 +137,7 @@ constexpr bool is_decreasing_size() {
 }
 
 template <size_t Size, typename... Ts> struct Largest;
-template <size_t Size> struct Largest<Size> {
-  using type = uint8_t;
-};
+template <size_t Size> struct Largest<Size> : cpp::type_identity<uint8_t> {};
 template <size_t Size, typename T, typename... Ts>
 struct Largest<Size, T, Ts...> {
   using next = Largest<Size, Ts...>;
@@ -179,6 +165,11 @@ template <typename First, typename... Ts> struct SupportedTypes {
   using TypeFor = typename details::Largest<Size, First, Ts...>::type;
 };
 
+// Returns the sum of the sizeof of all the TS types.
+template <typename... TS> static constexpr size_t sum_sizeof() {
+  return (... + sizeof(TS));
+}
+
 // Map from sizes to structures offering static load, store and splat methods.
 // Note: On platforms lacking vector support, we use the ArrayType below and
 // decompose the operation in smaller pieces.
@@ -220,27 +211,23 @@ using getTypeFor = cpp::conditional_t<
 
 ///////////////////////////////////////////////////////////////////////////////
 // Memset
-// The MaxSize template argument gives the maximum size handled natively by the
-// platform. For instance on x86 with AVX support this would be 32. If a size
-// greater than MaxSize is requested we break the operation down in smaller
-// pieces of size MaxSize.
 ///////////////////////////////////////////////////////////////////////////////
-template <size_t Size, size_t MaxSize> struct Memset {
-  static_assert(is_power2(MaxSize));
-  static constexpr size_t SIZE = Size;
+
+template <typename T, typename... TS> struct Memset {
+  static constexpr size_t SIZE = sum_sizeof<T, TS...>();
 
   LIBC_INLINE static void block(Ptr dst, uint8_t value) {
-    if constexpr (Size == 3) {
-      Memset<1, MaxSize>::block(dst + 2, value);
-      Memset<2, MaxSize>::block(dst, value);
-    } else {
-      using T = details::getTypeFor<Size, MaxSize>;
-      if constexpr (details::is_void_v<T>) {
-        deferred_static_assert("Unimplemented Size");
-      } else {
-        set<T>(dst, value);
-      }
+    static_assert(is_element_type_v<T>);
+    if constexpr (is_scalar_v<T> || is_vector_v<T>) {
+      store<T>(dst, splat<T>(value));
+    } else if constexpr (is_array_v<T>) {
+      using value_type = typename T::value_type;
+      const auto Splat = splat<value_type>(value);
+      for (size_t I = 0; I < array_size_v<T>; ++I)
+        store<value_type>(dst + (I * sizeof(value_type)), Splat);
     }
+    if constexpr (sizeof...(TS))
+      Memset<TS...>::block(dst + sizeof(T), value);
   }
 
   LIBC_INLINE static void tail(Ptr dst, uint8_t value, size_t count) {
@@ -253,7 +240,7 @@ template <size_t Size, size_t MaxSize> struct Memset {
   }
 
   LIBC_INLINE static void loop_and_tail(Ptr dst, uint8_t value, size_t count) {
-    static_assert(SIZE > 1);
+    static_assert(SIZE > 1, "a loop of size 1 does not need tail");
     size_t offset = 0;
     do {
       block(dst + offset, value);

diff  --git a/libc/test/src/string/memory_utils/op_tests.cpp b/libc/test/src/string/memory_utils/op_tests.cpp
index 7f5d4d4ed460a..b63a629da3f05 100644
--- a/libc/test/src/string/memory_utils/op_tests.cpp
+++ b/libc/test/src/string/memory_utils/op_tests.cpp
@@ -119,24 +119,20 @@ using MemsetImplementations = testing::TypeList<
     builtin::Memset<64>,
 #endif
 #ifdef LLVM_LIBC_HAS_UINT64
-    generic::Memset<8, 8>,  //
-    generic::Memset<16, 8>, //
-    generic::Memset<32, 8>, //
-    generic::Memset<64, 8>, //
+    generic::Memset<uint64_t>, generic::Memset<cpp::array<uint64_t, 2>>,
 #endif
 #ifdef __AVX512F__
-    generic::Memset<64, 64>, // prevents warning about avx512f
+    generic::Memset<uint8x64_t>, generic::Memset<cpp::array<uint8x64_t, 2>>,
 #endif
-    generic::Memset<1, 1>,   //
-    generic::Memset<2, 1>,   //
-    generic::Memset<2, 2>,   //
-    generic::Memset<4, 2>,   //
-    generic::Memset<4, 4>,   //
-    generic::Memset<16, 16>, //
-    generic::Memset<32, 16>, //
-    generic::Memset<64, 16>, //
-    generic::Memset<32, 32>, //
-    generic::Memset<64, 32>  //
+#ifdef __AVX__
+    generic::Memset<uint8x32_t>, generic::Memset<cpp::array<uint8x32_t, 2>>,
+#endif
+#ifdef __SSE2__
+    generic::Memset<uint8x16_t>, generic::Memset<cpp::array<uint8x16_t, 2>>,
+#endif
+    generic::Memset<uint32_t>, generic::Memset<cpp::array<uint32_t, 2>>, //
+    generic::Memset<uint16_t>, generic::Memset<cpp::array<uint16_t, 2>>, //
+    generic::Memset<uint8_t>, generic::Memset<cpp::array<uint8_t, 2>>    //
     >;
 
 // Adapt CheckMemset signature to op implementation signatures.


        


More information about the libc-commits mailing list