[libclc] 21508fa - libclc: clspv: fix fma, add vstore and fix inlining issues

Kévin Petit via cfe-commits cfe-commits at lists.llvm.org
Tue May 9 08:52:56 PDT 2023


Author: Kévin Petit
Date: 2023-05-09T16:52:13+01:00
New Revision: 21508fa76914a5e4281dc5bc77cac7f2e8bc3aef

URL: https://github.com/llvm/llvm-project/commit/21508fa76914a5e4281dc5bc77cac7f2e8bc3aef
DIFF: https://github.com/llvm/llvm-project/commit/21508fa76914a5e4281dc5bc77cac7f2e8bc3aef.diff

LOG: libclc: clspv: fix fma, add vstore and fix inlining issues

https://reviews.llvm.org/D147773

Patch by Romaric Jodin <rjodin at google.com>

Added: 
    libclc/clspv/lib/shared/vstore_half.cl
    libclc/clspv/lib/shared/vstore_half.inc

Modified: 
    libclc/CMakeLists.txt
    libclc/clspv/lib/SOURCES
    libclc/clspv/lib/math/fma.cl
    libclc/generic/include/clc/clcfunc.h

Removed: 
    


################################################################################
diff  --git a/libclc/CMakeLists.txt b/libclc/CMakeLists.txt
index 89f08b889ea1e..0eda12670b710 100644
--- a/libclc/CMakeLists.txt
+++ b/libclc/CMakeLists.txt
@@ -271,11 +271,11 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
 			set( spvflags --spirv-max-version=1.1 )
 		elseif( ${ARCH} STREQUAL "clspv" )
 			set( t "spir--" )
-			set( build_flags )
+			set( build_flags "-Wno-unknown-assumption")
 			set( opt_flags -O3 )
 		elseif( ${ARCH} STREQUAL "clspv64" )
 			set( t "spir64--" )
-			set( build_flags )
+			set( build_flags "-Wno-unknown-assumption")
 			set( opt_flags -O3 )
 		else()
 			set( build_flags )

diff  --git a/libclc/clspv/lib/SOURCES b/libclc/clspv/lib/SOURCES
index 98bc71a869b2a..7c369aa379e98 100644
--- a/libclc/clspv/lib/SOURCES
+++ b/libclc/clspv/lib/SOURCES
@@ -1,5 +1,6 @@
 math/fma.cl
 math/nextafter.cl
+shared/vstore_half.cl
 subnormal_config.cl
 ../../generic/lib/geometric/distance.cl
 ../../generic/lib/geometric/length.cl
@@ -45,6 +46,12 @@ subnormal_config.cl
 ../../generic/lib/math/frexp.cl
 ../../generic/lib/math/half_cos.cl
 ../../generic/lib/math/half_divide.cl
+../../generic/lib/math/half_exp.cl
+../../generic/lib/math/half_exp10.cl
+../../generic/lib/math/half_exp2.cl
+../../generic/lib/math/half_log.cl
+../../generic/lib/math/half_log10.cl
+../../generic/lib/math/half_log2.cl
 ../../generic/lib/math/half_powr.cl
 ../../generic/lib/math/half_recip.cl
 ../../generic/lib/math/half_sin.cl

diff  --git a/libclc/clspv/lib/math/fma.cl b/libclc/clspv/lib/math/fma.cl
index fdc8b8b296876..4f2806933eda9 100644
--- a/libclc/clspv/lib/math/fma.cl
+++ b/libclc/clspv/lib/math/fma.cl
@@ -34,6 +34,92 @@ struct fp {
   uint sign;
 };
 
