[clang] [HLSL] Add various overloads for MiniEngine (PR #139800)
via cfe-commits
cfe-commits at lists.llvm.org
Tue May 13 14:48:15 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Ashley Coleman (V-FEXrt)
<details>
<summary>Changes</summary>
Partial implementation of https://github.com/llvm/wg-hlsl/issues/264
Adds several overloads to various intrinsic functions used by MiniEngine
---
Full diff: https://github.com/llvm/llvm-project/pull/139800.diff
5 Files Affected:
- (modified) clang/lib/Headers/hlsl/hlsl_compat_overloads.h (+59-45)
- (modified) clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl (+6)
- (added) clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl (+29)
- (modified) clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl (+38)
- (modified) clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl (+13)
``````````diff
diff --git a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
index 4874206d349c0..a2f8f658b292e 100644
--- a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
+++ b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
@@ -158,6 +158,42 @@ namespace hlsl {
return fn((float4)V1, (float4)V2, (float4)V3); \
}
+#define _DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(fn) \
+ template <typename T, uint N> \
+ constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn( \
+ vector<T, N> V1, T V2) { \
+ return fn(V1, (vector<T, N>)V2); \
+ } \
+ template <typename T, uint N> \
+ constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn( \
+ T V1, vector<T, N> V2) { \
+ return fn((vector<T, N>)V1, V2); \
+ }
+
+#define _DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(fn) \
+ template <typename T, uint N> \
+ constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn( \
+ T V1, vector<T, N> V2, vector<T, N> V3) { \
+ return fn((vector<T, N>)V1, V2, V3); \
+ } \
+ template <typename T, uint N> \
+ constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn( \
+ vector<T, N> V1, T V2, vector<T, N> V3) { \
+ return fn(V1, (vector<T, N>)V2, V3); \
+ } \
+ template <typename T, uint N> \
+ constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn( \
+ vector<T, N> V1, vector<T, N> V2, T V3) { \
+ return fn(V1, V2, (vector<T, N>)V3); \
+ }
+
+#define _DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(fn) \
+ template <typename T, uint N> \
+ constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>> fn( \
+ vector<T, N> V1, T V2, T V3) { \
+ return fn(V1, (vector<T, N>)V2, (vector<T, N>)V3); \
+ }
+
//===----------------------------------------------------------------------===//
// acos builtins overloads
//===----------------------------------------------------------------------===//
@@ -197,23 +233,8 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(ceil)
// clamp builtins overloads
//===----------------------------------------------------------------------===//
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-clamp(vector<T, N> p0, vector<T, N> p1, T p2) {
- return clamp(p0, p1, (vector<T, N>)p2);
-}
-
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-clamp(vector<T, N> p0, T p1, vector<T, N> p2) {
- return clamp(p0, (vector<T, N>)p1, p2);
-}
-
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-clamp(vector<T, N> p0, T p1, T p2) {
- return clamp(p0, (vector<T, N>)p1, (vector<T, N>)p2);
-}
+_DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(clamp)
+_DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(clamp)
//===----------------------------------------------------------------------===//
// cos builtins overloads
@@ -236,6 +257,22 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(cosh)
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(degrees)
_DXC_COMPAT_UNARY_INTEGER_OVERLOADS(degrees)
+//===----------------------------------------------------------------------===//
+// dot builtins overloads
+//===----------------------------------------------------------------------===//
+
+template <typename T, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), T> dot(vector<T, N> V1,
+ T V2) {
+ return dot(V1, (vector<T, N>)V2);
+}
+
+template <typename T, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), T> dot(T V1,
+ vector<T, N> V2) {
+ return dot((vector<T, N>)V1, V2);
+}
+
//===----------------------------------------------------------------------===//
// exp builtins overloads
//===----------------------------------------------------------------------===//
@@ -277,14 +314,10 @@ constexpr bool4 isinf(double4 V) { return isinf((float4)V); }
// lerp builtins overloads
//===----------------------------------------------------------------------===//
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-lerp(vector<T, N> x, vector<T, N> y, T s) {
- return lerp(x, y, (vector<T, N>)s);
-}
-
_DXC_COMPAT_TERNARY_DOUBLE_OVERLOADS(lerp)
_DXC_COMPAT_TERNARY_INTEGER_OVERLOADS(lerp)
+_DXC_COMPAT_TERNARY_VECTOR_SCALAR_OVERLOADS(lerp)
+_DXC_COMPAT_TERNARY_SINGLE_VECTOR_SCALAR_OVERLOADS(lerp)
//===----------------------------------------------------------------------===//
// log builtins overloads
@@ -311,33 +344,13 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(log2)
// max builtins overloads
//===----------------------------------------------------------------------===//
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-max(vector<T, N> p0, T p1) {
- return max(p0, (vector<T, N>)p1);
-}
-
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-max(T p0, vector<T, N> p1) {
- return max((vector<T, N>)p0, p1);
-}
+_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(max)
//===----------------------------------------------------------------------===//
// min builtins overloads
//===----------------------------------------------------------------------===//
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-min(vector<T, N> p0, T p1) {
- return min(p0, (vector<T, N>)p1);
-}
-
-template <typename T, uint N>
-constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
-min(T p0, vector<T, N> p1) {
- return min((vector<T, N>)p0, p1);
-}
+_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(min)
//===----------------------------------------------------------------------===//
// normalize builtins overloads
@@ -352,6 +365,7 @@ _DXC_COMPAT_UNARY_INTEGER_OVERLOADS(normalize)
_DXC_COMPAT_BINARY_DOUBLE_OVERLOADS(pow)
_DXC_COMPAT_BINARY_INTEGER_OVERLOADS(pow)
+_DXC_COMPAT_BINARY_VECTOR_SCALAR_OVERLOADS(pow)
//===----------------------------------------------------------------------===//
// rsqrt builtins overloads
diff --git a/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl
index c0e1e914831aa..5bf23db7671ec 100644
--- a/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/clamp-overloads.hlsl
@@ -90,6 +90,12 @@ double4 test_clamp_double4_mismatch1(double4 p0, double p1) { return clamp(p0, p
// CHECK: [[CLAMP:%.*]] = call reassoc nnan ninf nsz arcp afn {{.*}} <4 x double> @llvm.[[TARGET]].nclamp.v4f64(<4 x double> %{{.*}}, <4 x double> [[CONV1]], <4 x double> %{{.*}})
// CHECK: ret <4 x double> [[CLAMP]]
double4 test_clamp_double4_mismatch2(double4 p0, double p1) { return clamp(p0, p1,p0); }
+// CHECK: define [[FNATTRS]] [[FFNATTRS]] <4 x double> {{.*}}test_clamp_double4_mismatch3
+// CHECK: [[CONV0:%.*]] = insertelement <4 x double> poison, double %{{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <4 x double> [[CONV0]], <4 x double> poison, <4 x i32> zeroinitializer
+// CHECK: [[CLAMP:%.*]] = call reassoc nnan ninf nsz arcp afn {{.*}} <4 x double> @llvm.[[TARGET]].nclamp.v4f64(<4 x double> [[CONV1]], <4 x double> %{{.*}}, <4 x double> %{{.*}})
+// CHECK: ret <4 x double> [[CLAMP]]
+double4 test_clamp_double4_mismatch3(double4 p0, double p1) { return clamp(p1, p0, p0); }
// CHECK: define [[FNATTRS]] <3 x i32> {{.*}}test_overloads3
// CHECK: [[CONV0:%.*]] = insertelement <3 x i32> poison, i32 %{{.*}}, i64 0
diff --git a/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl
new file mode 100644
index 0000000000000..33f0c7625b2eb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot-overloads.hlsl
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -triple dxil-pc-shadermodel6.3-library %s \
+// RUN: -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK \
+// RUN: -DTARGET=dx -DFNATTRS=noundef -DFFNATTRS="nofpclass(nan inf)"
+
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -triple spirv-unknown-vulkan-compute %s \
+// RUN: -emit-llvm -o - | FileCheck %s --check-prefixes=CHECK \
+// RUN: -DTARGET=spv -DFNATTRS="spir_func noundef" -DFFNATTRS="nofpclass(nan inf)"
+
+// CHECK: define [[FNATTRS]] [[FFNATTRS]] float {{.*}}test_dot_float4_mismatch1
+// CHECK: [[CONV0:%.*]] = insertelement <4 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <4 x float> [[CONV0]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK: [[DOT:%.*]] = call {{.*}} float @llvm.[[TARGET]].fdot.v4f32(<4 x float> %{{.*}}, <4 x float> [[CONV1]])
+// CHECK: ret float [[DOT]]
+float test_dot_float4_mismatch1(float4 p0, float p1) { return dot(p0, p1); }
+
+// CHECK: define [[FNATTRS]] [[FFNATTRS]] float {{.*}}test_dot_float4_mismatch2
+// CHECK: [[CONV0:%.*]] = insertelement <4 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <4 x float> [[CONV0]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK: [[DOT:%.*]] = call {{.*}} float @llvm.[[TARGET]].fdot.v4f32(<4 x float> [[CONV1]], <4 x float> %{{.*}})
+// CHECK: ret float [[DOT]]
+float test_dot_float4_mismatch2(float4 p0, float p1) { return dot(p1, p0); }
+
+// CHECK: define [[FNATTRS]] i32 {{.*}}test_dot_int2_mismatch1
+// CHECK: [[CONV0:%.*]] = insertelement <2 x i32> poison, i32 %{{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <2 x i32> [[CONV0]], <2 x i32> poison, <2 x i32> zeroinitializer
+// CHECK: [[DOT:%.*]] = call {{.*}} i32 @llvm.[[TARGET]].sdot.v2i32(<2 x i32> %{{.*}}, <2 x i32> [[CONV1]])
+// CHECK: ret i32 [[DOT]]
+int test_dot_int2_mismatch1(int2 p0, int p1) { return dot(p0, p1); }
+
diff --git a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
index 3cb14f8555cab..0158370a847d1 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
@@ -179,3 +179,41 @@ half3 test_lerp_half_scalar(half3 x, half3 y, half s) { return lerp(x, y, s); }
float3 test_lerp_float_scalar(float3 x, float3 y, float s) {
return lerp(x, y, s);
}
+
+// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar1Dv2_ff(
+// CHECK: [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, <2 x float> [[SPLAT]])
+// CHECK: ret <2 x float> [[LERP]]
+float2 test_lerp_float_scalar1(float2 v, float s) {
+ return lerp(v, v, s);
+}
+
+// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar2Dv2_ff(
+// CHECK: [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> [[SPLAT]], <2 x float> {{.*}})
+// CHECK: ret <2 x float> [[LERP]]
+float2 test_lerp_float_scalar2(float2 v, float s) {
+ return lerp(v, s, v);
+}
+
+// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar3Dv2_ff(
+// CHECK: [[SPLATINSERT:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[SPLAT:%.*]] = shufflevector <2 x float> [[SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> [[SPLAT]], <2 x float> {{.*}}, <2 x float> {{.*}})
+// CHECK: ret <2 x float> [[LERP]]
+float2 test_lerp_float_scalar3(float2 v, float s) {
+ return lerp(s, v, v);
+}
+
+// CHECK: define [[FNATTRS]] <2 x float> @_Z23test_lerp_float_scalar4Dv2_ff(
+// CHECK: [[SPLATINSERT0:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[SPLAT0:%.*]] = shufflevector <2 x float> [[SPLATINSERT0]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[SPLATINSERT1:%.*]] = insertelement <2 x float> poison, float %{{.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <2 x float> [[SPLATINSERT1]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[LERP:%.*]] = call {{.*}} <2 x float> @llvm.[[TARGET]].lerp.v2f32(<2 x float> {{.*}}, <2 x float> [[SPLAT0]], <2 x float> [[SPLAT1]])
+// CHECK: ret <2 x float> [[LERP]]
+float2 test_lerp_float_scalar4(float2 v, float s) {
+ return lerp(v, s, s);
+}
diff --git a/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl b/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl
index 39003aef7b7b5..3fc8cfcd9a8cb 100644
--- a/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/pow-overloads.hlsl
@@ -126,3 +126,16 @@ float3 test_pow_uint64_t3(uint64_t3 p0, uint64_t3 p1) { return pow(p0, p1); }
// CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <4 x float> @llvm.pow.v4f32(<4 x float> [[CONV0]], <4 x float> [[CONV1]])
// CHECK: ret <4 x float> [[POW]]
float4 test_pow_uint64_t4(uint64_t4 p0, uint64_t4 p1) { return pow(p0, p1); }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> {{.*}}test_pow_float2_mismatch1
+// CHECK: [[CONV0:%.*]] = insertelement <2 x float> poison, float {{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <2 x float> [[CONV0]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <2 x float> @llvm.pow.v2f32(<2 x float> {{.*}}, <2 x float> [[CONV1]])
+// CHECK: ret <2 x float> [[POW]]
+float2 test_pow_float2_mismatch1(float2 p0, float p1) { return pow(p0, p1); }
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> {{.*}}test_pow_float2_mismatch2
+// CHECK: [[CONV0:%.*]] = insertelement <2 x float> poison, float {{.*}}, i64 0
+// CHECK: [[CONV1:%.*]] = shufflevector <2 x float> [[CONV0]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: [[POW:%.*]] = call [[FLOATATTRS]] noundef <2 x float> @llvm.pow.v2f32(<2 x float> [[CONV1]], <2 x float> {{.*}})
+// CHECK: ret <2 x float> [[POW]]
+float2 test_pow_float2_mismatch2(float2 p0, float p1) { return pow(p1, p0); }
``````````
</details>
https://github.com/llvm/llvm-project/pull/139800
More information about the cfe-commits
mailing list