[clang] [llvm] [HLSL] Implement support for HLSL intrinsic - saturate (PR #104619)

via cfe-commits cfe-commits at lists.llvm.org
Fri Aug 16 10:25:11 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: S. Bharadwaj Yadavalli (bharadwajy)

<details>
<summary>Changes</summary>

Implement support for HLSL intrinsic saturate.
Implement DXIL codegen for the intrinsic saturate by lowering it to DXIL Op dx.saturate.
Implement SPIRV codegen by transforming saturate(x) to clamp(x, 0.0f, 1.0f).

Add tests for DXIL and SPIRV CodeGen.

---

Patch is 39.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104619.diff


15 Files Affected:

- (modified) clang/include/clang/Basic/Builtins.td (+6) 
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+9) 
- (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1) 
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+40-1) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+2-1) 
- (added) clang/test/CodeGenHLSL/builtins/saturate.hlsl (+54) 
- (added) clang/test/SemaHLSL/BuiltIns/saturate-errors.hlsl (+31) 
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1) 
- (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+2-1) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+10) 
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+32) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+66-14) 
- (added) llvm/test/CodeGen/DirectX/saturate.ll (+276) 
- (added) llvm/test/CodeGen/DirectX/saturate_errors.ll (+14) 
- (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/saturate.ll (+83) 


``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 0a874d8638df43..76e893e38b671c 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4743,6 +4743,12 @@ def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_saturate"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 1c0baeaee03632..01841774562f06 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18666,6 +18666,15 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(),
         ArrayRef<Value *>{Op0}, nullptr, "hlsl.rsqrt");
   }
+  case Builtin::BI__builtin_hlsl_elementwise_saturate: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+      llvm_unreachable("saturate operand must have a float representation");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Op0->getType(),
+        CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
+        nullptr, "hlsl.saturate");
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index cd604bea2e763d..b1455b5779acf9 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -79,6 +79,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Saturate, saturate)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
 
   //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 678cdc77f8a71b..6d38b668fe770e 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -916,7 +916,7 @@ float4 lerp(float4, float4, float4);
 /// \brief Returns the length of the specified floating-point vector.
 /// \param x [in] The vector of floats, or a scalar float.
 ///
-/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + �).
+/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
 
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
@@ -1564,6 +1564,45 @@ float3 round(float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_roundeven)
 float4 round(float4);
 
