[clang] [llvm] [DXIL] `exp`, `any`, `lerp`, & `rcp` Intrinsic Lowering (PR #84526)

Farzon Lotfi via cfe-commits cfe-commits at lists.llvm.org
Thu Mar 14 17:01:45 PDT 2024


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/84526

>From 3f515637fc87a41db1df4ea7627679c7dd75503a Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 14 Mar 2024 18:53:33 -0400
Subject: [PATCH 1/3] [DXIL] exp, any, lerp, & rcp Intrinsic Lowering This
 change implements lowering for #70076, #70100, #70072, & #70102
 `CGBuiltin.cpp` - - simplify `lerp` intrinsic `IntrinsicsDirectX.td` -
 simplify `lerp` intrinsic `SemaChecking.cpp` - remove unnecessary check
 `DXILIntrinsicExpansion.*` - add intrinsic to instruction expansion cases
 `DXILOpLowering.cpp` -  make sure `DXILIntrinsicExpansion` happens first
 `DirectX.h` - changes to support new pass `DirectXTargetMachine.cpp` -
 changes to support new pass

Why `any`, and `lerp` as instruction expansion just for DXIL?
-  SPIR-V there is an [OpAny](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpAny)
- SPIR-V has a GLSL lerp extension via [Fmix](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#FMix)

Why `exp` instruction expansion?
- We have an `exp2` opcode and `exp` reuses that opcode. So instruction expansion is a convenient way to do preprocessing.
- Further SPIR-V has a GLSL exp extension via  [Exp](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#Exp) and [Exp2](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#Exp2)

Why `rcp` as instruction expansion?
This one is a bit of the odd man out and might have to move to `cgbuiltins` when we better understand SPIRV requirements. However I included it because it seems like [fast math mode has an AllowRecip flag](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_fp_fast_math_mode) which lets you compute the reciprocal without performing the division.  We don't have that in DXIL so thought to include it.
---
 clang/include/clang/AST/Type.h                |   5 +
 clang/lib/CodeGen/CGBuiltin.cpp               |  35 +---
 clang/lib/Sema/SemaChecking.cpp               |  51 +++--
 .../CodeGenHLSL/builtins/lerp-builtin.hlsl    |  22 ---
 clang/test/CodeGenHLSL/builtins/lerp.hlsl     |  27 ++-
 clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl |  17 +-
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |   4 +-
 llvm/lib/Target/DirectX/CMakeLists.txt        |   1 +
 .../Target/DirectX/DXILIntrinsicExpansion.cpp | 186 ++++++++++++++++++
 .../Target/DirectX/DXILIntrinsicExpansion.h   |  33 ++++
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    |   6 +-
 llvm/lib/Target/DirectX/DirectX.h             |   6 +
 .../Target/DirectX/DirectXTargetMachine.cpp   |   2 +
 llvm/test/CodeGen/DirectX/any.ll              | 105 ++++++++++
 llvm/test/CodeGen/DirectX/exp-vec.ll          |  16 ++
 llvm/test/CodeGen/DirectX/exp.ll              |  29 +++
 llvm/test/CodeGen/DirectX/lerp.ll             |  53 +++++
 llvm/test/CodeGen/DirectX/rcp.ll              |  48 +++++
 18 files changed, 557 insertions(+), 89 deletions(-)
 create mode 100644 llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
 create mode 100644 llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
 create mode 100644 llvm/test/CodeGen/DirectX/any.ll
 create mode 100644 llvm/test/CodeGen/DirectX/exp-vec.ll
 create mode 100644 llvm/test/CodeGen/DirectX/exp.ll
 create mode 100644 llvm/test/CodeGen/DirectX/lerp.ll
 create mode 100644 llvm/test/CodeGen/DirectX/rcp.ll

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 1942b0e67f65a3..10916053cdfbf5 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2244,6 +2244,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFloatingType() const;     // C99 6.2.5p11 (real floating + complex)
   bool isHalfType() const;         // OpenCL 6.1.1.1, NEON (IEEE 754-2008 half)
   bool isFloat16Type() const;      // C11 extension ISO/IEC TS 18661
+  bool isFloat32Type() const;
   bool isBFloat16Type() const;
   bool isFloat128Type() const;
   bool isIbm128Type() const;
@@ -7452,6 +7453,10 @@ inline bool Type::isFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::Float16);
 }
 
+inline bool Type::isFloat32Type() const {
+  return isSpecificBuiltinType(BuiltinType::Float);
+}
+
 inline bool Type::isBFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::BFloat16);
 }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index b9d1a4912385e1..b09bf563622089 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18021,38 +18021,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     Value *X = EmitScalarExpr(E->getArg(0));
     Value *Y = EmitScalarExpr(E->getArg(1));
     Value *S = EmitScalarExpr(E->getArg(2));