+static uint2 u2_set(uint hi, uint lo) {
+  uint2 res;
+  res.lo = lo;
+  res.hi = hi;
+  return res;
+}
+
+static uint2 u2_set_u(uint val) { return u2_set(0, val); }
+
+static uint2 u2_mul(uint a, uint b) {
+  uint2 res;
+  res.hi = mul_hi(a, b);
+  res.lo = a * b;
+  return res;
+}
+
+static uint2 u2_sll(uint2 val, uint shift) {
+  if (shift == 0)
+    return val;
+  if (shift < 32) {
+    val.hi <<= shift;
+    val.hi |= val.lo >> (32 - shift);
+    val.lo <<= shift;
+  } else {
+    val.hi = val.lo << (shift - 32);
+    val.lo = 0;
+  }
+  return val;
+}
+
+static uint2 u2_srl(uint2 val, uint shift) {
+  if (shift == 0)
+    return val;
+  if (shift < 32) {
+    val.lo >>= shift;
+    val.lo |= val.hi << (32 - shift);
+    val.hi >>= shift;
+  } else {
+    val.lo = val.hi >> (shift - 32);
+    val.hi = 0;
+  }
+  return val;
+}
+
+static uint2 u2_or(uint2 a, uint b) {
+  a.lo |= b;
+  return a;
+}
+
+static uint2 u2_and(uint2 a, uint2 b) {
+  a.lo &= b.lo;
+  a.hi &= b.hi;
+  return a;
+}
+
+static uint2 u2_add(uint2 a, uint2 b) {
+  uint carry = (hadd(a.lo, b.lo) >> 31) & 0x1;
+  a.lo += b.lo;
+  a.hi += b.hi + carry;
+  return a;
+}
+
+static uint2 u2_add_u(uint2 a, uint b) { return u2_add(a, u2_set_u(b)); }
+
+static uint2 u2_inv(uint2 a) {
+  a.lo = ~a.lo;
+  a.hi = ~a.hi;
+  return u2_add_u(a, 1);
+}
+
+static uint u2_clz(uint2 a) {
+  uint leading_zeroes = clz(a.hi);
+  if (leading_zeroes == 32) {
+    leading_zeroes += clz(a.lo);
+  }
+  return leading_zeroes;
+}
+
+static bool u2_eq(uint2 a, uint2 b) { return a.lo == b.lo && a.hi == b.hi; }
+
+static bool u2_zero(uint2 a) { return u2_eq(a, u2_set_u(0)); }
+
+static bool u2_gt(uint2 a, uint2 b) {
+  return a.hi > b.hi || (a.hi == b.hi && a.lo > b.lo);
+}
+
 _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
   /* special cases */
   if (isnan(a) || isnan(b) || isnan(c) || isinf(a) || isinf(b)) {
@@ -63,12 +149,9 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
   st_b.exponent = b == .0f ? 0 : ((as_uint(b) & 0x7f800000) >> 23) - 127;
   st_c.exponent = c == .0f ? 0 : ((as_uint(c) & 0x7f800000) >> 23) - 127;
 
-  st_a.mantissa.lo = a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000;
-  st_b.mantissa.lo = b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000;
-  st_c.mantissa.lo = c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000;
-  st_a.mantissa.hi = 0;
-  st_b.mantissa.hi = 0;
-  st_c.mantissa.hi = 0;
+  st_a.mantissa = u2_set_u(a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000);
+  st_b.mantissa = u2_set_u(b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000);
+  st_c.mantissa = u2_set_u(c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000);
 
   st_a.sign = as_uint(a) & 0x80000000;
   st_b.sign = as_uint(b) & 0x80000000;
@@ -81,15 +164,13 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
   // add another bit to detect subtraction underflow
   struct fp st_mul;
   st_mul.sign = st_a.sign ^ st_b.sign;
-  st_mul.mantissa.hi = mul_hi(st_a.mantissa.lo, st_b.mantissa.lo);
-  st_mul.mantissa.lo = st_a.mantissa.lo * st_b.mantissa.lo;
-  uint upper_14bits = (st_mul.mantissa.lo >> 18) & 0x3fff;
-  st_mul.mantissa.lo <<= 14;
-  st_mul.mantissa.hi <<= 14;
-  st_mul.mantissa.hi |= upper_14bits;
-  st_mul.exponent = (st_mul.mantissa.lo != 0 || st_mul.mantissa.hi != 0)
-                        ? st_a.exponent + st_b.exponent
-                        : 0;
+  st_mul.mantissa = u2_sll(u2_mul(st_a.mantissa.lo, st_b.mantissa.lo), 14);
+  st_mul.exponent =
+      !u2_zero(st_mul.mantissa) ? st_a.exponent + st_b.exponent : 0;
+
+  // FIXME: Detecting a == 0 || b == 0 above crashed GCN isel
+  if (st_mul.exponent == 0 && u2_zero(st_mul.mantissa))
+    return c;
 
 // Mantissa is 23 fractional bits, shift it the same way as product mantissa
 #define C_ADJUST 37ul
@@ -97,146 +178,80 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
   // both exponents are bias adjusted
   int exp_
diff  = st_mul.exponent - st_c.exponent;
 
-  uint abs_exp_
diff  = abs(exp_
diff );
-  st_c.mantissa.hi = (st_c.mantissa.lo << 5);
-  st_c.mantissa.lo = 0;
-  uint2 cutoff_bits = (uint2)(0, 0);
-  uint2 cutoff_mask = (uint2)(0, 0);
-  if (abs_exp_
diff  < 32) {
-    cutoff_mask.lo = (1u << abs(exp_
diff )) - 1u;
-  } else if (abs_exp_
diff  < 64) {
-    cutoff_mask.lo = 0xffffffff;
-    uint remaining = abs_exp_
diff  - 32;
-    cutoff_mask.hi = (1u << remaining) - 1u;
+  st_c.mantissa = u2_sll(st_c.mantissa, C_ADJUST);
+  uint2 cutoff_bits = u2_set_u(0);
+  uint2 cutoff_mask = u2_add(u2_sll(u2_set_u(1), abs(exp_
diff )),
+                             u2_set(0xffffffff, 0xffffffff));
+  if (exp_
diff  > 0) {
+    cutoff_bits =
+        exp_
diff  >= 64 ? st_c.mantissa : u2_and(st_c.mantissa, cutoff_mask);
+    st_c.mantissa =
+        exp_
diff  >= 64 ? u2_set_u(0) : u2_srl(st_c.mantissa, exp_
diff );
   } else {
-    cutoff_mask = (uint2)(0, 0);
-  }
-  uint2 tmp = (exp_
diff  > 0) ? st_c.mantissa : st_mul.mantissa;
-  if (abs_exp_
diff  > 0) {
-    cutoff_bits = abs_exp_
diff  >= 64 ? tmp : (tmp & cutoff_mask);
-    if (abs_exp_
diff  < 32) {
-      // shift some of the hi bits into the shifted lo bits.
-      uint shift_mask = (1u << abs_exp_
diff ) - 1;
-      uint upper_saved_bits = tmp.hi & shift_mask;
-      upper_saved_bits = upper_saved_bits << (32 - abs_exp_
diff );
-      tmp.hi >>= abs_exp_
diff ;
-      tmp.lo >>= abs_exp_
diff ;
-      tmp.lo |= upper_saved_bits;
-    } else if (abs_exp_
diff  < 64) {
-      tmp.lo = (tmp.hi >> (abs_exp_
diff  - 32));
-      tmp.hi = 0;
-    } else {
-      tmp = (uint2)(0, 0);
-    }
+    cutoff_bits = -exp_
diff  >= 64 ? st_mul.mantissa
+                                  : u2_and(st_mul.mantissa, cutoff_mask);
+    st_mul.mantissa =
+        -exp_
diff  >= 64 ? u2_set_u(0) : u2_srl(st_mul.mantissa, -exp_
diff );
   }
-  if (exp_
diff  > 0)
-    st_c.mantissa = tmp;
-  else
-    st_mul.mantissa = tmp;
 
   struct fp st_fma;
   st_fma.sign = st_mul.sign;
   st_fma.exponent = max(st_mul.exponent, st_c.exponent);
-  st_fma.mantissa = (uint2)(0, 0);
   if (st_c.sign == st_mul.sign) {
-    uint carry = (hadd(st_mul.mantissa.lo, st_c.mantissa.lo) >> 31) & 0x1;
-    st_fma.mantissa = st_mul.mantissa + st_c.mantissa;
-    st_fma.mantissa.hi += carry;
+    st_fma.mantissa = u2_add(st_mul.mantissa, st_c.mantissa);
   } else {
     // cutoff bits borrow one
-    uint cutoff_borrow = ((cutoff_bits.lo != 0 || cutoff_bits.hi != 0) &&
-                          (st_mul.exponent > st_c.exponent))
-                             ? 1
-                             : 0;
-    uint borrow = 0;
-    if (st_c.mantissa.lo > st_mul.mantissa.lo) {
-      borrow = 1;
-    } else if (st_c.mantissa.lo == UINT_MAX && cutoff_borrow == 1) {
-      borrow = 1;
-    } else if ((st_c.mantissa.lo + cutoff_borrow) > st_mul.mantissa.lo) {
-      borrow = 1;
-    }
-
-    st_fma.mantissa.lo = st_mul.mantissa.lo - st_c.mantissa.lo - cutoff_borrow;
-    st_fma.mantissa.hi = st_mul.mantissa.hi - st_c.mantissa.hi - borrow;
+    st_fma.mantissa =
+        u2_add(u2_add(st_mul.mantissa, u2_inv(st_c.mantissa)),
+               (!u2_zero(cutoff_bits) && (st_mul.exponent > st_c.exponent)
+                    ? u2_set(0xffffffff, 0xffffffff)
+                    : u2_set_u(0)));
   }
 
   // underflow: st_c.sign != st_mul.sign, and magnitude switches the sign
-  if (st_fma.mantissa.hi > INT_MAX) {
-    st_fma.mantissa = ~st_fma.mantissa;
-    uint carry = (hadd(st_fma.mantissa.lo, 1u) >> 31) & 0x1;
-    st_fma.mantissa.lo += 1;
-    st_fma.mantissa.hi += carry;
-
+  if (u2_gt(st_fma.mantissa, u2_set(0x7fffffff, 0xffffffff))) {
+    st_fma.mantissa = u2_inv(st_fma.mantissa);
     st_fma.sign = st_mul.sign ^ 0x80000000;
   }
 
   // detect overflow/underflow
-  uint leading_zeroes = clz(st_fma.mantissa.hi);
-  if (leading_zeroes == 32) {
-    leading_zeroes += clz(st_fma.mantissa.lo);
-  }
-  int overflow_bits = 3 - leading_zeroes;
+  int overflow_bits = 3 - u2_clz(st_fma.mantissa);
 
   // adjust exponent
   st_fma.exponent += overflow_bits;
 
   // handle underflow
   if (overflow_bits < 0) {
-    uint shift = -overflow_bits;
-    if (shift < 32) {
-      uint shift_mask = (1u << shift) - 1;
-      uint saved_lo_bits = (st_fma.mantissa.lo >> (32 - shift)) & shift_mask;
-      st_fma.mantissa.lo <<= shift;
-      st_fma.mantissa.hi <<= shift;
-      st_fma.mantissa.hi |= saved_lo_bits;
-    } else if (shift < 64) {
-      st_fma.mantissa.hi = (st_fma.mantissa.lo << (64 - shift));
-      st_fma.mantissa.lo = 0;
-    } else {
-      st_fma.mantissa = (uint2)(0, 0);
-    }
-
+    st_fma.mantissa = u2_sll(st_fma.mantissa, -overflow_bits);
     overflow_bits = 0;
   }
 
   // rounding
-  // overflow_bits is now in the range of [0, 3] making the shift greater than
-  // 32 bits.
-  uint2 trunc_mask;
-  uint trunc_shift = C_ADJUST + overflow_bits - 32;
-  trunc_mask.hi = (1u << trunc_shift) - 1;
-  trunc_mask.lo = UINT_MAX;
-  uint2 trunc_bits = st_fma.mantissa & trunc_mask;
-  trunc_bits.lo |= (cutoff_bits.hi != 0 || cutoff_bits.lo != 0) ? 1 : 0;
-  uint2 last_bit;
-  last_bit.lo = 0;
-  last_bit.hi = st_fma.mantissa.hi & (1u << trunc_shift);
-  uint grs_shift = C_ADJUST - 3 + overflow_bits - 32;
-  uint2 grs_bits;
-  grs_bits.lo = 0;
-  grs_bits.hi = 0x4u << grs_shift;
+  uint2 trunc_mask = u2_add(u2_sll(u2_set_u(1), C_ADJUST + overflow_bits),
+                            u2_set(0xffffffff, 0xffffffff));
+  uint2 trunc_bits =
+      u2_or(u2_and(st_fma.mantissa, trunc_mask), !u2_zero(cutoff_bits));
+  uint2 last_bit =
+      u2_and(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits));
+  uint2 grs_bits = u2_sll(u2_set_u(4), C_ADJUST - 3 + overflow_bits);
 
   // round to nearest even
-  if ((trunc_bits.hi > grs_bits.hi ||
-       (trunc_bits.hi == grs_bits.hi && trunc_bits.lo > grs_bits.lo)) ||
-      (trunc_bits.hi == grs_bits.hi && trunc_bits.lo == grs_bits.lo &&
-       last_bit.hi != 0)) {
-    uint shift = C_ADJUST + overflow_bits - 32;
-    st_fma.mantissa.hi += 1u << shift;
+  if (u2_gt(trunc_bits, grs_bits) ||
+      (u2_eq(trunc_bits, grs_bits) && !u2_zero(last_bit))) {
+    st_fma.mantissa =
+        u2_add(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits));
   }
 
-        // Shift mantissa back to bit 23
-  st_fma.mantissa.lo = (st_fma.mantissa.hi >> (C_ADJUST + overflow_bits - 32));
-  st_fma.mantissa.hi = 0;
+  // Shift mantissa back to bit 23
+  st_fma.mantissa = u2_srl(st_fma.mantissa, C_ADJUST + overflow_bits);
 
   // Detect rounding overflow
-  if (st_fma.mantissa.lo > 0xffffff) {
+  if (u2_gt(st_fma.mantissa, u2_set_u(0xffffff))) {
     ++st_fma.exponent;
-    st_fma.mantissa.lo >>= 1;
+    st_fma.mantissa = u2_srl(st_fma.mantissa, 1);
   }
 
-  if (st_fma.mantissa.lo == 0) {
+  if (u2_zero(st_fma.mantissa)) {
     return 0.0f;
   }
 

diff  --git a/libclc/clspv/lib/shared/vstore_half.cl b/libclc/clspv/lib/shared/vstore_half.cl
new file mode 100644
index 0000000000000..b05fcfe75fb7a
--- /dev/null
+++ b/libclc/clspv/lib/shared/vstore_half.cl
@@ -0,0 +1,135 @@
+#include <clc/clc.h>
+
+#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
+
+#define ROUND_VEC1(out, in, ROUNDF) out = ROUNDF(in);
+#define ROUND_VEC2(out, in, ROUNDF)                                            \
+  ROUND_VEC1(out.lo, in.lo, ROUNDF);                                           \
+  ROUND_VEC1(out.hi, in.hi, ROUNDF);
+#define ROUND_VEC3(out, in, ROUNDF)                                            \
+  ROUND_VEC1(out.s0, in.s0, ROUNDF);                                           \
+  ROUND_VEC1(out.s1, in.s1, ROUNDF);                                           \
+  ROUND_VEC1(out.s2, in.s2, ROUNDF);
+#define ROUND_VEC4(out, in, ROUNDF)                                            \
+  ROUND_VEC2(out.lo, in.lo, ROUNDF);                                           \
+  ROUND_VEC2(out.hi, in.hi, ROUNDF);
+#define ROUND_VEC8(out, in, ROUNDF)                                            \
+  ROUND_VEC4(out.lo, in.lo, ROUNDF);                                           \
+  ROUND_VEC4(out.hi, in.hi, ROUNDF);
+#define ROUND_VEC16(out, in, ROUNDF)                                           \
+  ROUND_VEC8(out.lo, in.lo, ROUNDF);                                           \
+  ROUND_VEC8(out.hi, in.hi, ROUNDF);
+
+#define __FUNC(SUFFIX, VEC_SIZE, TYPE, AS, ROUNDF)                             \
+  void _CLC_OVERLOAD vstore_half_##VEC_SIZE(TYPE, size_t, AS half *);          \
+  _CLC_OVERLOAD _CLC_DEF void vstore_half##SUFFIX(TYPE vec, size_t offset,     \
+                                                  AS half *mem) {              \
+    TYPE rounded_vec;                                                          \
+    ROUND_VEC##VEC_SIZE(rounded_vec, vec, ROUNDF);                             \
+    vstore_half_##VEC_SIZE(rounded_vec, offset, mem);                          \
+  }                                                                            \
+  void _CLC_OVERLOAD vstorea_half_##VEC_SIZE(TYPE, size_t, AS half *);         \
+  _CLC_OVERLOAD _CLC_DEF void vstorea_half##SUFFIX(TYPE vec, size_t offset,    \
+                                                   AS half *mem) {             \
+    TYPE rounded_vec;                                                          \
+    ROUND_VEC##VEC_SIZE(rounded_vec, vec, ROUNDF);                             \
+    vstorea_half_##VEC_SIZE(rounded_vec, offset, mem);                         \
+  }
+
+_CLC_DEF _CLC_OVERLOAD float __clc_rtz(float x) {
+  /* Handle nan corner case */
+  if (isnan(x))
+    return x;
+  /* RTZ does not produce Inf for large numbers */
+  if (fabs(x) > 65504.0f && !isinf(x))
+    return copysign(65504.0f, x);
+
+  const int exp = (as_uint(x) >> 23 & 0xff) - 127;
+  /* Manage range rounded to +- zero explicitely */
+  if (exp < -24)
+    return copysign(0.0f, x);
+
+  /* Remove lower 13 bits to make sure the number is rounded down */
+  int mask = 0xffffe000;
+  /* Denormals cannot be flushed, and they use 
diff erent bit for rounding */
+  if (exp < -14)
+    mask <<= min(-(exp + 14), 10);
+
+  return as_float(as_uint(x) & mask);
+}
+
+_CLC_DEF _CLC_OVERLOAD float __clc_rti(float x) {
+  /* Handle nan corner case */
+  if (isnan(x))
+    return x;
+
+  const float inf = copysign(INFINITY, x);
+  uint ux = as_uint(x);
+
+  /* Manage +- infinity explicitely */
+  if (as_float(ux & 0x7fffffff) > 0x1.ffcp+15f) {
+    return inf;
+  }
+  /* Manage +- zero explicitely */
+  if ((ux & 0x7fffffff) == 0) {
+    return copysign(0.0f, x);
+  }
+
+  const int exp = (as_uint(x) >> 23 & 0xff) - 127;
+  /* Manage range rounded to smallest half denormal explicitely */
+  if (exp < -24) {
+    return copysign(0x1.0p-24f, x);
+  }
+
+  /* Set lower 13 bits */
+  int mask = (1 << 13) - 1;
+  /* Denormals cannot be flushed, and they use 
diff erent bit for rounding */
+  if (exp < -14) {
+    mask = (1 << (13 + min(-(exp + 14), 10))) - 1;
+  }
+
+  const float next = nextafter(as_float(ux | mask), inf);
+  return ((ux & mask) == 0) ? as_float(ux) : next;
+}
+_CLC_DEF _CLC_OVERLOAD float __clc_rtn(float x) {
+  return ((as_uint(x) & 0x80000000) == 0) ? __clc_rtz(x) : __clc_rti(x);
+}
+_CLC_DEF _CLC_OVERLOAD float __clc_rtp(float x) {
+  return ((as_uint(x) & 0x80000000) == 0) ? __clc_rti(x) : __clc_rtz(x);
+}
+_CLC_DEF _CLC_OVERLOAD float __clc_rte(float x) {
+  /* Mantisa + implicit bit */
+  const uint mantissa = (as_uint(x) & 0x7fffff) | (1u << 23);
+  const int exp = (as_uint(x) >> 23 & 0xff) - 127;
+  int shift = 13;
+  if (exp < -14) {
+    /* The default assumes lower 13 bits are rounded,
+     * but it might be more for denormals.
+     * Shifting beyond last == 0b, and qr == 00b is not necessary */
+    shift += min(-(exp + 14), 15);
+  }
+  int mask = (1 << shift) - 1;
+  const uint grs = mantissa & mask;
+  const uint last = mantissa & (1 << shift);
+  /* IEEE round up rule is: grs > 101b or grs == 100b and last == 1.
+   * exp > 15 should round to inf. */
+  bool roundup = (grs > (1 << (shift - 1))) ||
+                 (grs == (1 << (shift - 1)) && last != 0) || (exp > 15);
+  return roundup ? __clc_rti(x) : __clc_rtz(x);
+}
+
+#define __XFUNC(SUFFIX, VEC_SIZE, TYPE, AS)                                    \
+  __FUNC(SUFFIX, VEC_SIZE, TYPE, AS, __clc_rte)                                \
+  __FUNC(SUFFIX##_rtz, VEC_SIZE, TYPE, AS, __clc_rtz)                          \
+  __FUNC(SUFFIX##_rtn, VEC_SIZE, TYPE, AS, __clc_rtn)                          \
+  __FUNC(SUFFIX##_rtp, VEC_SIZE, TYPE, AS, __clc_rtp)                          \
+  __FUNC(SUFFIX##_rte, VEC_SIZE, TYPE, AS, __clc_rte)
+
+#define FUNC(SUFFIX, VEC_SIZE, TYPE, AS) __XFUNC(SUFFIX, VEC_SIZE, TYPE, AS)
+
+#define __CLC_BODY "vstore_half.inc"
+#include <clc/math/gentype.inc>
+#undef __CLC_BODY
+#undef FUNC
+#undef __XFUNC
+#undef __FUNC

diff  --git a/libclc/clspv/lib/shared/vstore_half.inc b/libclc/clspv/lib/shared/vstore_half.inc
new file mode 100644
index 0000000000000..83704cca3a010
--- /dev/null
+++ b/libclc/clspv/lib/shared/vstore_half.inc
@@ -0,0 +1,15 @@
+// This does exist only for fp32
+#if __CLC_FPSIZE == 32
+#ifdef __CLC_VECSIZE
+
+FUNC(__CLC_VECSIZE, __CLC_VECSIZE, __CLC_GENTYPE, __private);
+FUNC(__CLC_VECSIZE, __CLC_VECSIZE, __CLC_GENTYPE, __local);
+FUNC(__CLC_VECSIZE, __CLC_VECSIZE, __CLC_GENTYPE, __global);
+
+#undef __CLC_OFFSET
+#else
+FUNC(, 1, __CLC_GENTYPE, __private);
+FUNC(, 1, __CLC_GENTYPE, __local);
+FUNC(, 1, __CLC_GENTYPE, __global);
+#endif
+#endif

diff  --git a/libclc/generic/include/clc/clcfunc.h b/libclc/generic/include/clc/clcfunc.h
index abb5484d6248e..ad9eb23f29333 100644
--- a/libclc/generic/include/clc/clcfunc.h
+++ b/libclc/generic/include/clc/clcfunc.h
@@ -4,9 +4,11 @@
 
 // avoid inlines for SPIR-V related targets since we'll optimise later in the
 // chain
-#if defined(CLC_SPIRV) || defined(CLC_SPIRV64) || defined(CLC_CLSPV) || \
-    defined(CLC_CLSPV64)
+#if defined(CLC_SPIRV) || defined(CLC_SPIRV64)
 #define _CLC_DEF
+#elif defined(CLC_CLSPV) || defined(CLC_CLSPV64)
+#define _CLC_DEF                                                               \
+  __attribute__((noinline)) __attribute__((assume("clspv_libclc_builtin")))
 #else
 #define _CLC_DEF __attribute__((always_inline))
 #endif


        


More information about the cfe-commits mailing list