[clang] [HLSL] Implement min and max overloads using templates (PR #131666)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Mar 17 12:50:29 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Sarah Spall (spall)
<details>
<summary>Changes</summary>
Replace min and max overload implementation using macros with one using templates.
Enable new overloads of the forms:
vector<T,N> min/max(vector<T,N> p0, U p1)
vector<T,N> min/max(U p0, vector<T,N> p1)
vector<T,N> min/max(vector<T,N> p0, vector<R,N> p1)
U min/max(U p0, V p1)
Add new tests.
Closes #<!-- -->131170
---
Full diff: https://github.com/llvm/llvm-project/pull/131666.diff
6 Files Affected:
- (modified) clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h (-38)
- (modified) clang/lib/Headers/hlsl/hlsl_compat_overloads.h (+63-1)
- (modified) clang/test/CodeGenHLSL/builtins/max.hlsl (+16)
- (modified) clang/test/CodeGenHLSL/builtins/min.hlsl (+16)
- (added) clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl (+12)
- (added) clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl (+12)
``````````diff
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 62054b368691d..585e905c7bf5d 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -35,26 +35,6 @@ namespace hlsl {
#define _HLSL_16BIT_AVAILABILITY_STAGE(environment, version, stage)
#endif
-#define GEN_VEC_SCALAR_OVERLOADS(FUNC_NAME, BASE_TYPE, AVAIL) \
- GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##2, AVAIL) \
- GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##3, AVAIL) \
- GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, BASE_TYPE##4, AVAIL)
-
-#define GEN_BOTH_OVERLOADS(FUNC_NAME, BASE_TYPE, VECTOR_TYPE, AVAIL) \
- IF_TRUE_##AVAIL( \
- _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)) constexpr VECTOR_TYPE \
- FUNC_NAME(VECTOR_TYPE p0, BASE_TYPE p1) { \
- return __builtin_elementwise_##FUNC_NAME(p0, (VECTOR_TYPE)p1); \
- } \
- IF_TRUE_##AVAIL( \
- _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)) constexpr VECTOR_TYPE \
- FUNC_NAME(BASE_TYPE p0, VECTOR_TYPE p1) { \
- return __builtin_elementwise_##FUNC_NAME((VECTOR_TYPE)p0, p1); \
- }
-
-#define IF_TRUE_0(EXPR)
-#define IF_TRUE_1(EXPR) EXPR
-
//===----------------------------------------------------------------------===//
// abs builtins
//===----------------------------------------------------------------------===//
@@ -1563,7 +1543,6 @@ half3 max(half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
half4 max(half4, half4);
-GEN_VEC_SCALAR_OVERLOADS(max, half, 1)
#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
@@ -1578,7 +1557,6 @@ int16_t3 max(int16_t3, int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
int16_t4 max(int16_t4, int16_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, int16_t, 1)
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
@@ -1592,7 +1570,6 @@ uint16_t3 max(uint16_t3, uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
uint16_t4 max(uint16_t4, uint16_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint16_t, 1)
#endif
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
@@ -1603,7 +1580,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
int3 max(int3, int3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
int4 max(int4, int4);
-GEN_VEC_SCALAR_OVERLOADS(max, int, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
uint max(uint, uint);
@@ -1613,7 +1589,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
uint3 max(uint3, uint3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
uint4 max(uint4, uint4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
int64_t max(int64_t, int64_t);
@@ -1623,7 +1598,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
int64_t3 max(int64_t3, int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
int64_t4 max(int64_t4, int64_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, int64_t, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
uint64_t max(uint64_t, uint64_t);
@@ -1633,7 +1607,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
uint64_t3 max(uint64_t3, uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
uint64_t4 max(uint64_t4, uint64_t4);
-GEN_VEC_SCALAR_OVERLOADS(max, uint64_t, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
float max(float, float);
@@ -1643,7 +1616,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
float3 max(float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
float4 max(float4, float4);
-GEN_VEC_SCALAR_OVERLOADS(max, float, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
double max(double, double);
@@ -1653,7 +1625,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
double3 max(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_max)
double4 max(double4, double4);
-GEN_VEC_SCALAR_OVERLOADS(max, double, 0)
//===----------------------------------------------------------------------===//
// min builtins
@@ -1676,7 +1647,6 @@ half3 min(half3, half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
half4 min(half4, half4);
-GEN_VEC_SCALAR_OVERLOADS(min, half, 1)
#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
@@ -1691,7 +1661,6 @@ int16_t3 min(int16_t3, int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
int16_t4 min(int16_t4, int16_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, int16_t, 1)
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
@@ -1705,7 +1674,6 @@ uint16_t3 min(uint16_t3, uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
uint16_t4 min(uint16_t4, uint16_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint16_t, 1)
#endif
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
@@ -1716,7 +1684,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
int3 min(int3, int3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
int4 min(int4, int4);
-GEN_VEC_SCALAR_OVERLOADS(min, int, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
uint min(uint, uint);
@@ -1726,7 +1693,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
uint3 min(uint3, uint3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
uint4 min(uint4, uint4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
float min(float, float);
@@ -1736,7 +1702,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
float3 min(float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
float4 min(float4, float4);
-GEN_VEC_SCALAR_OVERLOADS(min, float, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
int64_t min(int64_t, int64_t);
@@ -1746,7 +1711,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
int64_t3 min(int64_t3, int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
int64_t4 min(int64_t4, int64_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, int64_t, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
uint64_t min(uint64_t, uint64_t);
@@ -1756,7 +1720,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
uint64_t3 min(uint64_t3, uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
uint64_t4 min(uint64_t4, uint64_t4);
-GEN_VEC_SCALAR_OVERLOADS(min, uint64_t, 0)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double min(double, double);
@@ -1766,7 +1729,6 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double3 min(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double4 min(double4, double4);
-GEN_VEC_SCALAR_OVERLOADS(min, double, 0)
//===----------------------------------------------------------------------===//
// normalize builtins
diff --git a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
index 97f3cade32676..93a9aa31d0018 100644
--- a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
+++ b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
@@ -1,4 +1,4 @@
-//===--- hlsl_compat_overloads.h - Extra HLSL overloads for intrinsics --===//
+//===--- hlsl_compat_overloads.h - Extra HLSL overloads for intrinsics ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -54,5 +54,67 @@ clamp(U p0, V p1, W p2) {
return clamp(p0, (U)p1, (U)p2);
}
+//===----------------------------------------------------------------------===//
+// max builtin overloads
+//===----------------------------------------------------------------------===//
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+ __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+max(vector<T, N> p0, U p1) {
+ return max(p0, (vector<T, N>)p1);
+}
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+ __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+max(U p0, vector<T, N> p1) {
+ return max((vector<T, N>)p0, p1);
+}
+
+template <typename T, typename R, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
+max(vector<T, N> p0, vector<R, N> p1) {
+ return max(p0, (vector<T, N>)p1);
+}
+
+template <typename U, typename V>
+constexpr __detail::enable_if_t<
+ __detail::is_arithmetic<U>::Value && __detail::is_arithmetic<V>::Value, U>
+max(U p0, V p1) {
+ return max(p0, (U)p1);
+}
+
+//===----------------------------------------------------------------------===//
+// min builtin overloads
+//===----------------------------------------------------------------------===//
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+ __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+min(vector<T, N> p0, U p1) {
+ return min(p0, (vector<T, N>)p1);
+}
+
+template <typename T, typename U, uint N>
+constexpr __detail::enable_if_t<
+ __detail::is_arithmetic<U>::Value && (N > 1 && N <= 4), vector<T, N>>
+min(U p0, vector<T, N> p1) {
+ return min((vector<T, N>)p0, p1);
+}
+
+template <typename T, typename R, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
+min(vector<T, N> p0, vector<R, N> p1) {
+ return min(p0, (vector<T, N>)p1);
+}
+
+template <typename U, typename V>
+constexpr __detail::enable_if_t<
+ __detail::is_arithmetic<U>::Value && __detail::is_arithmetic<V>::Value, U>
+min(U p0, V p1) {
+ return min(p0, (U)p1);
+}
+
} // namespace hlsl
#endif // _HLSL_COMPAT_OVERLOADS_H_
diff --git a/clang/test/CodeGenHLSL/builtins/max.hlsl b/clang/test/CodeGenHLSL/builtins/max.hlsl
index 6b5fb6ae59534..39aa1cfd05585 100644
--- a/clang/test/CodeGenHLSL/builtins/max.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/max.hlsl
@@ -163,3 +163,19 @@ double4 test_max_double4_mismatch(double4 p0, double p1) { return max(p0, p1); }
// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x double> {{.*}}test_max_double4_mismatch2
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.maxnum.v4f64
double4 test_max_double4_mismatch2(double4 p0, double p1) { return max(p1, p0); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads1
+// CHECK: call <2 x i32> @llvm.smax.v2i32
+int2 test_overloads1(int2 p0, float p1) { return max(p0, p1); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads2
+// CHECK: call <2 x i32> @llvm.smax.v2i32
+int2 test_overloads2(int2 p0, float p1) { return max(p1, p0); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> {{.*}}test_overloads3
+// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.maxnum.v3f32
+float3 test_overloads3(float3 p0, int3 p1) { return max(p0, p1); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) double {{.*}}test_overloads4
+// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.maxnum.f64(
+double test_overloads4(double p0, int p1) { return max(p0, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/min.hlsl b/clang/test/CodeGenHLSL/builtins/min.hlsl
index 551db52878e37..c678c03cabb31 100644
--- a/clang/test/CodeGenHLSL/builtins/min.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/min.hlsl
@@ -163,3 +163,19 @@ double4 test_min_double4_mismatch(double4 p0, double p1) { return min(p0, p1); }
// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x double> {{.*}}test_min_double4_mismatch2
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.minnum.v4f64
double4 test_min_double4_mismatch2(double4 p0, double p1) { return min(p1, p0); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads1
+// CHECK: call <2 x i32> @llvm.smin.v2i32
+int2 test_overloads1(int2 p0, float p1) { return min(p0, p1); }
+
+// CHECK-LABEL: define noundef <2 x i32> {{.*}}test_overloads2
+// CHECK: call <2 x i32> @llvm.smin.v2i32
+int2 test_overloads2(int2 p0, float p1) { return min(p1, p0); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> {{.*}}test_overloads3
+// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.minnum.v3f32
+float3 test_overloads3(float3 p0, int3 p1) { return min(p0, p1); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) double {{.*}}test_overloads4
+// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.minnum.f64(
+double test_overloads4(double p0, int p1) { return min(p0, p1); }
diff --git a/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl b/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl
new file mode 100644
index 0000000000000..74c729c5ef4be
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/max-errors-16bit.hlsl
@@ -0,0 +1,12 @@
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=half
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=half3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=int16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=int16_t3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=uint16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=uint16_t3
+
+// check we error on 16 bit type if shader model is too old
+// CHECK: '-enable-16bit-types' option requires target HLSL Version >= 2018 and shader model >= 6.2, but HLSL Version is 'hlsl202x' and shader model is '6.0'
+TEST_TYPE test_error(TEST_TYPE p0, int p1) {
+ return max(p0, p1);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl b/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl
new file mode 100644
index 0000000000000..ea5ed36763c75
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/min-errors-16bit.hlsl
@@ -0,0 +1,12 @@
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=half
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=half3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=int16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=int16_t3
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=uint16_t
+// RUN: not %clang_dxc -enable-16bit-types -T cs_6_0 -HV 202x %s 2>&1 | FileCheck %s -DTEST_TYPE=uint16_t3
+
+// check we error on 16 bit type if shader model is too old
+// CHECK: '-enable-16bit-types' option requires target HLSL Version >= 2018 and shader model >= 6.2, but HLSL Version is 'hlsl202x' and shader model is '6.0'
+TEST_TYPE test_error(TEST_TYPE p0, int p1) {
+ return min(p0, p1);
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/131666
More information about the cfe-commits
mailing list