-    llvm::Type *Xty = X->getType();
-    llvm::Type *Yty = Y->getType();
-    llvm::Type *Sty = S->getType();
-    if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
-      if (Xty->isFloatingPointTy()) {
-        auto V = Builder.CreateFSub(Y, X);
-        V = Builder.CreateFMul(S, V);
-        return Builder.CreateFAdd(X, V, "dx.lerp");
-      }
-      llvm_unreachable("Scalar Lerp is only supported on floats.");
-    }
-    // A VectorSplat should have happened
-    assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
-           "Lerp of vector and scalar is not supported.");
-
-    [[maybe_unused]] auto *XVecTy =
-        E->getArg(0)->getType()->getAs<VectorType>();
-    [[maybe_unused]] auto *YVecTy =
-        E->getArg(1)->getType()->getAs<VectorType>();
-    [[maybe_unused]] auto *SVecTy =
-        E->getArg(2)->getType()->getAs<VectorType>();
-    // A HLSLVectorTruncation should have happend
-    assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
-           XVecTy->getNumElements() == SVecTy->getNumElements() &&
-           "Lerp requires vectors to be of the same size.");
-    assert(XVecTy->getElementType()->isRealFloatingType() &&
-           XVecTy->getElementType() == YVecTy->getElementType() &&
-           XVecTy->getElementType() == SVecTy->getElementType() &&
-           "Lerp requires float vectors to be of the same type.");
+    if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+      llvm_unreachable("lerp operand must have a float representation");
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
-        nullptr, "dx.lerp");
+        /*ReturnType=*/X->getType(), Intrinsic::dx_lerp,
+        ArrayRef<Value *>{X, Y, S}, nullptr, "dx.lerp");
   }
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 769de5c4a8fcfc..d88a38eb6eb97b 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5234,10 +5234,6 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
                            TheCall->getArg(1)->getEndLoc());
         retValue = true;
       }
-
-      if (!retValue)
-        TheCall->setType(VecTyA->getElementType());
-
       return retValue;
     }
   }
@@ -5251,11 +5247,12 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
   return true;
 }
 
-bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
-  QualType ExpectedType = S->Context.FloatTy;
+bool CheckArgsTypesAreCorrect(
+    Sema *S, CallExpr *TheCall, QualType ExpectedType,
+    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
   for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
     QualType PassedType = TheCall->getArg(i)->getType();
-    if (!PassedType->hasFloatingRepresentation()) {
+    if (Check(PassedType)) {
       if (auto *VecTyA = PassedType->getAs<VectorType>())
         ExpectedType = S->Context.getVectorType(
             ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
@@ -5268,6 +5265,26 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
+  auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
+    return !PassedType->hasFloatingRepresentation();
+  };
+  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+                                  checkAllFloatTypes);
+}
+
+bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
+  auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
+    clang::QualType BaseType =
+        PassedType->isVectorType()
+            ? PassedType->getAs<clang::VectorType>()->getElementType()
+            : PassedType;
+    return !BaseType->isHalfType() && !BaseType->isFloat32Type();
+  };
+  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+                                  checkFloatorHalf);
+}
+
 void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
                                 QualType ReturnType) {
   auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
@@ -5295,21 +5312,27 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
-  case Builtin::BI__builtin_hlsl_elementwise_isinf: {
-    if (checkArgCount(*this, TheCall, 1))
-      return true;
+  case Builtin::BI__builtin_hlsl_elementwise_rcp: {
     if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
       return true;
-    SetElementTypeAsReturnType(this, TheCall, this->Context.BoolTy);
+    if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
-  case Builtin::BI__builtin_hlsl_elementwise_rcp:
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
-    if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
+    if (CheckFloatOrHalfRepresentations(this, TheCall))
+      return true;
+    if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+      return true;
+    break;
+  }
+  case Builtin::BI__builtin_hlsl_elementwise_isinf: {
+    if (CheckFloatOrHalfRepresentations(this, TheCall))
       return true;
     if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
+    SetElementTypeAsReturnType(this, TheCall, this->Context.BoolTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_lerp: {
@@ -5319,7 +5342,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     if (SemaBuiltinElementwiseTernaryMath(TheCall))
       return true;
-    if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
+    if (CheckFloatOrHalfRepresentations(this, TheCall))
       return true;
     break;
   }
diff --git a/clang/test/CodeGenHLSL/builtins/lerp-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/lerp-builtin.hlsl
index 1f16dec68212e4..2fd5a19fc33521 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp-builtin.hlsl
@@ -1,27 +1,5 @@
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
 
-
-
-// CHECK-LABEL: builtin_lerp_half_scalar
-// CHECK: %3 = fsub double %conv1, %conv
-// CHECK: %4 = fmul double %conv2, %3
-// CHECK: %dx.lerp = fadd double %conv, %4
-// CHECK: %conv3 = fptrunc double %dx.lerp to half
-// CHECK: ret half %conv3
-half builtin_lerp_half_scalar (half p0) {
-  return __builtin_hlsl_lerp ( p0, p0, p0 );
-}
-
-// CHECK-LABEL: builtin_lerp_float_scalar
-// CHECK: %3 = fsub double %conv1, %conv
-// CHECK: %4 = fmul double %conv2, %3
-// CHECK: %dx.lerp = fadd double %conv, %4
-// CHECK: %conv3 = fptrunc double %dx.lerp to float
-// CHECK: ret float %conv3
-float builtin_lerp_float_scalar ( float p0) {
-  return __builtin_hlsl_lerp ( p0, p0, p0 );
-}
-
 // CHECK-LABEL: builtin_lerp_half_vector
 // CHECK: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
 // CHECK: ret <3 x half> %dx.lerp
diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index a6b3d9643d674c..49cd04a10115ae 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -6,13 +6,10 @@
 // RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
-// NATIVE_HALF: %3 = fsub half %1, %0
-// NATIVE_HALF: %4 = fmul half %2, %3
-// NATIVE_HALF: %dx.lerp = fadd half %0, %4
+
+// NATIVE_HALF: %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
 // NATIVE_HALF: ret half %dx.lerp
-// NO_HALF: %3 = fsub float %1, %0
-// NO_HALF: %4 = fmul float %2, %3
-// NO_HALF: %dx.lerp = fadd float %0, %4
+// NO_HALF: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
 // NO_HALF: ret float %dx.lerp
 half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
 
@@ -20,37 +17,35 @@ half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
 // NATIVE_HALF: ret <2 x half> %dx.lerp
 // NO_HALF: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
 // NO_HALF: ret <2 x float> %dx.lerp
-half2 test_lerp_half2(half2 p0, half2 p1) { return lerp(p0, p0, p0); }
+half2 test_lerp_half2(half2 p0) { return lerp(p0, p0, p0); }
 
 // NATIVE_HALF: %dx.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
 // NATIVE_HALF: ret <3 x half> %dx.lerp
 // NO_HALF: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
 // NO_HALF: ret <3 x float> %dx.lerp
-half3 test_lerp_half3(half3 p0, half3 p1) { return lerp(p0, p0, p0); }
+half3 test_lerp_half3(half3 p0) { return lerp(p0, p0, p0); }
 
 // NATIVE_HALF: %dx.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
 // NATIVE_HALF: ret <4 x half> %dx.lerp
 // NO_HALF: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
 // NO_HALF: ret <4 x float> %dx.lerp
-half4 test_lerp_half4(half4 p0, half4 p1) { return lerp(p0, p0, p0); }
+half4 test_lerp_half4(half4 p0) { return lerp(p0, p0, p0); }
 
-// CHECK: %3 = fsub float %1, %0
-// CHECK: %4 = fmul float %2, %3
-// CHECK: %dx.lerp = fadd float %0, %4
+// CHECK: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
 // CHECK: ret float %dx.lerp
-float test_lerp_float(float p0, float p1) { return lerp(p0, p0, p0); }
+float test_lerp_float(float p0) { return lerp(p0, p0, p0); }
 
 // CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
 // CHECK: ret <2 x float> %dx.lerp
-float2 test_lerp_float2(float2 p0, float2 p1) { return lerp(p0, p0, p0); }
+float2 test_lerp_float2(float2 p0) { return lerp(p0, p0, p0); }
 
 // CHECK: %dx.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
 // CHECK: ret <3 x float> %dx.lerp
-float3 test_lerp_float3(float3 p0, float3 p1) { return lerp(p0, p0, p0); }
+float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }
 
 // CHECK: %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
 // CHECK: ret <4 x float> %dx.lerp
-float4 test_lerp_float4(float4 p0, float4 p1) { return lerp(p0, p0, p0); }
+float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }
 
 // CHECK: %dx.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
 // CHECK: ret <2 x float> %dx.lerp
diff --git a/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
index f6ce87e7c33e3e..83751f68357edf 100644
--- a/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
@@ -92,5 +92,18 @@ float builtin_lerp_int_to_float_promotion(float p0, int p1) {
 
 float4 test_lerp_int4(int4 p0, int4 p1, int4 p2) {
   return __builtin_hlsl_lerp(p0, p1, p2);
-   // expected-error at -1 {{1st argument must be a floating point type (was 'int4' (aka 'vector<int, 4>'))}}
-}
\ No newline at end of file
+  // expected-error at -1 {{1st argument must be a floating point type (was 'int4' (aka 'vector<int, 4>'))}}
+}
+
+// note: DefaultVariadicArgumentPromotion --> DefaultArgumentPromotion has already promoted to double
+// we don't know anymore that the input was half when __builtin_hlsl_lerp is called so we default to float
+// for expected type
+half builtin_lerp_half_scalar (half p0) {
+  return __builtin_hlsl_lerp ( p0, p0, p0 );
+  // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
+
+float builtin_lerp_float_scalar ( float p0) {
+  return __builtin_hlsl_lerp ( p0, p0, p0 );
+  // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 957cc8f2e15eb7..00536c71c3e2e5 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -33,9 +33,7 @@ def int_dx_isinf :
     DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
     [llvm_anyfloat_ty]>;
 
-def int_dx_lerp :
-    Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], 
     [IntrNoMem, IntrWillReturn] >;
 
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index bf93280779bf8b..4c70b3f9230edb 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -19,6 +19,7 @@ add_llvm_target(DirectXCodeGen
   DirectXSubtarget.cpp
   DirectXTargetMachine.cpp
   DXContainerGlobals.cpp
+  DXILIntrinsicExpansion.cpp
   DXILMetadata.cpp
   DXILOpBuilder.cpp
   DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
new file mode 100644
index 00000000000000..0461f0490017bf
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -0,0 +1,186 @@
+//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains DXIL intrinsic expansions for those that don't have
+//  opcodes in DirectX Intermediate Language (DXIL).
+//===----------------------------------------------------------------------===//
+
+#include "DXILIntrinsicExpansion.h"
+#include "DirectX.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MathExtras.h"
+
+#define DEBUG_TYPE "dxil-intrinsic-expansion"
+
+using namespace llvm;
+
+static bool isIntrinsicExpansion(Function &F) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::exp:
+  case Intrinsic::dx_any:
+  case Intrinsic::dx_lerp:
+  case Intrinsic::dx_rcp:
+    return true;
+  }
+  return false;
+}
+
+static bool expandExpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+  Constant *Log2eConst =
+      Ty->isVectorTy() ? ConstantVector::getSplat(
+                             ElementCount::getFixed(
+                                 cast<FixedVectorType>(Ty)->getNumElements()),
+                             ConstantFP::get(EltTy, numbers::log2e))
+                       : ConstantFP::get(EltTy, numbers::log2e);
+  Value *NewX = Builder.CreateFMul(Log2eConst, X);
+  auto *Exp2Call =
+      Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
+  Exp2Call->setTailCall(Orig->isTailCall());
+  Exp2Call->setAttributes(Orig->getAttributes());
+  Orig->replaceAllUsesWith(Exp2Call);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandAnyIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+
+  if (!Ty->isVectorTy()) {
+    Value *Cond = EltTy->isFloatingPointTy()
+                      ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
+                      : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
+    Orig->replaceAllUsesWith(Cond);
+  } else {
+    auto *XVec = dyn_cast<FixedVectorType>(Ty);
+    Value *Cond =
+        EltTy->isFloatingPointTy()
+            ? Builder.CreateFCmpUNE(
+                  X, ConstantVector::getSplat(
+                         ElementCount::getFixed(XVec->getNumElements()),
+                         ConstantFP::get(EltTy, 0)))
+            : Builder.CreateICmpNE(
+                  X, ConstantVector::getSplat(
+                         ElementCount::getFixed(XVec->getNumElements()),
+                         ConstantInt::get(EltTy, 0)));
+    Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
+    for (unsigned I = 1; I < XVec->getNumElements(); I++) {
+      Value *Elt = Builder.CreateExtractElement(Cond, I);
+      Result = Builder.CreateOr(Result, Elt);
+    }
+    Orig->replaceAllUsesWith(Result);
+  }
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandLerpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  Value *Y = Orig->getOperand(1);
+  Value *S = Orig->getOperand(2);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  auto *V = Builder.CreateFSub(Y, X);
+  V = Builder.CreateFMul(S, V);
+  auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandRcpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+  Constant *One =
+      Ty->isVectorTy()
+          ? ConstantVector::getSplat(
+                ElementCount::getFixed(
+                    dyn_cast<FixedVectorType>(Ty)->getNumElements()),
+                ConstantFP::get(EltTy, 1.0))
+          : ConstantFP::get(EltTy, 1.0);
+  auto *Result = Builder.CreateFDiv(One, X, "dx.rcp");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandIntrinsic(Function &F, CallInst *Orig) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::exp:
+    return expandExpIntrinsic(Orig);
+  case Intrinsic::dx_any:
+    return expandAnyIntrinsic(Orig);
+  case Intrinsic::dx_lerp:
+    return expandLerpIntrinsic(Orig);
+  case Intrinsic::dx_rcp:
+    return expandRcpIntrinsic(Orig);
+  }
+  return false;
+}
+
+static bool expansionIntrinsics(Module &M) {
+  for (auto &F : make_early_inc_range(M.functions())) {
+    if (!isIntrinsicExpansion(F))
+      continue;
+    bool IntrinsicExpanded = false;
+    for (User *U : make_early_inc_range(F.users())) {
+      auto *IntrinsicCall = dyn_cast<CallInst>(U);
+      if (!IntrinsicCall)
+        continue;
+      IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
+    }
+    if (F.user_empty() && IntrinsicExpanded)
+      F.eraseFromParent();
+  }
+  return true;
+}
+
+PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
+                                              ModuleAnalysisManager &) {
+  if (expansionIntrinsics(M))
+    return PreservedAnalyses::none();
+  return PreservedAnalyses::all();
+}
+
+bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
+  return expansionIntrinsics(M);
+}
+
+char DXILIntrinsicExpansionLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
+                      "DXIL Intrinsic Expansion", false, false)
+INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
+                    "DXIL Intrinsic Expansion", false, false)
+
+ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
+  return new DXILIntrinsicExpansionLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
new file mode 100644
index 00000000000000..c86681af7a3712
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
@@ -0,0 +1,33 @@
+//===- DXILIntrinsicExpansion.h - Prepare LLVM Module for DXIL encoding----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
+#define LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms DXIL Intrinsics that don't have DXIL opCodes
+class DXILIntrinsicExpansion : public PassInfoMixin<DXILIntrinsicExpansion> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+
+class DXILIntrinsicExpansionLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
+
+  static char ID; // Pass identification.
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 6b649b76beecdf..e5c2042e7d16ae 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILConstants.h"
+#include "DXILIntrinsicExpansion.h"
 #include "DXILOpBuilder.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallVector.h"