+//===----------------------------------------------------------------------===//
+// saturate builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T saturate(T Val)
+/// \brief Returns input value, \a Val, clamped within the range of 0.0f
+/// to 1.0f. \param Val The input value.
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+half saturate(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+half2 saturate(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+half3 saturate(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+half4 saturate(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+float saturate(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+float2 saturate(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+float3 saturate(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+float4 saturate(float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+double saturate(double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+double2 saturate(double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+double3 saturate(double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
+double4 saturate(double4);
+
 //===----------------------------------------------------------------------===//
 // sin builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index e3e926465e799e..df01549cc2eeb6 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -356,7 +356,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
   return true;
 }
 
-void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {  
+void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
   auto *VD = cast<ValueDecl>(D);
   if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
     Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
@@ -1045,6 +1045,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
     if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/saturate.hlsl b/clang/test/CodeGenHLSL/builtins/saturate.hlsl
new file mode 100644
index 00000000000000..970d7b7371b1eb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/saturate.hlsl
@@ -0,0 +1,54 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.saturate.f16(
+// NO_HALF: define noundef float @"?test_saturate_half
+// NO_HALF: call float @llvm.dx.saturate.f32(
+half test_saturate_half(half p0) { return saturate(p0); }
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: call <2 x half> @llvm.dx.saturate.v2f16
+// NO_HALF: define noundef <2 x float> @"?test_saturate_half2
+// NO_HALF: call <2 x float> @llvm.dx.saturate.v2f32(
+half2 test_saturate_half2(half2 p0) { return saturate(p0); }
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: call <3 x half> @llvm.dx.saturate.v3f16
+// NO_HALF: define noundef <3 x float> @"?test_saturate_half3
+// NO_HALF: call <3 x float> @llvm.dx.saturate.v3f32(
+half3 test_saturate_half3(half3 p0) { return saturate(p0); }
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: call <4 x half> @llvm.dx.saturate.v4f16
+// NO_HALF: define noundef <4 x float> @"?test_saturate_half4
+// NO_HALF: call <4 x float> @llvm.dx.saturate.v4f32(
+half4 test_saturate_half4(half4 p0) { return saturate(p0); }
+
+// CHECK: define noundef float @"?test_saturate_float
+// CHECK: call float @llvm.dx.saturate.f32(
+float test_saturate_float(float p0) { return saturate(p0); }
+// CHECK: define noundef <2 x float> @"?test_saturate_float2
+// CHECK: call <2 x float> @llvm.dx.saturate.v2f32
+float2 test_saturate_float2(float2 p0) { return saturate(p0); }
+// CHECK: define noundef <3 x float> @"?test_saturate_float3
+// CHECK: call <3 x float> @llvm.dx.saturate.v3f32
+float3 test_saturate_float3(float3 p0) { return saturate(p0); }
+// CHECK: define noundef <4 x float> @"?test_saturate_float4
+// CHECK: call <4 x float> @llvm.dx.saturate.v4f32
+float4 test_saturate_float4(float4 p0) { return saturate(p0); }
+
+// CHECK: define noundef double @
+// CHECK: call double @llvm.dx.saturate.f64(
+double test_saturate_double(double p0) { return saturate(p0); }
+// CHECK: define noundef <2 x double> @
+// CHECK: call <2 x double> @llvm.dx.saturate.v2f64
+double2 test_saturate_double2(double2 p0) { return saturate(p0); }
+// CHECK: define noundef <3 x double> @
+// CHECK: call <3 x double> @llvm.dx.saturate.v3f64
+double3 test_saturate_double3(double3 p0) { return saturate(p0); }
+// CHECK: define noundef <4 x double> @
+// CHECK: call <4 x double> @llvm.dx.saturate.v4f64
+double4 test_saturate_double4(double4 p0) { return saturate(p0); }
diff --git a/clang/test/SemaHLSL/BuiltIns/saturate-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/saturate-errors.hlsl
new file mode 100644
index 00000000000000..721b28f86f950f
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/saturate-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected -Werror
+
+float2 test_no_arg() {
+  return saturate();
+  // expected-error at -1 {{no matching function for call to 'saturate'}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return saturate(p0, p0, p0, p0);
+  // expected-error at -1 {{no matching function for call to 'saturate'}}
+}
+
+float2 test_saturate_vector_size_mismatch(float3 p0) {
+  return saturate(p0);
+  // expected-error at -1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 2>'}}
+}
+
+float2 test_saturate_float2_int_splat(int p0) {
+  return saturate(p0);
+  // expected-error at -1 {{call to 'saturate' is ambiguous}}
+}
+
+float2 test_saturate_int_vect_to_float_vec_promotion(int2 p0) {
+  return saturate(p0);
+  // expected-error at -1 {{call to 'saturate' is ambiguous}}
+}
+
+float test_saturate_bool_type_promotion(bool p0) {
+  return saturate(p0);
+  // expected-error at -1 {{call to 'saturate' is ambiguous}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index c9102aa3dd972b..a0807a01ea5ab2 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -34,6 +34,7 @@ def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
 def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; 
+def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 
 def int_dx_dot2 : 
     Intrinsic<[LLVMVectorElementType<0>], 
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 1b5e463822749e..4e130ad0c907d9 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -61,9 +61,10 @@ let TargetPrefix = "spv" in {
   def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
   def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
   def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
-  def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], 
+  def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
     [IntrNoMem, IntrWillReturn] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
   def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
+  def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 67015cff78a79a..ac378db2c9b499 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -325,6 +325,16 @@ def Abs :  DXILOp<6, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+def Saturate :  DXILOp<7, unary> {
+  let Doc = "Clamps a single or double precision floating point value to [0.0f...1.0f].";
+  let LLVMIntrinsic = int_dx_saturate;
+  let arguments = [overloadTy];
+  let result = overloadTy;
+  let overloads = [Overloads<DXIL1_0, [halfTy, floatTy, doubleTy]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def IsInf :  DXILOp<9, isSpecialFloat> {
   let Doc = "Determines if the specified value is infinite.";
   let LLVMIntrinsic = int_dx_isinf;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e63633b8a1e1ab..4285b5e5d5a48c 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -46,6 +47,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_normalize:
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
+  case Intrinsic::dx_saturate:
     return true;
   }
   return false;
@@ -362,6 +364,34 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
   return true;
 }
 
+static bool expandSaturateIntrinsic(CallInst *SaturateCall) {
+  FunctionType *FT = SaturateCall->getFunctionType();
+  Type *FTRetTy = FT->getReturnType();
+  assert(FTRetTy == FT->getParamType(0) &&
+         "Unexpected different operand and return types of call to saturate");
+  if (FTRetTy->isVectorTy()) {
+    IRBuilder<> Builder(SaturateCall->getParent());
+    Builder.SetInsertPoint(SaturateCall);
+    FixedVectorType *FTVecRetTy = dyn_cast<FixedVectorType>(FTRetTy);
+    Function *Callee = dyn_cast<Function>(SaturateCall->getOperand(1));
+    assert(Callee->getIntrinsicID() == Intrinsic::dx_saturate);
+    Value *SrcVec = SaturateCall->getOperand(0);
+    Type *EltTy = FTVecRetTy->getScalarType();
+    Function *ScalarSatCallee = Intrinsic::getDeclaration(
+        SaturateCall->getModule(), Intrinsic::dx_saturate, {EltTy});
+    Value *Result;
+    for (unsigned I = 0; I < FTVecRetTy->getNumElements(); I++) {
+      Value *Elt = Builder.CreateExtractElement(SrcVec, I);
+      CallInst *CallSatutate =
+          Builder.CreateCall(ScalarSatCallee, {Elt}, "dx_saturate");
+      Result = Builder.CreateInsertElement(SrcVec, CallSatutate, I);
+    }
+    SaturateCall->replaceAllUsesWith(Result);
+    SaturateCall->eraseFromParent();
+  }
+  return true;
+}
+
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
   switch (F.getIntrinsicID()) {
   case Intrinsic::abs:
@@ -388,6 +418,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return expandIntegerDot(Orig, F.getIntrinsicID());
+  case Intrinsic::dx_saturate:
+    return expandSaturateIntrinsic(Orig);
   }
   return false;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7cb19279518989..ecb3cee4e781af 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -247,6 +247,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
                        MachineInstr &I) const;
 
+  bool selectSaturate(Register ResVReg, const SPIRVType *ResType,
+                      MachineInstr &I) const;
+
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
@@ -259,6 +262,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
   Register buildZerosValF(const SPIRVType *ResType, MachineInstr &I) const;
   Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
                         MachineInstr &I) const;
+  Register buildOnesValF(const SPIRVType *ResType, MachineInstr &I) const;
 
   bool wrapIntoSpecConstantOp(MachineInstr &I,
                               SmallVector<Register> &CompositeArgs) const;
@@ -1285,6 +1289,34 @@ static unsigned getBoolCmpOpcode(unsigned PredNum) {
   }
 }
 
+static APFloat getZeroFP(const Type *LLVMFloatTy) {
+  if (!LLVMFloatTy)
+    return APFloat::getZero(APFloat::IEEEsingle());
+  switch (LLVMFloatTy->getScalarType()->getTypeID()) {
+  case Type::HalfTyID:
+    return APFloat::getZero(APFloat::IEEEhalf());
+  default:
+  case Type::FloatTyID:
+    return APFloat::getZero(APFloat::IEEEsingle());
+  case Type::DoubleTyID:
+    return APFloat::getZero(APFloat::IEEEdouble());
+  }
+}
+
+static APFloat getOneFP(const Type *LLVMFloatTy) {
+  if (!LLVMFloatTy)
+    return APFloat::getOne(APFloat::IEEEsingle());
+  switch (LLVMFloatTy->getScalarType()->getTypeID()) {
+  case Type::HalfTyID:
+    return APFloat::getOne(APFloat::IEEEhalf());
+  default:
+  case Type::FloatTyID:
+    return APFloat::getOne(APFloat::IEEEsingle());
+  case Type::DoubleTyID:
+    return APFloat::getOne(APFloat::IEEEdouble());
+  }
+}
+
 bool SPIRVInstructionSelector::selectAnyOrAll(Register ResVReg,
                                               const SPIRVType *ResType,
                                               MachineInstr &I,
@@ -1446,6 +1478,28 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+/// Transform saturate(x) to clamp(x, 0.0f, 1.0f) as SPIRV
+/// does not have a saturate builtin.
+bool SPIRVInstructionSelector::selectSaturate(Register ResVReg,
+                                              const SPIRVType *ResType,
+                                              MachineInstr &I) const {
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+  Register VZero = buildZerosValF(ResType, I);
+  Register VOne = buildOnesValF(ResType, I);
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::FClamp)
+      .addUse(I.getOperand(2).getReg())
+      .addUse(VZero)
+      .addUse(VOne)
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
                                                 const SPIRVType *ResType,
                                                 MachineInstr &I) const {
@@ -1724,20 +1778,6 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
 }
 
-static APFloat getZeroFP(const Type *LLVMFloatTy) {
-  if (!LLVMFloatTy)
-    return APFloat::getZero(APFloat::IEEEsingle());
-  switch (LLVMFloatTy->getScalarType()->getTypeID()) {
-  case Type::HalfTyID:
-    return APFloat::getZero(APFloat::IEEEhalf());
-  default:
-  case Type::FloatTyID:
-    return APFloat::getZero(APFloat::IEEEsingle());
-  case Type::DoubleTyID:
-    return APFloat::getZero(APFloat::IEEEdouble());
-  }
-}
-
 Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
                                                   MachineInstr &I) const {
   // OpenCL uses nulls for Zero. In HLSL we don't use null constants.
@@ -1748,6 +1788,16 @@ Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
   return GR.getOrCreateConstFP(VZero, I, ResType, TII, ZeroAsNull);
 }
 
+Register SPIRVInstructionSelector::buildOnesValF(const SPIRVType *ResType,
+                                                 MachineInstr &I) const {
+  // OpenCL uses nulls for Zero. In HLSL we don't use null constants.
+  bool ZeroAsNull = STI.isOpenCLEnv();
+  APFloat VOne = getOneFP(GR.getTypeForSPIRVType(ResType));
+  if (ResType->getOpcode() ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/104619


More information about the cfe-commits mailing list