[libcxx-commits] [libcxx] [libc++] <experimental/simd> Add compound assignment operators for simd reference (PR #86761)

via libcxx-commits libcxx-commits at lists.llvm.org
Sun Apr 7 21:09:51 PDT 2024


https://github.com/joy2myself updated https://github.com/llvm/llvm-project/pull/86761

>From c3eab8ed4a37cf568b55a75ac8c9b16304a453bc Mon Sep 17 00:00:00 2001
From: Yin Zhang <zhangyin2018 at iscas.ac.cn>
Date: Tue, 9 Jan 2024 17:04:16 +0800
Subject: [PATCH] [libc++] <experimental/simd> Add compound assignment
 operators for simd reference

---
 libcxx/docs/Status/ParallelismProjects.csv    |   1 +
 .../include/experimental/__simd/reference.h   |  23 ++++
 .../reference_operators.pass.cpp              | 105 ++++++++++++++++++
 .../test/std/experimental/simd/test_utils.h   |  11 +-
 4 files changed, 135 insertions(+), 5 deletions(-)
 create mode 100644 libcxx/test/std/experimental/simd/simd.reference/reference_operators.pass.cpp

diff --git a/libcxx/docs/Status/ParallelismProjects.csv b/libcxx/docs/Status/ParallelismProjects.csv
index 06da008ac5fe90..89e727af488f07 100644
--- a/libcxx/docs/Status/ParallelismProjects.csv
+++ b/libcxx/docs/Status/ParallelismProjects.csv
@@ -16,6 +16,7 @@ Section,Description,Dependencies,Assignee,Complete
 | `[parallel.simd.whereexpr] <https://wg21.link/N4808>`_, "Where expression class templates", None, Yin Zhang, |In Progress|
 | `[parallel.simd.reference] <https://wg21.link/N4808>`_, "`Element references operator value_type() <https://github.com/llvm/llvm-project/pull/68960>`_", None, Yin Zhang, |Complete|
 | `[parallel.simd.reference] <https://wg21.link/N4808>`_, "`Element references operator= <https://github.com/llvm/llvm-project/pull/70020>`_", None, Yin Zhang, |Complete|
+| `[parallel.simd.reference] <https://wg21.link/N4808>`_, "`Element references compound assignment operators <https://github.com/llvm/llvm-project/pull/86761>`_", None, Yin Zhang, |Complete|
 | `[parallel.simd.class] <https://wg21.link/N4808>`_, "`Class template simd declaration and alias <https://reviews.llvm.org/D144362>`_", [parallel.simd.abi], Yin Zhang, |Complete|
 | `[parallel.simd.class] <https://wg21.link/N4808>`_, "`simd<>::size() <https://reviews.llvm.org/D144363>`_", [parallel.simd.traits] simd_size[_v], Yin Zhang, |Complete|
 | `[parallel.simd.class] <https://wg21.link/N4808>`_, "`simd default constructor <https://github.com/llvm/llvm-project/pull/70424>`_", None, Yin Zhang, |Complete|
diff --git a/libcxx/include/experimental/__simd/reference.h b/libcxx/include/experimental/__simd/reference.h
index 7efbba96ec71b1..6fc7cd93050470 100644
--- a/libcxx/include/experimental/__simd/reference.h
+++ b/libcxx/include/experimental/__simd/reference.h
@@ -12,6 +12,9 @@
 
 #include <__type_traits/is_assignable.h>
 #include <__type_traits/is_same.h>
+#include <__type_traits/is_void.h>
+#include <__type_traits/void_t.h>
+#include <__utility/declval.h>
 #include <__utility/forward.h>
 #include <cstddef>
 #include <experimental/__config>
@@ -55,6 +58,26 @@ class __simd_reference {
     __set(static_cast<value_type>(std::forward<_Up>(__v)));
     return {__s_, __idx_};
   }
+
+#  define _LIBCXX_SIMD_REFERENCE_OP_(__op)                                                                             \
+    template <                                                                                                         \
+        class _Up,                                                                                                     \
+        enable_if_t<is_void_v<void_t<decltype(std::declval<value_type&>() __op## = std::declval<_Up>())>>, int> = 0>   \
+        __simd_reference _LIBCPP_HIDE_FROM_ABI operator __op##=(_Up&& __v) && noexcept {                               \
+      __set(__get() __op static_cast<value_type>(std::forward<_Up>(__v)));                                             \
+      return {__s_, __idx_};                                                                                           \
+    }
+  _LIBCXX_SIMD_REFERENCE_OP_(+)
+  _LIBCXX_SIMD_REFERENCE_OP_(-)
+  _LIBCXX_SIMD_REFERENCE_OP_(*)
+  _LIBCXX_SIMD_REFERENCE_OP_(/)
+  _LIBCXX_SIMD_REFERENCE_OP_(%)
+  _LIBCXX_SIMD_REFERENCE_OP_(&)
+  _LIBCXX_SIMD_REFERENCE_OP_(|)
+  _LIBCXX_SIMD_REFERENCE_OP_(^)
+  _LIBCXX_SIMD_REFERENCE_OP_(<<)
+  _LIBCXX_SIMD_REFERENCE_OP_(>>)
+#  undef _LIBCXX_SIMD_REFERENCE_OP_
 };
 
 } // namespace parallelism_v2