@@ -94,9 +95,12 @@ class DXILOpLoweringLegacy : public ModulePass {
   DXILOpLoweringLegacy() : ModulePass(ID) {}
 
   static char ID; // Pass identification.
+  void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+    // Specify the passes that your pass depends on
+    AU.addRequired<DXILIntrinsicExpansionLegacy>();
+  }
 };
 char DXILOpLoweringLegacy::ID = 0;
-
 } // end anonymous namespace
 
 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index eaecc3ac280c4c..11b5412c21d783 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -28,6 +28,12 @@ void initializeDXILPrepareModulePass(PassRegistry &);
 /// Pass to convert modules into DXIL-compatable modules
 ModulePass *createDXILPrepareModulePass();
 
+/// Initializer for DXIL Intrinsic Expansion
+void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
+
+/// Pass to expand intrinsic operations that lack DXIL opCodes
+ModulePass *createDXILIntrinsicExpansionLegacyPass();
+
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 06938f8c74f155..03c825b3977db3 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -39,6 +39,7 @@ using namespace llvm;
 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
   auto *PR = PassRegistry::getPassRegistry();
+  initializeDXILIntrinsicExpansionLegacyPass(*PR);
   initializeDXILPrepareModulePass(*PR);
   initializeEmbedDXILPassPass(*PR);
   initializeWriteDXILPassPass(*PR);
