[Openmp-commits] [llvm] [openmp] [OpenMP][OMPX] Add shfl_down_sync (PR #93311)

Shilei Tian via Openmp-commits openmp-commits at lists.llvm.org
Fri May 24 08:19:44 PDT 2024


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/93311

>From 9c59d0cbbb9d0f8b8ac6afc90537e1625e9de802 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Fri, 24 May 2024 11:19:35 -0400
Subject: [PATCH] [OpenMP][OMPX] Add shfl_down_sync

---
 offload/DeviceRTL/include/Utils.h             |  2 +
 offload/DeviceRTL/src/Mapping.cpp             | 24 ++++++-
 offload/DeviceRTL/src/Utils.cpp               | 15 ++--
 .../offloading/ompx_bare_shfl_down_sync.cpp   | 68 +++++++++++++++++++
 openmp/runtime/src/include/ompx.h.var         | 52 ++++++++++++++
 5 files changed, 155 insertions(+), 6 deletions(-)
 create mode 100644 offload/test/offloading/ompx_bare_shfl_down_sync.cpp

diff --git a/offload/DeviceRTL/include/Utils.h b/offload/DeviceRTL/include/Utils.h
index d43b7f5c95de1..82e2397b5958b 100644
--- a/offload/DeviceRTL/include/Utils.h
+++ b/offload/DeviceRTL/include/Utils.h
@@ -25,6 +25,8 @@ int32_t shuffle(uint64_t Mask, int32_t Var, int32_t SrcLane);
 
 int32_t shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta, int32_t Width);
 
+int64_t shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta, int32_t Width);
+
 uint64_t ballotSync(uint64_t Mask, int32_t Pred);
 
 /// Return \p LowBits and \p HighBits packed into a single 64 bit value.
diff --git a/offload/DeviceRTL/src/Mapping.cpp b/offload/DeviceRTL/src/Mapping.cpp
index 4f39d2a299ee6..c1ce878746a69 100644
--- a/offload/DeviceRTL/src/Mapping.cpp
+++ b/offload/DeviceRTL/src/Mapping.cpp
@@ -364,8 +364,30 @@ _TGT_KERNEL_LANGUAGE(block_id, getBlockIdInKernel)
 _TGT_KERNEL_LANGUAGE(block_dim, getNumberOfThreadsInBlock)
 _TGT_KERNEL_LANGUAGE(grid_dim, getNumberOfBlocksInKernel)
 
-extern "C" uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
+extern "C" {
+uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
   return utils::ballotSync(mask, pred);
 }
 
+int ompx_shfl_down_sync_i(uint64_t mask, int var, unsigned delta, int width) {
+  return utils::shuffleDown(mask, var, delta, width);
+}
+
+float ompx_shfl_down_sync_f(uint64_t mask, float var, unsigned delta,
+                            int width) {
+  return utils::convertViaPun<float>(utils::shuffleDown(
+      mask, utils::convertViaPun<int32_t>(var), delta, width));
+}
+
+long ompx_shfl_down_sync_l(uint64_t mask, long var, unsigned delta, int width) {
+  return utils::shuffleDown(mask, var, delta, width);
+}
+
+double ompx_shfl_down_sync_d(uint64_t mask, double var, unsigned delta,
+                             int width) {
+  return utils::convertViaPun<double>(utils::shuffleDown(
+      mask, utils::convertViaPun<int64_t>(var), delta, width));
+}
+}
+
 #pragma omp end declare target
diff --git a/offload/DeviceRTL/src/Utils.cpp b/offload/DeviceRTL/src/Utils.cpp
index 606e3bec0d33c..53cc803234867 100644
--- a/offload/DeviceRTL/src/Utils.cpp
+++ b/offload/DeviceRTL/src/Utils.cpp
@@ -113,6 +113,15 @@ int32_t utils::shuffleDown(uint64_t Mask, int32_t Var, uint32_t Delta,
   return impl::shuffleDown(Mask, Var, Delta, Width);
 }
 
+int64_t utils::shuffleDown(uint64_t Mask, int64_t Var, uint32_t Delta,
+                           int32_t Width) {
+  uint32_t Lo, Hi;
+  utils::unpack(Var, Lo, Hi);
+  Hi = impl::shuffleDown(Mask, Hi, Delta, Width);
+  Lo = impl::shuffleDown(Mask, Lo, Delta, Width);
+  return utils::pack(Lo, Hi);
+}
+
 uint64_t utils::ballotSync(uint64_t Mask, int32_t Pred) {
   return impl::ballotSync(Mask, Pred);
 }
