[clang] [HIP] fix host min/max in header (PR #82956)

Yaxun Liu via cfe-commits cfe-commits at lists.llvm.org
Mon Feb 26 06:48:14 PST 2024


https://github.com/yxsamliu updated https://github.com/llvm/llvm-project/pull/82956

>From aa50cadf0baf84ea38379fd3276f306a27164007 Mon Sep 17 00:00:00 2001
From: "Yaxun (Sam) Liu" <yaxun.liu at amd.com>
Date: Sun, 25 Feb 2024 11:13:40 -0500
Subject: [PATCH] [HIP] fix host min/max in header

CUDA defines min/max functions for host in global namespace.
HIP header needs to define them too to be compatible.
Currently only min/max(int, int) is defined. This causes
wrong result for arguments that are out of range for int.
This patch defines host min/max functions to be compatible
with CUDA.

Fixes: SWDEV-446564
Change-Id: I73363b534db3aa9ac34ee2a136e582b3f71d5dd1
---
 clang/lib/Headers/__clang_hip_math.h | 60 ++++++++++++++++++++++++++--
 1 file changed, 56 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Headers/__clang_hip_math.h b/clang/lib/Headers/__clang_hip_math.h
index 11e1e7d032586f..8e048ab0d383aa 100644
--- a/clang/lib/Headers/__clang_hip_math.h
+++ b/clang/lib/Headers/__clang_hip_math.h
@@ -1306,14 +1306,66 @@ float min(float __x, float __y) { return __builtin_fminf(__x, __y); }
 __DEVICE__
 double min(double __x, double __y) { return __builtin_fmin(__x, __y); }
 
+// Define host min/max functions.
+
 #if !defined(__HIPCC_RTC__) && !defined(__OPENMP_AMDGCN__)
-__host__ inline static int min(int __arg1, int __arg2) {
-  return __arg1 < __arg2 ? __arg1 : __arg2;
+
+#pragma push_macro("DEFINE_MIN_MAX_FUNCTIONS")
+#pragma push_macro("DEFINE_MIN_MAX_FUNCTIONS")
+#define DEFINE_MIN_MAX_FUNCTIONS(ret_type, type1, type2)                       \
+  static inline ret_type min(const type1 __a, const type2 __b) {               \
+    return (__a < __b) ? __a : __b;                                            \
+  }                                                                            \
+  static inline ret_type max(const type1 __a, const type2 __b) {               \
+    return (__a > __b) ? __a : __b;                                            \
+  }
+
+// Define min and max functions for same type comparisons
+DEFINE_MIN_MAX_FUNCTIONS(int, int, int)
+DEFINE_MIN_MAX_FUNCTIONS(unsigned int, unsigned int, unsigned int)
+DEFINE_MIN_MAX_FUNCTIONS(long, long, long)
+DEFINE_MIN_MAX_FUNCTIONS(unsigned long, unsigned long, unsigned long)
+DEFINE_MIN_MAX_FUNCTIONS(long long, long long, long long)
+DEFINE_MIN_MAX_FUNCTIONS(unsigned long long, unsigned long long,
+                         unsigned long long)
+
+// Define min and max functions for all mixed type comparisons
+DEFINE_MIN_MAX_FUNCTIONS(unsigned int, int, unsigned int)
+DEFINE_MIN_MAX_FUNCTIONS(unsigned int, unsigned int, int)
+DEFINE_MIN_MAX_FUNCTIONS(unsigned long, long, unsigned long)
+DEFINE_MIN_MAX_FUNCTIONS(unsigned long, unsigned long, long)
+DEFINE_MIN_MAX_FUNCTIONS(unsigned long long, long long, unsigned long long)
+DEFINE_MIN_MAX_FUNCTIONS(unsigned long long, unsigned long long, long long)
+
+// Floating-point comparisons using built-in functions
+static inline float min(float const __a, float const __b) {
+  return __builtin_fminf(__a, __b);
+}
+static inline double min(double const __a, double const __b) {
+  return __builtin_fmin(__a, __b);
+}
+static inline double min(float const __a, double const __b) {
+  return __builtin_fmin(__a, __b);
+}
+static inline double min(double const __a, float const __b) {
+  return __builtin_fmin(__a, __b);
 }
 
-__host__ inline static int max(int __arg1, int __arg2) {
-  return __arg1 > __arg2 ? __arg1 : __arg2;
+static inline float max(float const __a, float const __b) {
+  return __builtin_fmaxf(__a, __b);
+}
+static inline double max(double const __a, double const __b) {
+  return __builtin_fmax(__a, __b);
+}
+static inline double max(float const __a, double const __b) {
+  return __builtin_fmax(__a, __b);
 }
+static inline double max(double const __a, float const __b) {
+  return __builtin_fmax(__a, __b);
+}
+
+#pragma pop_macro("DEFINE_MIN_MAX_FUNCTIONS")
+
 #endif // !defined(__HIPCC_RTC__) && !defined(__OPENMP_AMDGCN__)
 #endif
 



More information about the cfe-commits mailing list