@@ -76,6 +77,7 @@ class DirectXPassConfig : public TargetPassConfig {
 
   FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
   void addCodeGenPrepare() override {
+    addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILPrepareModulePass());
     addPass(createDXILTranslateMetadataPass());
diff --git a/llvm/test/CodeGen/DirectX/any.ll b/llvm/test/CodeGen/DirectX/any.ll
new file mode 100644
index 00000000000000..4ae824d39edc1f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/any.ll
@@ -0,0 +1,105 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for any are generated for float and half.
+
+; CHECK: icmp ne i1 %{{.*}}, false
+define noundef i1 @any_bool(i1 noundef %p0) {
+entry:
+  %p0.addr = alloca i8, align 1
+  %frombool = zext i1 %p0 to i8
+  store i8 %frombool, ptr %p0.addr, align 1
+  %0 = load i8, ptr %p0.addr, align 1
+  %tobool = trunc i8 %0 to i1
+  %dx.any = call i1 @llvm.dx.any.i1(i1 %tobool)
+  ret i1 %dx.any
+}
+
+; CHECK: icmp ne i64 %{{.*}}, 0
+define noundef i1 @any_int64_t(i64 noundef %p0) {
+entry:
+  %p0.addr = alloca i64, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  %dx.any = call i1 @llvm.dx.any.i64(i64 %0)
+  ret i1 %dx.any
+}
+
+; CHECK: icmp ne i32 %{{.*}}, 0
+define noundef i1 @any_int(i32 noundef %p0) {
+entry:
+  %p0.addr = alloca i32, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  %dx.any = call i1 @llvm.dx.any.i32(i32 %0)
+  ret i1 %dx.any
+}
+
+; CHECK: icmp ne i16 %{{.*}}, 0
+define noundef i1 @any_int16_t(i16 noundef %p0) {
+entry:
+  %p0.addr = alloca i16, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  %dx.any = call i1 @llvm.dx.any.i16(i16 %0)
+  ret i1 %dx.any
+}
+
+; CHECK: fcmp une double %{{.*}}, 0.000000e+00
+define noundef i1 @any_double(double noundef %p0) {
+entry:
+  %p0.addr = alloca double, align 8
+  store double %p0, ptr %p0.addr, align 8
+  %0 = load double, ptr %p0.addr, align 8
+  %dx.any = call i1 @llvm.dx.any.f64(double %0)
+  ret i1 %dx.any
+}
+
+; CHECK: fcmp une float %{{.*}}, 0.000000e+00
+define noundef i1 @any_float(float noundef %p0) {
+entry:
+  %p0.addr = alloca float, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  %dx.any = call i1 @llvm.dx.any.f32(float %0)
+  ret i1 %dx.any
+}
+
+; CHECK: fcmp une half %{{.*}}, 0xH0000
+define noundef i1 @any_half(half noundef %p0) {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.any = call i1 @llvm.dx.any.f16(half %0)
+  ret i1 %dx.any
+}
+
+; CHECK: icmp ne <4 x i1> %extractvec, zeroinitialize
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK: or i1  %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
+; CHECK: or i1  %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
+; CHECK: or i1  %{{.*}}, %{{.*}}
+define noundef i1 @any_bool4(<4 x i1> noundef %p0) {
+entry:
+  %p0.addr = alloca i8, align 1
+  %insertvec = shufflevector <4 x i1> %p0, <4 x i1> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison>
+  %0 = bitcast <8 x i1> %insertvec to i8
+  store i8 %0, ptr %p0.addr, align 1
+  %load_bits = load i8, ptr %p0.addr, align 1
+  %1 = bitcast i8 %load_bits to <8 x i1>
+  %extractvec = shufflevector <8 x i1> %1, <8 x i1> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %dx.any = call i1 @llvm.dx.any.v4i1(<4 x i1> %extractvec)
+  ret i1 %dx.any
+}
+
+declare i1 @llvm.dx.any.v4i1(<4 x i1>)
+declare i1 @llvm.dx.any.i1(i1)
+declare i1 @llvm.dx.any.i16(i16)
+declare i1 @llvm.dx.any.i32(i32)
+declare i1 @llvm.dx.any.i64(i64)
+declare i1 @llvm.dx.any.f16(half)
+declare i1 @llvm.dx.any.f32(float)
+declare i1 @llvm.dx.any.f64(double)
diff --git a/llvm/test/CodeGen/DirectX/exp-vec.ll b/llvm/test/CodeGen/DirectX/exp-vec.ll
new file mode 100644
index 00000000000000..3a6dc10817a4de
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/exp-vec.ll
@@ -0,0 +1,16 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s
+
+; Make sure dxil operation function calls for exp are generated for float and half.
+
+; CHECK: fmul <4 x float> <float 0x3FF7154760000000, float 0x3FF7154760000000, float 0x3FF7154760000000, float 0x3FF7154760000000>,  %{{.*}}
+; CHECK: call <4 x float> @llvm.exp2.v4f32(<4 x float>  %{{.*}})
+define noundef <4 x float> @exp_float4(<4 x float> noundef %p0) {
+entry:
+  %p0.addr = alloca <4 x float>, align 16
+  store <4 x float> %p0, ptr %p0.addr, align 16
+  %0 = load <4 x float>, ptr %p0.addr, align 16
+  %elt.exp = call <4 x float> @llvm.exp.v4f32(<4 x float> %0)
+  ret <4 x float> %elt.exp
+}
+
+declare <4 x float> @llvm.exp.v4f32(<4 x float>)
diff --git a/llvm/test/CodeGen/DirectX/exp.ll b/llvm/test/CodeGen/DirectX/exp.ll
new file mode 100644
index 00000000000000..7f77a030cdbefd
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/exp.ll
@@ -0,0 +1,29 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for exp are generated for float and half.
+
+; CHECK: fmul float 0x3FF7154760000000, %{{.*}}
+; CHECK: call float @dx.op.unary.f32(i32 21, float %{{.*}})
+define noundef float @exp_float(float noundef %a) {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %elt.exp = call float @llvm.exp.f32(float %0)
+  ret float %elt.exp
+}
+
+; CHECK: fmul half 0xH3DC5, %{{.*}}
+; CHECK: call half @dx.op.unary.f16(i32 21, half %{{.*}})
+; Function Attrs: noinline nounwind optnone
+define noundef half @exp_half(half noundef %a) {
+entry:
+  %a.addr = alloca half, align 2
+  store half %a, ptr %a.addr, align 2
+  %0 = load half, ptr %a.addr, align 2
+  %elt.exp = call half @llvm.exp.f16(half %0)
+  ret half %elt.exp
+}
+
+declare half @llvm.exp.f16(half)
+declare float @llvm.exp.f32(float)
diff --git a/llvm/test/CodeGen/DirectX/lerp.ll b/llvm/test/CodeGen/DirectX/lerp.ll
new file mode 100644
index 00000000000000..5e183b58fdad4c
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/lerp.ll
@@ -0,0 +1,53 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for lerp are generated for float and half.
+
+; CHECK: fsub half %{{.*}}, %{{.*}}
+; CHECK: fmul half %{{.*}}, %{{.*}}
+; CHECK: fadd half %{{.*}}, %{{.*}}
+define noundef half @lerp_half(half noundef %p0) {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %1 = load half, ptr %p0.addr, align 2
+  %2 = load half, ptr %p0.addr, align 2
+  %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
+  ret half %dx.lerp
+}
+
+; CHECK: fsub float %{{.*}}, %{{.*}}
+; CHECK: fmul float %{{.*}}, %{{.*}}
+; CHECK: fadd float %{{.*}}, %{{.*}}
+define noundef float @lerp_float(float noundef %p0, float noundef %p1) {
+entry:
+  %p1.addr = alloca float, align 4
+  %p0.addr = alloca float, align 4
+  store float %p1, ptr %p1.addr, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  %1 = load float, ptr %p0.addr, align 4
+  %2 = load float, ptr %p0.addr, align 4
+  %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
+  ret float %dx.lerp
+}
+
+; CHECK: fsub <4 x float> %{{.*}}, %{{.*}}
+; CHECK: fmul <4 x float> %{{.*}}, %{{.*}}
+; CHECK: fadd <4 x float> %{{.*}}, %{{.*}}
+define noundef <4 x float> @lerp_float4(<4 x float> noundef %p0, <4 x float> noundef %p1) {
+entry:
+  %p1.addr = alloca <4 x float>, align 16
+  %p0.addr = alloca <4 x float>, align 16
+  store <4 x float> %p1, ptr %p1.addr, align 16
+  store <4 x float> %p0, ptr %p0.addr, align 16
+  %0 = load <4 x float>, ptr %p0.addr, align 16
+  %1 = load <4 x float>, ptr %p0.addr, align 16
+  %2 = load <4 x float>, ptr %p0.addr, align 16
+  %dx.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+  ret <4 x float> %dx.lerp
+}
+
+declare half @llvm.dx.lerp.f16(half, half, half)
+declare float @llvm.dx.lerp.f32(float, float, float)
+declare <4 x float> @llvm.dx.lerp.v4f32(<4 x float>, <4 x float>, <4 x float>)
diff --git a/llvm/test/CodeGen/DirectX/rcp.ll b/llvm/test/CodeGen/DirectX/rcp.ll
new file mode 100644
index 00000000000000..bb54d5f5ab6230
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/rcp.ll
@@ -0,0 +1,48 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for rcp are generated for float, double, and half.
+
+; CHECK: fdiv <4 x float> <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>, %{{.*}}
+define noundef <4 x float> @rcp_float4(<4 x float> noundef %p0) {
+entry:
+  %p0.addr = alloca <4 x float>, align 16
+  store <4 x float> %p0, ptr %p0.addr, align 16
+  %0 = load <4 x float>, ptr %p0.addr, align 16
+  %dx.rcp = call <4 x float> @llvm.dx.rcp.v4f32(<4 x float> %0)
+  ret <4 x float> %dx.rcp
+}
+
+; CHECK: fdiv <4 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00>, %{{.*}}
+define noundef <4 x double> @rcp_double4(<4 x double> noundef %p0) {
+entry:
+  %p0.addr = alloca <4 x double>, align 16
+  store <4 x double> %p0, ptr %p0.addr, align 16
+  %0 = load <4 x double>, ptr %p0.addr, align 16
+  %dx.rcp = call <4 x double> @llvm.dx.rcp.v4f64(<4 x double> %0)
+  ret <4 x double> %dx.rcp
+}
+
+; CHECK: fdiv <4 x half> <half  0xH3C00, half  0xH3C00, half  0xH3C00, half  0xH3C00>, %{{.*}} 
+define noundef <4 x half> @rcp_half4(<4 x half> noundef %p0) {
+entry:
+  %p0.addr = alloca <4 x half>, align 16
+  store <4 x half> %p0, ptr %p0.addr, align 16
+  %0 = load <4 x half>, ptr %p0.addr, align 16
+  %dx.rcp = call <4 x half> @llvm.dx.rcp.v4f16(<4 x half> %0)
+  ret <4 x half> %dx.rcp
+}
+
+; CHECK: fdiv half 0xH3C00, %{{.*}} 
+define noundef half @rcp_half(half noundef %p0) {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.rcp = call half @llvm.dx.rcp.f16(half %0)
+  ret half %dx.rcp
+}
+
+declare half @llvm.dx.rcp.f16(half)
+declare <4 x half> @llvm.dx.rcp.v4f16(<4 x half>)
+declare <4 x float> @llvm.dx.rcp.v4f32(<4 x float>)
+declare <4 x double> @llvm.dx.rcp.v4f64(<4 x double>)

>From 94ac3a971ee402c4dddef3588d23739e08ba248b Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 14 Mar 2024 19:21:21 -0400
Subject: [PATCH 2/3] add labels for test cases

---
 llvm/test/CodeGen/DirectX/any.ll     | 8 ++++++++
 llvm/test/CodeGen/DirectX/exp-vec.ll | 1 +
 llvm/test/CodeGen/DirectX/exp.ll     | 2 ++
 llvm/test/CodeGen/DirectX/lerp.ll    | 3 +++
 llvm/test/CodeGen/DirectX/rcp.ll     | 4 ++++
 5 files changed, 18 insertions(+)

diff --git a/llvm/test/CodeGen/DirectX/any.ll b/llvm/test/CodeGen/DirectX/any.ll
index 4ae824d39edc1f..e8d87075d65cfe 100644
--- a/llvm/test/CodeGen/DirectX/any.ll
+++ b/llvm/test/CodeGen/DirectX/any.ll
@@ -2,6 +2,7 @@
 
 ; Make sure dxil operation function calls for any are generated for float and half.
 
+; CHECK-LABEL: any_bool
 ; CHECK: icmp ne i1 %{{.*}}, false
 define noundef i1 @any_bool(i1 noundef %p0) {
 entry:
@@ -14,6 +15,7 @@ entry:
   ret i1 %dx.any
 }
 
+; CHECK-LABEL: any_int64_t
 ; CHECK: icmp ne i64 %{{.*}}, 0
 define noundef i1 @any_int64_t(i64 noundef %p0) {
 entry:
@@ -24,6 +26,7 @@ entry:
   ret i1 %dx.any
 }
 
+; CHECK-LABEL: any_int
 ; CHECK: icmp ne i32 %{{.*}}, 0
 define noundef i1 @any_int(i32 noundef %p0) {
 entry:
@@ -34,6 +37,7 @@ entry:
   ret i1 %dx.any
 }
 
+; CHECK-LABEL: any_int16_t
 ; CHECK: icmp ne i16 %{{.*}}, 0
 define noundef i1 @any_int16_t(i16 noundef %p0) {
 entry:
@@ -44,6 +48,7 @@ entry:
   ret i1 %dx.any
 }
 
+; CHECK-LABEL: any_double
 ; CHECK: fcmp une double %{{.*}}, 0.000000e+00
 define noundef i1 @any_double(double noundef %p0) {
 entry:
@@ -54,6 +59,7 @@ entry:
   ret i1 %dx.any
 }
 
+; CHECK-LABEL: any_float
 ; CHECK: fcmp une float %{{.*}}, 0.000000e+00
 define noundef i1 @any_float(float noundef %p0) {
 entry:
@@ -64,6 +70,7 @@ entry:
   ret i1 %dx.any
 }
 
+; CHECK-LABEL: any_half
 ; CHECK: fcmp une half %{{.*}}, 0xH0000
 define noundef i1 @any_half(half noundef %p0) {
 entry:
@@ -74,6 +81,7 @@ entry:
   ret i1 %dx.any
 }
 
+; CHECK-LABEL: any_bool4
 ; CHECK: icmp ne <4 x i1> %extractvec, zeroinitialize
 ; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
 ; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
diff --git a/llvm/test/CodeGen/DirectX/exp-vec.ll b/llvm/test/CodeGen/DirectX/exp-vec.ll
index 3a6dc10817a4de..c9371557190549 100644
--- a/llvm/test/CodeGen/DirectX/exp-vec.ll
+++ b/llvm/test/CodeGen/DirectX/exp-vec.ll
@@ -2,6 +2,7 @@
 
 ; Make sure dxil operation function calls for exp are generated for float and half.
 
+; CHECK-LABEL: exp_float4
 ; CHECK: fmul <4 x float> <float 0x3FF7154760000000, float 0x3FF7154760000000, float 0x3FF7154760000000, float 0x3FF7154760000000>,  %{{.*}}
 ; CHECK: call <4 x float> @llvm.exp2.v4f32(<4 x float>  %{{.*}})
 define noundef <4 x float> @exp_float4(<4 x float> noundef %p0) {
diff --git a/llvm/test/CodeGen/DirectX/exp.ll b/llvm/test/CodeGen/DirectX/exp.ll
index 7f77a030cdbefd..fdafc1438cf0e8 100644
--- a/llvm/test/CodeGen/DirectX/exp.ll
+++ b/llvm/test/CodeGen/DirectX/exp.ll
@@ -2,6 +2,7 @@
 
 ; Make sure dxil operation function calls for exp are generated for float and half.
 
+; CHECK-LABEL: exp_float
 ; CHECK: fmul float 0x3FF7154760000000, %{{.*}}
 ; CHECK: call float @dx.op.unary.f32(i32 21, float %{{.*}})
 define noundef float @exp_float(float noundef %a) {
@@ -13,6 +14,7 @@ entry:
   ret float %elt.exp
 }
 
+; CHECK-LABEL: exp_half
 ; CHECK: fmul half 0xH3DC5, %{{.*}}
 ; CHECK: call half @dx.op.unary.f16(i32 21, half %{{.*}})
 ; Function Attrs: noinline nounwind optnone
diff --git a/llvm/test/CodeGen/DirectX/lerp.ll b/llvm/test/CodeGen/DirectX/lerp.ll
index 5e183b58fdad4c..ebd7e133b5163c 100644
--- a/llvm/test/CodeGen/DirectX/lerp.ll
+++ b/llvm/test/CodeGen/DirectX/lerp.ll
@@ -2,6 +2,7 @@
 
 ; Make sure dxil operation function calls for lerp are generated for float and half.
 
+; CHECK-LABEL: lerp_half
 ; CHECK: fsub half %{{.*}}, %{{.*}}
 ; CHECK: fmul half %{{.*}}, %{{.*}}
 ; CHECK: fadd half %{{.*}}, %{{.*}}
@@ -16,6 +17,7 @@ entry:
   ret half %dx.lerp
 }
 
+; CHECK-LABEL: lerp_float
 ; CHECK: fsub float %{{.*}}, %{{.*}}
 ; CHECK: fmul float %{{.*}}, %{{.*}}
 ; CHECK: fadd float %{{.*}}, %{{.*}}
@@ -32,6 +34,7 @@ entry:
   ret float %dx.lerp
 }
 