@@ -125,11 +134,7 @@ int32_t __kmpc_shuffle_int32(int32_t Val, int16_t Delta, int16_t SrcLane) {
 }
 
 int64_t __kmpc_shuffle_int64(int64_t Val, int16_t Delta, int16_t Width) {
-  uint32_t lo, hi;
-  utils::unpack(Val, lo, hi);
-  hi = impl::shuffleDown(lanes::All, hi, Delta, Width);
-  lo = impl::shuffleDown(lanes::All, lo, Delta, Width);
-  return utils::pack(lo, hi);
+  return utils::shuffleDown(lanes::All, Val, Delta, Width);
 }
 }
 
diff --git a/offload/test/offloading/ompx_bare_shfl_down_sync.cpp b/offload/test/offloading/ompx_bare_shfl_down_sync.cpp
new file mode 100644
index 0000000000000..c2f38080c5770
--- /dev/null
+++ b/offload/test/offloading/ompx_bare_shfl_down_sync.cpp
@@ -0,0 +1,68 @@
+// RUN: %libomptarget-compilexx-run-and-check-generic
+//
+// UNSUPPORTED: x86_64-pc-linux-gnu
+// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: aarch64-unknown-linux-gnu
+// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
+// UNSUPPORTED: s390x-ibm-linux-gnu
+// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
+
+#ifdef __AMDGCN_WAVEFRONT_SIZE
+#define WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
+#else
+#define WARP_SIZE 32
+#endif
+
+#include <cassert>
+#include <cmath>
+#include <cstdint>
+#include <cstdio>
+#include <limits>
+#include <ompx.h>
+#include <type_traits>
+
+template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
+bool equal(T LHS, T RHS) {
+  return LHS == RHS;
+}
+
+template <typename T,
+          std::enable_if_t<std::is_floating_point<T>::value, bool> = true>
+bool equal(T LHS, T RHS) {
+  return std::abs(LHS - RHS) < std::numeric_limits<T>::epsilon();
+}
+
+template <typename T> void test() {
+  constexpr const int num_blocks = 1;
+  constexpr const int block_size = 256;
+  constexpr const int N = num_blocks * block_size;
+  T *data = new T[N];
+
+  for (int i = 0; i < N; ++i)
+    data[i] = i;
+
+#pragma omp target teams ompx_bare num_teams(num_blocks)                       \
+    thread_limit(block_size) map(tofrom : data[0 : N])
+  {
+    int tid = ompx_thread_id_x();
+    data[tid] = ompx::shfl_down_sync(~0U, data[tid], 1);
+  }
+
+  for (int i = N - 1; i > 0; i -= WARP_SIZE) {
+    for (int j = i; j > i - WARP_SIZE; --j)
+      assert(equal(data[i], data[i - 1]));
+  }
+
+  delete[] data;
+}
+
+int main(int argc, char *argv[]) {
+  test<int32_t>();
+  test<int64_t>();
+  test<float>();
+  test<double>();
+  // CHECK: PASS
+  printf("PASS\n");
+
+  return 0;
+}
diff --git a/openmp/runtime/src/include/ompx.h.var b/openmp/runtime/src/include/ompx.h.var
index 19851880c3ac3..7f41d6ef92219 100644
--- a/openmp/runtime/src/include/ompx.h.var
+++ b/openmp/runtime/src/include/ompx.h.var
@@ -9,6 +9,12 @@
 #ifndef __OMPX_H
 #define __OMPX_H
 
+#ifdef __AMDGCN_WAVEFRONT_SIZE
+#define __WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
+#else
+#define __WARP_SIZE 32
+#endif
+
 typedef unsigned long uint64_t;
 
 #ifdef __cplusplus
@@ -87,6 +93,22 @@ static inline uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
   __builtin_trap();
 }
 
+/// ompx_shfl_down_sync_{i,f,l,d}
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(TYPE, TY)                \
+  static inline TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var,         \
+                                              unsigned delta, int width) {     \
+    __builtin_trap();                                                          \
+  }
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL
+///}
+
 #pragma omp end declare variant
 
 /// ompx_{sync_block}_{,divergent}
@@ -117,6 +139,20 @@ _TGT_KERNEL_LANGUAGE_DECL_GRID_C(grid_dim)
 
 uint64_t ompx_ballot_sync(uint64_t mask, int pred);
 
+/// ompx_shfl_down_sync_{i,f,l,d}
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY)                          \
+  TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, unsigned delta,       \
+                                int width = __WARP_SIZE);
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
+///}
+
 #ifdef __cplusplus
 }
 #endif
@@ -172,6 +208,22 @@ static inline uint64_t ballot_sync(uint64_t mask, int pred) {
   return ompx_ballot_sync(mask, pred);
 }
 
+/// shfl_down_sync
+///{
+#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY)                          \
+  static inline TYPE shfl_down_sync(uint64_t mask, TYPE var, unsigned delta,   \
+                                    int width = __WARP_SIZE) {                 \
+    return ompx_shfl_down_sync_##TY(mask, var, delta, width);                  \
+  }
+
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l)
+_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d)
+
+#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC
+///}
+
 } // namespace ompx
 #endif
 



More information about the Openmp-commits mailing list