[llvm] 6a38e19 - [HLSL] Implement support for HLSL intrinsic - saturate (#104619)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 20 09:46:36 PDT 2024
Author: S. Bharadwaj Yadavalli
Date: 2024-08-20T12:46:33-04:00
New Revision: 6a38e19c92ed09eeecb70d5f61c3b822acb4964d
URL: https://github.com/llvm/llvm-project/commit/6a38e19c92ed09eeecb70d5f61c3b822acb4964d
DIFF: https://github.com/llvm/llvm-project/commit/6a38e19c92ed09eeecb70d5f61c3b822acb4964d.diff
LOG: [HLSL] Implement support for HLSL intrinsic - saturate (#104619)
Implement support for HLSL intrinsic saturate.
Implement DXIL codegen for the intrinsic saturate by lowering it to DXIL
Op dx.saturate.
Implement SPIRV codegen by transforming saturate(x) to clamp(x, 0.0f,
1.0f).
Add tests for DXIL and SPIRV CodeGen.
Added:
clang/test/CodeGenHLSL/builtins/saturate.hlsl
clang/test/SemaHLSL/BuiltIns/saturate-errors.hlsl
llvm/test/CodeGen/DirectX/saturate.ll
llvm/test/CodeGen/DirectX/saturate_errors.ll
llvm/test/CodeGen/SPIRV/hlsl-intrinsics/saturate.ll
Modified:
clang/include/clang/Basic/Builtins.td
clang/lib/CodeGen/CGBuiltin.cpp
clang/lib/CodeGen/CGHLSLRuntime.h
clang/lib/Headers/hlsl/hlsl_intrinsics.h
clang/lib/Sema/SemaHLSL.cpp
llvm/include/llvm/IR/IntrinsicsDirectX.td
llvm/include/llvm/IR/IntrinsicsSPIRV.td
llvm/lib/Target/DirectX/DXIL.td
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 036366cdadf4aa..ac33672a32b336 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4745,6 +4745,12 @@ def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_saturate"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f424ddaa175400..495fb3e1e5b697 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18667,6 +18667,15 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(),
ArrayRef<Value *>{Op0}, nullptr, "hlsl.rsqrt");
}
+ case Builtin::BI__builtin_hlsl_elementwise_saturate: {
+ Value *Op0 = EmitScalarExpr(E->getArg(0));
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ "saturate operand must have a float representation");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/Op0->getType(),
+ CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
+ nullptr, "hlsl.saturate");
+ }
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index cd604bea2e763d..b1455b5779acf9 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -79,6 +79,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Saturate, saturate)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 678cdc77f8a71b..6d38b668fe770e 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -916,7 +916,7 @@ float4 lerp(float4, float4, float4);
/// \brief Returns the length of the specified floating-point vector.
/// \param x [in] The vector of floats, or a scalar float.
///
-/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + …).
+/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
@@ -1564,6 +1564,45 @@ float3 round(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_roundeven)
float4 round(float4);
+//===----------------------------------------------------------------------===//
+// saturate builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T saturate(T Val)
+/// \brief Returns input value, \a Val, clamped within the range of 0.0f
+/// to 1.0f. \param Val The input value.
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+half saturate(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+half2 saturate(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+half3 saturate(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+half4 saturate(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+float saturate(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+float2 saturate(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+float3 saturate(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+float4 saturate(float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+double saturate(double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+double2 saturate(double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+double3 saturate(double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+double4 saturate(double4);
+
//===----------------------------------------------------------------------===//
// sin builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index e3e926465e799e..df01549cc2eeb6 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -356,7 +356,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
return true;
}
-void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
+void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
auto *VD = cast<ValueDecl>(D);
if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
@@ -1045,6 +1045,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_elementwise_saturate:
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
return true;
diff --git a/clang/test/CodeGenHLSL/builtins/saturate.hlsl b/clang/test/CodeGenHLSL/builtins/saturate.hlsl
new file mode 100644
index 00000000000000..65a3cd74621cc0
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/saturate.hlsl
@@ -0,0 +1,95 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=SPIRV,SPIRV_HALF
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-library %s \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=SPIRV,SPIRV_NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.saturate.f16(
+// NO_HALF: define noundef float @"?test_saturate_half
+// NO_HALF: call float @llvm.dx.saturate.f32(
+// SPIRV_HALF: define spir_func noundef half @_Z18test_saturate_halfDh(half
+// SPIRV_HALF: call half @llvm.spv.saturate.f16(half
+// SPIRV_NO_HALF: define spir_func noundef float @_Z18test_saturate_halfDh(float
+// SPIRV_NO_HALF: call float @llvm.spv.saturate.f32(float
+half test_saturate_half(half p0) { return saturate(p0); }
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: call <2 x half> @llvm.dx.saturate.v2f16
+// NO_HALF: define noundef <2 x float> @"?test_saturate_half2
+// NO_HALF: call <2 x float> @llvm.dx.saturate.v2f32(
+// SPIRV_HALF: define spir_func noundef <2 x half> @_Z19test_saturate_half2Dv2_Dh(
+// SPIRV_HALF: call <2 x half> @llvm.spv.saturate.v2f16(<2 x half>
+// SPIRV_NO_HALF: define spir_func noundef <2 x float> @_Z19test_saturate_half2Dv2_Dh(<2 x float>
+// SPIRV_NO_HALF: call <2 x float> @llvm.spv.saturate.v2f32(<2 x float>
+half2 test_saturate_half2(half2 p0) { return saturate(p0); }
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: call <3 x half> @llvm.dx.saturate.v3f16
+// NO_HALF: define noundef <3 x float> @"?test_saturate_half3
+// NO_HALF: call <3 x float> @llvm.dx.saturate.v3f32(
+// SPIRV_HALF: define spir_func noundef <3 x half> @_Z19test_saturate_half3Dv3_Dh(
+// SPIRV_HALF: call <3 x half> @llvm.spv.saturate.v3f16(<3 x half>
+// SPIRV_NO_HALF: define spir_func noundef <3 x float> @_Z19test_saturate_half3Dv3_Dh(<3 x float>
+// SPIRV_NO_HALF: call <3 x float> @llvm.spv.saturate.v3f32(<3 x float>
+half3 test_saturate_half3(half3 p0) { return saturate(p0); }
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: call <4 x half> @llvm.dx.saturate.v4f16
+// NO_HALF: define noundef <4 x float> @"?test_saturate_half4
+// NO_HALF: call <4 x float> @llvm.dx.saturate.v4f32(
+// SPIRV_HALF: define spir_func noundef <4 x half> @_Z19test_saturate_half4Dv4_Dh(
+// SPIRV_HALF: call <4 x half> @llvm.spv.saturate.v4f16(<4 x half>
+// SPIRV_NO_HALF: define spir_func noundef <4 x float> @_Z19test_saturate_half4Dv4_Dh(<4 x float>
+// SPIRV_NO_HALF: call <4 x float> @llvm.spv.saturate.v4f32(<4 x float>
+half4 test_saturate_half4(half4 p0) { return saturate(p0); }
+
+// CHECK: define noundef float @"?test_saturate_float
+// CHECK: call float @llvm.dx.saturate.f32(
+// SPIRV: define spir_func noundef float @_Z19test_saturate_floatf(float
+// SPIRV: call float @llvm.spv.saturate.f32(float
+float test_saturate_float(float p0) { return saturate(p0); }
+// CHECK: define noundef <2 x float> @"?test_saturate_float2
+// CHECK: call <2 x float> @llvm.dx.saturate.v2f32
+// SPIRV: define spir_func noundef <2 x float> @_Z20test_saturate_float2Dv2_f(<2 x float>
+// SPIRV: call <2 x float> @llvm.spv.saturate.v2f32(<2 x float>
+float2 test_saturate_float2(float2 p0) { return saturate(p0); }
+// CHECK: define noundef <3 x float> @"?test_saturate_float3
+// CHECK: call <3 x float> @llvm.dx.saturate.v3f32
+// SPIRV: define spir_func noundef <3 x float> @_Z20test_saturate_float3Dv3_f(<3 x float>
+// SPIRV: call <3 x float> @llvm.spv.saturate.v3f32(<3 x float>
+float3 test_saturate_float3(float3 p0) { return saturate(p0); }
+// CHECK: define noundef <4 x float> @"?test_saturate_float4
+// CHECK: call <4 x float> @llvm.dx.saturate.v4f32
+// SPIRV: define spir_func noundef <4 x float> @_Z20test_saturate_float4Dv4_f(<4 x float>
+// SPIRV: call <4 x float> @llvm.spv.saturate.v4f32(<4 x float>
+float4 test_saturate_float4(float4 p0) { return saturate(p0); }
+
+// CHECK: define noundef double @
+// CHECK: call double @llvm.dx.saturate.f64(
+// SPIRV: define spir_func noundef double @_Z20test_saturate_doubled(double
+// SPIRV: call double @llvm.spv.saturate.f64(double
+double test_saturate_double(double p0) { return saturate(p0); }
+// CHECK: define noundef <2 x double> @
+// CHECK: call <2 x double> @llvm.dx.saturate.v2f64
+// SPIRV: define spir_func noundef <2 x double> @_Z21test_saturate_double2Dv2_d(<2 x double>
+// SPIRV: call <2 x double> @llvm.spv.saturate.v2f64(<2 x double>
+double2 test_saturate_double2(double2 p0) { return saturate(p0); }
+// CHECK: define noundef <3 x double> @
+// CHECK: call <3 x double> @llvm.dx.saturate.v3f64
+// SPIRV: define spir_func noundef <3 x double> @_Z21test_saturate_double3Dv3_d(<3 x double>
+// SPIRV: call <3 x double> @llvm.spv.saturate.v3f64(<3 x double>
+double3 test_saturate_double3(double3 p0) { return saturate(p0); }
+// CHECK: define noundef <4 x double> @
+// CHECK: call <4 x double> @llvm.dx.saturate.v4f64
+// SPIRV: define spir_func noundef <4 x double> @_Z21test_saturate_double4Dv4_d(<4 x double>
+// SPIRV: call <4 x double> @llvm.spv.saturate.v4f64(<4 x double>
+double4 test_saturate_double4(double4 p0) { return saturate(p0); }
diff --git a/clang/test/SemaHLSL/BuiltIns/saturate-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/saturate-errors.hlsl
new file mode 100644
index 00000000000000..721b28f86f950f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/saturate-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected -Werror
+
+float2 test_no_arg() {
+ return saturate();
+ // expected-error at -1 {{no matching function for call to 'saturate'}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return saturate(p0, p0, p0, p0);
+ // expected-error at -1 {{no matching function for call to 'saturate'}}
+}
+
+float2 test_saturate_vector_size_mismatch(float3 p0) {
+ return saturate(p0);
+ // expected-error at -1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 2>'}}
+}
+
+float2 test_saturate_float2_int_splat(int p0) {
+ return saturate(p0);
+ // expected-error at -1 {{call to 'saturate' is ambiguous}}
+}
+
+float2 test_saturate_int_vect_to_float_vec_promotion(int2 p0) {
+ return saturate(p0);
+ // expected-error at -1 {{call to 'saturate' is ambiguous}}
+}
+
+float test_saturate_bool_type_promotion(bool p0) {
+ return saturate(p0);
+ // expected-error at -1 {{call to 'saturate' is ambiguous}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index c9102aa3dd972b..a0807a01ea5ab2 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -34,6 +34,7 @@ def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_dot2 :
Intrinsic<[LLVMVectorElementType<0>],
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 1b5e463822749e..4e130ad0c907d9 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -61,9 +61,10 @@ let TargetPrefix = "spv" in {
def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
- def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
+ def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
+ def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 34c7f84b1ca5b2..bed525a5e5699b 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -330,6 +330,16 @@ def Abs : DXILOp<6, unary> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
+def Saturate : DXILOp<7, unary> {
+ let Doc = "Clamps a single or double precision floating point value to [0.0f...1.0f].";
+ let LLVMIntrinsic = int_dx_saturate;
+ let arguments = [overloadTy];
+ let result = overloadTy;
+ let overloads = [Overloads<DXIL1_0, [halfTy, floatTy, doubleTy]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
def IsInf : DXILOp<9, isSpecialFloat> {
let Doc = "Determines if the specified value is infinite.";
let LLVMIntrinsic = int_dx_isinf;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7cb19279518989..ecb3cee4e781af 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -247,6 +247,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectSaturate(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -259,6 +262,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
Register buildZerosValF(const SPIRVType *ResType, MachineInstr &I) const;
Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
MachineInstr &I) const;
+ Register buildOnesValF(const SPIRVType *ResType, MachineInstr &I) const;
bool wrapIntoSpecConstantOp(MachineInstr &I,
SmallVector<Register> &CompositeArgs) const;
@@ -1285,6 +1289,34 @@ static unsigned getBoolCmpOpcode(unsigned PredNum) {
}
}
+static APFloat getZeroFP(const Type *LLVMFloatTy) {
+ if (!LLVMFloatTy)
+ return APFloat::getZero(APFloat::IEEEsingle());
+ switch (LLVMFloatTy->getScalarType()->getTypeID()) {
+ case Type::HalfTyID:
+ return APFloat::getZero(APFloat::IEEEhalf());
+ default:
+ case Type::FloatTyID:
+ return APFloat::getZero(APFloat::IEEEsingle());
+ case Type::DoubleTyID:
+ return APFloat::getZero(APFloat::IEEEdouble());
+ }
+}
+
+static APFloat getOneFP(const Type *LLVMFloatTy) {
+ if (!LLVMFloatTy)
+ return APFloat::getOne(APFloat::IEEEsingle());
+ switch (LLVMFloatTy->getScalarType()->getTypeID()) {
+ case Type::HalfTyID:
+ return APFloat::getOne(APFloat::IEEEhalf());
+ default:
+ case Type::FloatTyID:
+ return APFloat::getOne(APFloat::IEEEsingle());
+ case Type::DoubleTyID:
+ return APFloat::getOne(APFloat::IEEEdouble());
+ }
+}
+
bool SPIRVInstructionSelector::selectAnyOrAll(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
@@ -1446,6 +1478,28 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+/// Transform saturate(x) to clamp(x, 0.0f, 1.0f) as SPIRV
+/// does not have a saturate builtin.
+bool SPIRVInstructionSelector::selectSaturate(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+ Register VZero = buildZerosValF(ResType, I);
+ Register VOne = buildOnesValF(ResType, I);
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::FClamp)
+ .addUse(I.getOperand(2).getReg())
+ .addUse(VZero)
+ .addUse(VOne)
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1724,20 +1778,6 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
}
-static APFloat getZeroFP(const Type *LLVMFloatTy) {
- if (!LLVMFloatTy)
- return APFloat::getZero(APFloat::IEEEsingle());
- switch (LLVMFloatTy->getScalarType()->getTypeID()) {
- case Type::HalfTyID:
- return APFloat::getZero(APFloat::IEEEhalf());
- default:
- case Type::FloatTyID:
- return APFloat::getZero(APFloat::IEEEsingle());
- case Type::DoubleTyID:
- return APFloat::getZero(APFloat::IEEEdouble());
- }
-}
-
Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
MachineInstr &I) const {
// OpenCL uses nulls for Zero. In HLSL we don't use null constants.
@@ -1748,6 +1788,16 @@ Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
return GR.getOrCreateConstFP(VZero, I, ResType, TII, ZeroAsNull);
}
+Register SPIRVInstructionSelector::buildOnesValF(const SPIRVType *ResType,
+ MachineInstr &I) const {
+ // OpenCL uses nulls for Zero. In HLSL we don't use null constants.
+ bool ZeroAsNull = STI.isOpenCLEnv();
+ APFloat VOne = getOneFP(GR.getTypeForSPIRVType(ResType));
+ if (ResType->getOpcode() == SPIRV::OpTypeVector)
+ return GR.getOrCreateConstVector(VOne, I, ResType, TII, ZeroAsNull);
+ return GR.getOrCreateConstFP(VOne, I, ResType, TII, ZeroAsNull);
+}
+
Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2181,6 +2231,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
Size = 0;
BuildMI(BB, I, I.getDebugLoc(), TII.get(Op)).addUse(PtrReg).addImm(Size);
} break;
+ case Intrinsic::spv_saturate:
+ return selectSaturate(ResVReg, ResType, I);
default: {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
diff --git a/llvm/test/CodeGen/DirectX/saturate.ll b/llvm/test/CodeGen/DirectX/saturate.ll
new file mode 100644
index 00000000000000..a8557351756f2b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/saturate.ll
@@ -0,0 +1,39 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; Make sure the intrinsic dx.saturate is to appropriate DXIL op for half/float/double data types.
+
+; CHECK-LABEL: test_saturate_half
+define noundef half @test_saturate_half(half noundef %p0) #0 {
+entry:
+ ; CHECK: call half @dx.op.unary.f16(i32 7, half %p0)
+ %hlsl.saturate = call half @llvm.dx.saturate.f16(half %p0)
+ ; CHECK: ret half
+ ret half %hlsl.saturate
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare half @llvm.dx.saturate.f16(half) #1
+
+; CHECK-LABEL: test_saturate_float
+define noundef float @test_saturate_float(float noundef %p0) #0 {
+entry:
+ ; CHECK: call float @dx.op.unary.f32(i32 7, float %p0)
+ %hlsl.saturate = call float @llvm.dx.saturate.f32(float %p0)
+ ; CHECK: ret float
+ ret float %hlsl.saturate
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare float @llvm.dx.saturate.f32(float) #1
+
+; CHECK-LABEL: test_saturate_double
+define noundef double @test_saturate_double(double noundef %p0) #0 {
+entry:
+ ; CHECK: call double @dx.op.unary.f64(i32 7, double %p0)
+ %hlsl.saturate = call double @llvm.dx.saturate.f64(double %p0)
+ ; CHECK: ret double
+ ret double %hlsl.saturate
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare double @llvm.dx.saturate.f64(double) #1
+
diff --git a/llvm/test/CodeGen/DirectX/saturate_errors.ll b/llvm/test/CodeGen/DirectX/saturate_errors.ll
new file mode 100644
index 00000000000000..0dd2e04ab56751
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/saturate_errors.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation saturate does not support i32 overload
+; CHECK: invalid intrinsic signature
+
+define noundef i32 @test_saturate_i32(i32 noundef %p0) #0 {
+entry:
+ %hlsl.saturate = call i32 @llvm.dx.saturate.i32(i32 %p0)
+ ret i32 %hlsl.saturate
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/saturate.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/saturate.ll
new file mode 100644
index 00000000000000..0b05b615c4ad17
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/saturate.ll
@@ -0,0 +1,83 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure SPIRV operation function calls for saturate are lowered correctly.
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#vec4_float_64:]] = OpTypeVector %[[#float_64]] 4
+; CHECK-DAG: %[[#zero_float_16:]] = OpConstant %[[#float_16]] 0
+; CHECK-DAG: %[[#vec4_zero_float_16:]] = OpConstantComposite %[[#vec4_float_16]] %[[#zero_float_16]] %[[#zero_float_16]] %[[#zero_float_16]]
+; CHECK-DAG: %[[#one_float_16:]] = OpConstant %[[#float_16]] 15360
+; CHECK-DAG: %[[#vec4_one_float_16:]] = OpConstantComposite %[[#vec4_float_16]] %[[#one_float_16]] %[[#one_float_16]] %[[#one_float_16]]
+; CHECK-DAG: %[[#zero_float_32:]] = OpConstant %[[#float_32]] 0
+; CHECK-DAG: %[[#vec4_zero_float_32:]] = OpConstantComposite %[[#vec4_float_32]] %[[#zero_float_32]] %[[#zero_float_32]] %[[#zero_float_32]]
+; CHECK-DAG: %[[#one_float_32:]] = OpConstant %[[#float_32]] 1
+; CHECK-DAG: %[[#vec4_one_float_32:]] = OpConstantComposite %[[#vec4_float_32]] %[[#one_float_32]] %[[#one_float_32]] %[[#one_float_32]]
+
+; CHECK-DAG: %[[#zero_float_64:]] = OpConstant %[[#float_64]] 0
+; CHECK-DAG: %[[#vec4_zero_float_64:]] = OpConstantComposite %[[#vec4_float_64]] %[[#zero_float_64]] %[[#zero_float_64]] %[[#zero_float_64]]
+; CHECK-DAG: %[[#one_float_64:]] = OpConstant %[[#float_64]] 1
+; CHECK-DAG: %[[#vec4_one_float_64:]] = OpConstantComposite %[[#vec4_float_64]] %[[#one_float_64]] %[[#one_float_64]] %[[#one_float_64]]
+
+define noundef half @saturate_half(half noundef %a) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_16]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_16]]
+ ; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] FClamp %[[#arg0]] %[[#zero_float_16]] %[[#one_float_16]]
+ %hlsl.saturate = call half @llvm.spv.saturate.f16(half %a)
+ ret half %hlsl.saturate
+}
+
+define noundef float @saturate_float(float noundef %a) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_32]]
+ ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] FClamp %[[#arg0]] %[[#zero_float_32]] %[[#one_float_32]]
+ %hlsl.saturate = call float @llvm.spv.saturate.f32(float %a)
+ ret float %hlsl.saturate
+}
+
+define noundef double @saturate_double(double noundef %a) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_64]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_64]]
+ ; CHECK: %[[#]] = OpExtInst %[[#float_64]] %[[#op_ext_glsl]] FClamp %[[#arg0]] %[[#zero_float_64]] %[[#one_float_64]]
+ %hlsl.saturate = call double @llvm.spv.saturate.f64(double %a)
+ ret double %hlsl.saturate
+}
+
+define noundef <4 x half> @saturate_half4(<4 x half> noundef %a) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec4_float_16]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
+ ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] FClamp %[[#arg0]] %[[#vec4_zero_float_16]] %[[#vec4_one_float_16]]
+ %hlsl.saturate = call <4 x half> @llvm.spv.saturate.v4f16(<4 x half> %a)
+ ret <4 x half> %hlsl.saturate
+}
+
+define noundef <4 x float> @saturate_float4(<4 x float> noundef %a) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] FClamp %[[#arg0]] %[[#vec4_zero_float_32]] %[[#vec4_one_float_32]]
+ %hlsl.saturate = call <4 x float> @llvm.spv.saturate.v4f32(<4 x float> %a)
+ ret <4 x float> %hlsl.saturate
+}
+
+define noundef <4 x double> @saturate_double4(<4 x double> noundef %a) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec4_float_64]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_64]]
+ ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_64]] %[[#op_ext_glsl]] FClamp %[[#arg0]] %[[#vec4_zero_float_64]] %[[#vec4_one_float_64]]
+ %hlsl.saturate = call <4 x double> @llvm.spv.saturate.v4f64(<4 x double> %a)
+ ret <4 x double> %hlsl.saturate
+}
+
+declare <4 x half> @llvm.spv.saturate.v4f16(<4 x half>)
+declare <4 x float> @llvm.spv.saturate.v4f32(<4 x float>)
+declare <4 x double> @llvm.spv.saturate.v4f64(<4 x double>)
More information about the llvm-commits
mailing list