+; CHECK-LABEL: lerp_float4
 ; CHECK: fsub <4 x float> %{{.*}}, %{{.*}}
 ; CHECK: fmul <4 x float> %{{.*}}, %{{.*}}
 ; CHECK: fadd <4 x float> %{{.*}}, %{{.*}}
diff --git a/llvm/test/CodeGen/DirectX/rcp.ll b/llvm/test/CodeGen/DirectX/rcp.ll
index bb54d5f5ab6230..65abe832db53fe 100644
--- a/llvm/test/CodeGen/DirectX/rcp.ll
+++ b/llvm/test/CodeGen/DirectX/rcp.ll
@@ -2,6 +2,7 @@
 
 ; Make sure dxil operation function calls for rcp are generated for float, double, and half.
 
+; CHECK-LABEL: rcp_float4
 ; CHECK: fdiv <4 x float> <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>, %{{.*}}
 define noundef <4 x float> @rcp_float4(<4 x float> noundef %p0) {
 entry:
@@ -12,6 +13,7 @@ entry:
   ret <4 x float> %dx.rcp
 }
 
+; CHECK-LABEL: rcp_double4
 ; CHECK: fdiv <4 x double> <double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00>, %{{.*}}
 define noundef <4 x double> @rcp_double4(<4 x double> noundef %p0) {
 entry:
@@ -22,6 +24,7 @@ entry:
   ret <4 x double> %dx.rcp
 }
 
