[clang] 4cea2d0 - [HLSL][DXIL] implement `sqrt` intrinsic (#86560)

via cfe-commits cfe-commits at lists.llvm.org
Mon Mar 25 15:02:35 PDT 2024


Author: Farzon Lotfi
Date: 2024-03-25T18:02:30-04:00
New Revision: 4cea2d049f511d16cbc4605f7ecc908c851e169e

URL: https://github.com/llvm/llvm-project/commit/4cea2d049f511d16cbc4605f7ecc908c851e169e
DIFF: https://github.com/llvm/llvm-project/commit/4cea2d049f511d16cbc4605f7ecc908c851e169e.diff

LOG: [HLSL][DXIL] implement `sqrt` intrinsic (#86560)

completes #86187
- fix hlsl_intrinsic to cover the correct cases
- move to using `__builtin_elementwise_sqrt`
- add lowering of `Intrinsic::sqrt` to dxilop 24.

Added: 
    llvm/test/CodeGen/DirectX/sqrt.ll
    llvm/test/CodeGen/DirectX/sqrt_error.ll

Modified: 
    clang/lib/Headers/hlsl/hlsl_intrinsics.h
    clang/test/CodeGenHLSL/builtins/sqrt.hlsl
    llvm/lib/Target/DirectX/DXIL.td

Removed: 
    


################################################################################
diff  --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index d9f27c0db57ce4..fcb64fde1b91f1 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1366,14 +1366,26 @@ float4 sin(float4);
 /// \param Val The input value.
 
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_sqrtf16)
-half sqrt(half In);
-
-_HLSL_BUILTIN_ALIAS(__builtin_sqrtf)
-float sqrt(float In);
-
-_HLSL_BUILTIN_ALIAS(__builtin_sqrt)
-double sqrt(double In);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+half sqrt(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+half2 sqrt(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+half3 sqrt(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+half4 sqrt(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+float sqrt(float);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+float2 sqrt(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+float3 sqrt(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
+float4 sqrt(float4);
 
 //===----------------------------------------------------------------------===//
 // trunc builtins

diff  --git a/clang/test/CodeGenHLSL/builtins/sqrt.hlsl b/clang/test/CodeGenHLSL/builtins/sqrt.hlsl
index 2c2a09617cf86a..adbbf69a8e0685 100644
--- a/clang/test/CodeGenHLSL/builtins/sqrt.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/sqrt.hlsl
@@ -1,29 +1,53 @@
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
-// RUN:   dxil-pc-shadermodel6.2-library %s -fnative-half-type \
-// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
-using hlsl::sqrt;
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %{{.*}} = call half @llvm.sqrt.f16(
+// NATIVE_HALF: ret half %{{.*}}
+// NO_HALF: define noundef float @"?test_sqrt_half@@YA$halff@$halff@@Z"(
+// NO_HALF: %{{.*}} = call float @llvm.sqrt.f32(
+// NO_HALF: ret float %{{.*}}
+half test_sqrt_half(half p0) { return sqrt(p0); }
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %{{.*}} = call <2 x half> @llvm.sqrt.v2f16
+// NATIVE_HALF: ret <2 x half> %{{.*}}
+// NO_HALF: define noundef <2 x float> @
+// NO_HALF: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32(
+// NO_HALF: ret <2 x float> %{{.*}}
+half2 test_sqrt_half2(half2 p0) { return sqrt(p0); }
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %{{.*}} = call <3 x half> @llvm.sqrt.v3f16
+// NATIVE_HALF: ret <3 x half> %{{.*}}
+// NO_HALF: define noundef <3 x float> @
+// NO_HALF: %{{.*}} = call <3 x float> @llvm.sqrt.v3f32(
+// NO_HALF: ret <3 x float> %{{.*}}
+half3 test_sqrt_half3(half3 p0) { return sqrt(p0); }
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %{{.*}} = call <4 x half> @llvm.sqrt.v4f16
+// NATIVE_HALF: ret <4 x half> %{{.*}}
+// NO_HALF: define noundef <4 x float> @
+// NO_HALF: %{{.*}} = call <4 x float> @llvm.sqrt.v4f32(
+// NO_HALF: ret <4 x float> %{{.*}}
+half4 test_sqrt_half4(half4 p0) { return sqrt(p0); }
 
-double sqrt_d(double x)
-{
-  return sqrt(x);
-}
-
-// CHECK: define noundef double @"?sqrt_d@@YANN at Z"(
-// CHECK: call double @llvm.sqrt.f64(double %0)
-
-float sqrt_f(float x)
-{
-  return sqrt(x);
-}
-
-// CHECK: define noundef float @"?sqrt_f@@YAMM at Z"(
-// CHECK: call float @llvm.sqrt.f32(float %0)
-
-half sqrt_h(half x)
-{
-  return sqrt(x);
-}
-
-// CHECK: define noundef half @"?sqrt_h@@YA$f16@$f16@@Z"(
-// CHECK: call half @llvm.sqrt.f16(half %0)
+// CHECK: define noundef float @
+// CHECK: %{{.*}} = call float @llvm.sqrt.f32(
+// CHECK: ret float %{{.*}}
+float test_sqrt_float(float p0) { return sqrt(p0); }
+// CHECK: define noundef <2 x float> @
+// CHECK: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32
+// CHECK: ret <2 x float> %{{.*}}
+float2 test_sqrt_float2(float2 p0) { return sqrt(p0); }
+// CHECK: define noundef <3 x float> @
+// CHECK: %{{.*}} = call <3 x float> @llvm.sqrt.v3f32
+// CHECK: ret <3 x float> %{{.*}}
+float3 test_sqrt_float3(float3 p0) { return sqrt(p0); }
+// CHECK: define noundef <4 x float> @
+// CHECK: %{{.*}} = call <4 x float> @llvm.sqrt.v4f32
+// CHECK: ret <4 x float> %{{.*}}
+float4 test_sqrt_float4(float4 p0) { return sqrt(p0); }

diff  --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 2e6d58e14fd325..b4bc9eda4f50f5 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -274,6 +274,10 @@ def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.",
                          [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+def Sqrt : DXILOpMapping<24, unary, int_sqrt,
+                         "Returns the square root of the specified floating-point"
+                         "value, per component.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
                          "Returns the reciprocal of the square root of the specified value."
                          "rsqrt(x) = 1 / sqrt(x).",

diff  --git a/llvm/test/CodeGen/DirectX/sqrt.ll b/llvm/test/CodeGen/DirectX/sqrt.ll
new file mode 100644
index 00000000000000..76a572efd20557
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sqrt.ll
@@ -0,0 +1,20 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for sqrt are generated for float and half.
+
+define noundef float @sqrt_float(float noundef %a) #0 {
+entry:
+; CHECK:call float @dx.op.unary.f32(i32 24, float %{{.*}})
+  %elt.sqrt = call float @llvm.sqrt.f32(float %a)
+  ret float %elt.sqrt
+}
+
+define noundef half @sqrt_half(half noundef %a) #0 {
+entry:
+; CHECK:call half @dx.op.unary.f16(i32 24, half %{{.*}})
+  %elt.sqrt = call half @llvm.sqrt.f16(half %a)
+  ret half %elt.sqrt
+}
+
+declare half @llvm.sqrt.f16(half)
+declare float @llvm.sqrt.f32(float)

diff  --git a/llvm/test/CodeGen/DirectX/sqrt_error.ll b/llvm/test/CodeGen/DirectX/sqrt_error.ll
new file mode 100644
index 00000000000000..fffa2e19b80fa9
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sqrt_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation sqrt does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+define noundef double @sqrt_double(double noundef %a) {
+entry:
+  %elt.sqrt = call double @llvm.sqrt.f64(double %a)
+  ret double %elt.sqrt
+}


        


More information about the cfe-commits mailing list