[llvm] [OpenMP] Update atomic helpers to just use headers (PR #122185)

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 8 15:28:37 PST 2025


================
@@ -48,51 +50,124 @@ uint32_t inc(uint32_t *Addr, uint32_t V, OrderingTy Ordering,
 /// result is stored in \p *Addr;
 /// {
 
-#define ATOMIC_COMMON_OP(TY)                                                   \
-  TY add(TY *Addr, TY V, OrderingTy Ordering);                                 \
-  TY mul(TY *Addr, TY V, OrderingTy Ordering);                                 \
-  TY load(TY *Addr, OrderingTy Ordering);                                      \
-  void store(TY *Addr, TY V, OrderingTy Ordering);                             \
-  bool cas(TY *Addr, TY ExpectedV, TY DesiredV, OrderingTy OrderingSucc,       \
-           OrderingTy OrderingFail);
-
-#define ATOMIC_FP_ONLY_OP(TY)                                                  \
-  TY min(TY *Addr, TY V, OrderingTy Ordering);                                 \
-  TY max(TY *Addr, TY V, OrderingTy Ordering);
-
-#define ATOMIC_INT_ONLY_OP(TY)                                                 \
-  TY min(TY *Addr, TY V, OrderingTy Ordering);                                 \
-  TY max(TY *Addr, TY V, OrderingTy Ordering);                                 \
-  TY bit_or(TY *Addr, TY V, OrderingTy Ordering);                              \
-  TY bit_and(TY *Addr, TY V, OrderingTy Ordering);                             \
-  TY bit_xor(TY *Addr, TY V, OrderingTy Ordering);
-
-#define ATOMIC_FP_OP(TY)                                                       \
-  ATOMIC_FP_ONLY_OP(TY)                                                        \
-  ATOMIC_COMMON_OP(TY)
-
-#define ATOMIC_INT_OP(TY)                                                      \
-  ATOMIC_INT_ONLY_OP(TY)                                                       \
-  ATOMIC_COMMON_OP(TY)
-
-// This needs to be kept in sync with the header. Also the reason we don't use
-// templates here.
-ATOMIC_INT_OP(int8_t)
-ATOMIC_INT_OP(int16_t)
-ATOMIC_INT_OP(int32_t)
-ATOMIC_INT_OP(int64_t)
-ATOMIC_INT_OP(uint8_t)
-ATOMIC_INT_OP(uint16_t)
-ATOMIC_INT_OP(uint32_t)
-ATOMIC_INT_OP(uint64_t)
-ATOMIC_FP_OP(float)
-ATOMIC_FP_OP(double)
-
-#undef ATOMIC_INT_ONLY_OP
-#undef ATOMIC_FP_ONLY_OP
-#undef ATOMIC_COMMON_OP
-#undef ATOMIC_INT_OP
-#undef ATOMIC_FP_OP
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+bool cas(Ty *Address, V ExpectedV, V DesiredV, atomic::OrderingTy OrderingSucc,
+         atomic::OrderingTy OrderingFail) {
+  return __scoped_atomic_compare_exchange(Address, &ExpectedV, &DesiredV, false,
+                                          OrderingSucc, OrderingFail,
+                                          __MEMORY_SCOPE_DEVICE);
+}
+
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+V add(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+  return __scoped_atomic_fetch_add(Address, Val, Ordering,
+                                   __MEMORY_SCOPE_DEVICE);
+}
+
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+V load(Ty *Address, atomic::OrderingTy Ordering) {
+  return add(Address, Ty(0), Ordering);
+}
+
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+void store(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+  __scoped_atomic_store_n(Address, Val, Ordering, __MEMORY_SCOPE_DEVICE);
+}
+
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+V mul(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+  Ty TypedCurrentVal, TypedResultVal, TypedNewVal;
+  bool Success;
+  do {
+    TypedCurrentVal = atomic::load(Address, Ordering);
+    TypedNewVal = TypedCurrentVal * Val;
+    Success = atomic::cas(Address, TypedCurrentVal, TypedNewVal, Ordering,
+                          atomic::relaxed);
+  } while (!Success);
+  return TypedResultVal;
+}
+
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+utils::enable_if_t<!utils::is_floating_point_v<V>, V>
+max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+  return __scoped_atomic_fetch_max(Address, Val, Ordering,
+                                   __MEMORY_SCOPE_DEVICE);
+}
+
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+utils::enable_if_t<utils::is_same_v<V, float>, V>
+max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+  if (Val >= 0)
+    return max((int32_t *)Address, utils::convertViaPun<int32_t>(Val),
+               Ordering);
+  return min((uint32_t *)Address, utils::convertViaPun<uint32_t>(Val),
+             Ordering);
+}
+
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+utils::enable_if_t<utils::is_same_v<V, double>, V>
+max(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+  if (Val >= 0)
+    return max((int64_t *)Address, utils::convertViaPun<int64_t>(Val),
+               Ordering);
+  return min((uint64_t *)Address, utils::convertViaPun<uint64_t>(Val),
+             Ordering);
+}
+
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+utils::enable_if_t<!utils::is_floating_point_v<V>, V>
+min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+  return __scoped_atomic_fetch_min(Address, Val, Ordering,
+                                   __MEMORY_SCOPE_DEVICE);
+}
+
+// TODO: Implement this with __atomic_fetch_max and remove the duplication.
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+utils::enable_if_t<utils::is_same_v<V, float>, V>
+min(Ty *Address, V Val, atomic::OrderingTy Ordering) {
+  if (Val >= 0)
+    return min((int32_t *)Address, utils::convertViaPun<int32_t>(Val),
+               Ordering);
+  return max((uint32_t *)Address, utils::convertViaPun<uint32_t>(Val),
+             Ordering);
+}
+
+// TODO: Implement this with __atomic_fetch_max and remove the duplication.
+template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
+utils::enable_if_t<utils::is_same_v<V, double>, V>
+min(Ty *Address, utils::remove_addrspace_t<Ty> Val,
+    atomic::OrderingTy Ordering) {
+  if (Val >= 0)
+    return min((int64_t *)Address, utils::convertViaPun<int64_t>(Val),
+               Ordering);
+  return max((uint64_t *)Address, utils::convertViaPun<uint64_t>(Val),
+             Ordering);
----------------
jdoerfert wrote:

I believe these were, and are, missing a convertViaPun on the return value.

https://github.com/llvm/llvm-project/pull/122185


More information about the llvm-commits mailing list