+; CHECK-LABEL: rcp_half4
 ; CHECK: fdiv <4 x half> <half  0xH3C00, half  0xH3C00, half  0xH3C00, half  0xH3C00>, %{{.*}} 
 define noundef <4 x half> @rcp_half4(<4 x half> noundef %p0) {
 entry:
@@ -32,6 +35,7 @@ entry:
   ret <4 x half> %dx.rcp
 }
 
+; CHECK-LABEL: rcp_half
 ; CHECK: fdiv half 0xH3C00, %{{.*}} 
 define noundef half @rcp_half(half noundef %p0) {
 entry:

>From 2dc11f267e25b87084f2f8056513ca22f8164fdd Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 14 Mar 2024 19:57:09 -0400
Subject: [PATCH 3/3] enforce float\half type checks

---
 clang/include/clang/Basic/Builtins.td          |  2 +-
 clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl  | 12 ++++++++++++
 clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl | 11 +++++++++++
 clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl | 11 +++++++++++
 4 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 15fcdb3ced95c1..58a2d22e7641fc 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4598,7 +4598,7 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
 
 def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
-  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Attributes = [NoThrow, Const];
   let Prototype = "void(...)";
 }
 
diff --git a/clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl
index 904880397559d1..f82b5942fd46b6 100644
--- a/clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl
@@ -25,3 +25,15 @@ float2 builtin_frac_int2_to_float2_promotion(int2 p1) {
   return __builtin_hlsl_elementwise_frac(p1);
   // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
 }