diff --git a/libcxx/test/std/experimental/simd/simd.reference/reference_operators.pass.cpp b/libcxx/test/std/experimental/simd/simd.reference/reference_operators.pass.cpp
new file mode 100644
index 00000000000000..a671f32303f0ff
--- /dev/null
+++ b/libcxx/test/std/experimental/simd/simd.reference/reference_operators.pass.cpp
@@ -0,0 +1,105 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++03, c++11, c++14
+
+// <experimental/simd>
+//
+// [simd.reference]
+// template<class U> reference+=(U&& x) && noexcept;
+// template<class U> reference-=(U&& x) && noexcept;
+// template<class U> reference*=(U&& x) && noexcept;
+// template<class U> reference/=(U&& x) && noexcept;
+// template<class U> reference%=(U&& x) && noexcept;
+// template<class U> reference|=(U&& x) && noexcept;
+// template<class U> reference&=(U&& x) && noexcept;
+// template<class U> reference^=(U&& x) && noexcept;
+// template<class U> reference<<=(U&& x) && noexcept;
+// template<class U> reference>>=(U&& x) && noexcept;
+
+#include "../test_utils.h"
+#include <experimental/simd>
+
+namespace ex = std::experimental::parallelism_v2;
+
+#define LIBCXX_SIMD_REFERENCE_OP_(op, name)                                                                            \
+  template <class T, class SimdAbi>                                                                                    \
+  struct SimdReferenceOperatorHelper##name {                                                                           \
+    template <class U>                                                                                                 \
+    void operator()() const {                                                                                          \
+      ex::simd<T, SimdAbi> origin_simd([](T i) { return i; });                                                         \
+      for (size_t i = 0; i < origin_simd.size(); ++i) {                                                                \
+        static_assert(noexcept(origin_simd[i] op## = static_cast<U>(2)));                                              \
+        origin_simd[i] op## = static_cast<U>(2);                                                                       \
+        assert((T)origin_simd[i] == (T)(static_cast<T>(i) op static_cast<T>(std::forward<U>(2))));                     \
+      }                                                                                                                \
+    }                                                                                                                  \
+  };
+LIBCXX_SIMD_REFERENCE_OP_(+, Plus)
+LIBCXX_SIMD_REFERENCE_OP_(-, Minus)
+LIBCXX_SIMD_REFERENCE_OP_(*, Multiplies)
+LIBCXX_SIMD_REFERENCE_OP_(/, Divides)
+LIBCXX_SIMD_REFERENCE_OP_(%, Modulus)
+LIBCXX_SIMD_REFERENCE_OP_(&, BitAnd)
+LIBCXX_SIMD_REFERENCE_OP_(|, BitOr)
+LIBCXX_SIMD_REFERENCE_OP_(^, BitXor)
+LIBCXX_SIMD_REFERENCE_OP_(<<, ShiftLeft)
+LIBCXX_SIMD_REFERENCE_OP_(>>, ShiftRight)
+#undef LIBCXX_SIMD_REFERENCE_OP_
+
+#define LIBCXX_SIMD_MASK_REFERENCE_OP_(op, name)                                                                       \
+  template <class T, class SimdAbi>                                                                                    \
+  struct MaskReferenceOperatorHelper##name {                                                                           \
+    template <class U>                                                                                                 \
+    void operator()() const {                                                                                          \
+      ex::simd<T, SimdAbi> origin_simd_mask(true);                                                                     \
+      for (size_t i = 0; i < origin_simd_mask.size(); ++i) {                                                           \
+        static_assert(noexcept(origin_simd_mask[i] op## = static_cast<U>(i % 2)));                                     \
+        origin_simd_mask[i] op## = static_cast<U>(i % 2);                                                              \
+        assert((bool)origin_simd_mask[i] == (bool)(true op static_cast<bool>(std::forward<U>(i % 2))));                \
+      }                                                                                                                \
+    }                                                                                                                  \
+  };
+LIBCXX_SIMD_MASK_REFERENCE_OP_(&, BitAnd)
+LIBCXX_SIMD_MASK_REFERENCE_OP_(|, BitOr)
+LIBCXX_SIMD_MASK_REFERENCE_OP_(^, BitXor)
+#undef LIBCXX_SIMD_MASK_REFERENCE_OP_
+
+template <class T, std::size_t>
+struct CheckReferenceArithOperators {
+  template <class SimdAbi>
+  void operator()() {
+    types::for_each(simd_test_types(), SimdReferenceOperatorHelperPlus<T, SimdAbi>());
+    types::for_each(simd_test_types(), SimdReferenceOperatorHelperMinus<T, SimdAbi>());
+    types::for_each(simd_test_types(), SimdReferenceOperatorHelperMultiplies<T, SimdAbi>());
+    types::for_each(simd_test_types(), SimdReferenceOperatorHelperDivides<T, SimdAbi>());
+  }
+};
+
+template <class T, std::size_t>
+struct CheckReferenceIntOperators {
+  template <class SimdAbi>
+  void operator()() {
+    types::for_each(simd_test_integer_types(), SimdReferenceOperatorHelperModulus<T, SimdAbi>());
+    types::for_each(simd_test_integer_types(), SimdReferenceOperatorHelperBitAnd<T, SimdAbi>());
+    types::for_each(simd_test_integer_types(), SimdReferenceOperatorHelperBitOr<T, SimdAbi>());
+    types::for_each(simd_test_integer_types(), SimdReferenceOperatorHelperBitXor<T, SimdAbi>());
+    types::for_each(simd_test_integer_types(), SimdReferenceOperatorHelperShiftLeft<T, SimdAbi>());
+    types::for_each(simd_test_integer_types(), SimdReferenceOperatorHelperShiftRight<T, SimdAbi>());
+
+    types::for_each(simd_test_integer_types(), MaskReferenceOperatorHelperBitAnd<T, SimdAbi>());
+    types::for_each(simd_test_integer_types(), MaskReferenceOperatorHelperBitOr<T, SimdAbi>());
+    types::for_each(simd_test_integer_types(), MaskReferenceOperatorHelperBitXor<T, SimdAbi>());
+  }
+};
+
+int main(int, char**) {
+  test_all_simd_abi<CheckReferenceArithOperators>();
+  types::for_each(types::integer_types(), TestAllSimdAbiFunctor<CheckReferenceIntOperators>());
+  return 0;
+}
diff --git a/libcxx/test/std/experimental/simd/test_utils.h b/libcxx/test/std/experimental/simd/test_utils.h
index 9b441608575138..3c227a43c2f4dd 100644
--- a/libcxx/test/std/experimental/simd/test_utils.h
+++ b/libcxx/test/std/experimental/simd/test_utils.h
@@ -50,15 +50,16 @@ using arithmetic_no_bool_types = types::concatenate_t<types::integer_types, type
 
 // For interfaces with vectorizable type template parameters, we only use some common or boundary types
 // as template parameters for testing to ensure that the compilation time of a single test does not exceed.
-using simd_test_types =
+using simd_test_integer_types =
     types::type_list<char,
                      unsigned,
-                     int,
+                     int
 #ifndef TEST_HAS_NO_INT128
-                     __int128_t,
+                     ,
+                     __int128_t
 #endif
-                     float,
-                     double>;
+                     >;
+using simd_test_types = types::concatenate_t<simd_test_integer_types, types::type_list<float, double>>;
 
 template <template <class T, std::size_t N> class Func>
 void test_all_simd_abi() {



More information about the libcxx-commits mailing list