+
+// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
+half builtin_frac_half_scalar (half p0) {
+  return __builtin_hlsl_elementwise_frac (p0);
+  // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
+
+float builtin_frac_float_scalar ( float p0) {
+  return __builtin_hlsl_elementwise_frac (p0);
+  // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
+
diff --git a/clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl
index 6c3ab743e814d4..7ddfd56638273b 100644
--- a/clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl
@@ -25,3 +25,14 @@ bool2 builtin_isinf_int2_to_float2_promotion(int2 p1) {
   return __builtin_hlsl_elementwise_isinf(p1);
   // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
 }
+
+// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
+half builtin_isinf_half_scalar (half p0) {
+  return __builtin_hlsl_elementwise_isinf (p0);
+  // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
+
+float builtin_isinf_float_scalar ( float p0) {
+  return __builtin_hlsl_elementwise_isinf (p0);
+  // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
index fe32e13f0632fe..c027a698c5e58f 100644
--- a/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
@@ -25,3 +25,14 @@ float2 builtin_rsqrt_int2_to_float2_promotion(int2 p1) {
   return __builtin_hlsl_elementwise_rsqrt(p1);
   // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
 }
+
+// builtins are variadic functions and so are subject to DefaultVariadicArgumentPromotion
+half builtin_rsqrt_half_scalar (half p0) {
+  return __builtin_hlsl_elementwise_rsqrt (p0);
+  // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
+
+float builtin_rsqrt_float_scalar ( float p0) {
+  return __builtin_hlsl_elementwise_rsqrt (p0);
+  // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}



More information about the cfe-commits mailing list