[clang] Adding `asuint` implementation to hlsl (PR #107292)

via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 12 10:06:08 PDT 2024


https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/107292

>From c6434c06d17a2442863f8843e75dc870966fb97c Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 3 Sep 2024 19:06:22 +0000
Subject: [PATCH 1/5] Adding `asuint`  implementation to hlsl

---
 clang/include/clang/Basic/Builtins.td       |  6 +++
 clang/lib/CodeGen/CGBuiltin.cpp             | 17 +++++++
 clang/lib/Headers/hlsl/hlsl_intrinsics.h    | 38 ++++++++++-----
 clang/lib/Sema/SemaHLSL.cpp                 | 19 ++++++++
 clang/test/CodeGenHLSL/builtins/asuint.hlsl | 53 +++++++++++++++++++++
 5 files changed, 122 insertions(+), 11 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/asuint.hlsl

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 3dc04f68b3172a..b055c50689eff6 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4751,6 +4751,12 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLAsUint : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_asuint"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 9950c06a0b9a6b..4f43370a04424d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -27,9 +27,11 @@
 #include "clang/AST/Decl.h"
 #include "clang/AST/OSLog.h"
 #include "clang/AST/OperationKinds.h"
+#include "clang/Basic/Builtins.h"
 #include "clang/Basic/TargetBuiltins.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Basic/TargetOptions.h"
+#include "clang/Basic/TokenKinds.h"
 #include "clang/CodeGen/CGFunctionInfo.h"
 #include "clang/Frontend/FrontendDiagnostic.h"
 #include "llvm/ADT/APFloat.h"
@@ -39,6 +41,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
@@ -62,6 +65,7 @@
 #include "llvm/Support/ConvertUTF.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/ScopedPrinter.h"
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/TargetParser/AArch64TargetParser.h"
 #include "llvm/TargetParser/RISCVISAInfo.h"
 #include "llvm/TargetParser/X86TargetParser.h"
@@ -18866,6 +18870,19 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
   }
+  case Builtin::BI__builtin_hlsl_elementwise_asuint: {
+    Value *Op = EmitScalarExpr(E->getArg(0));
+    E->dump();
+    llvm::Type *DestTy = llvm::Type::getInt32Ty(this->getLLVMContext());
+
+    if (Op -> getType()->isVectorTy()){
+      auto VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+      DestTy = llvm::VectorType::get(DestTy, VecTy->getNumElements(),
+                                     VecTy->isSizelessVectorType());
+    }
+
+    return Builder.CreateBitCast(Op, DestTy);
+  }
   case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
     return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 7a1edd93984de7..1a7e1f646619e0 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -367,17 +367,6 @@ bool any(double4);
 /// \brief Returns the arcsine of the input value, \a Val.
 /// \param Val The input value.
 
-#ifdef __HLSL_ENABLE_16_BIT
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
-half asin(half);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
-half2 asin(half2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
-half3 asin(half3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
-half4 asin(half4);
-#endif
-
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
 float asin(float);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
@@ -387,6 +376,33 @@ float3 asin(float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
 float4 asin(float4);
 
+//===----------------------------------------------------------------------===//
+// asin builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn uint asin(T Val)
+/// \brief Reinterprest.
+/// \param Val The input value.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint asuint(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint2 asuint(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint3 asuint(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint4 asuint(float4);
+
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint asuint(double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint2 asuint(double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint3 asuint(double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint4 asuint(double4);
+
 //===----------------------------------------------------------------------===//
 // atan builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 4e44813fe515ce..ed99a6e97d8368 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -29,6 +29,8 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Type.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -1754,6 +1756,23 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_elementwise_asuint: {
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+
+    if(ArgTyA->isVectorType()){
+      auto VecTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
+      auto ReturnType = this->getASTContext().getVectorType(TheCall->getCallReturnType(this->getASTContext()), VecTy->getNumElements(),
+                                          VectorKind::Generic);
+
+      TheCall->setType(ReturnType);
+    }
+
+    break;
+  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/asuint.hlsl b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
new file mode 100644
index 00000000000000..33acb00ae11182
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+// // CHECK-LABEL: builtin_test_asuint_float
+// // CHECK: bitcast float %0 to i32
+// // CHECK: ret <4 x i32> %dx.clamp
+// export uint builtin_test_asuint_float(float p0) {
+//   return __builtin_hlsl_elementwise_asuint(p0);
+// }
+
+
+// // CHECK-LABEL: builtin_test_asuint_float
+// // CHECK: bitcast float %0 to i32
+// // CHECK: ret <4 x i32> %dx.clamp
+// export uint builtin_test_asuint_double(double p0) {
+//   return __builtin_hlsl_elementwise_asuint(p0);
+// }
+
+
+// // CHECK-LABEL: builtin_test_asuint_float
+// // CHECK: bitcast float %0 to i32
+// // CHECK: ret <4 x i32> %dx.clamp
+// export uint builtin_test_asuint_half(half p0) {
+//   return __builtin_hlsl_elementwise_asuint(p0);
+// }
+
+
+// // CHECK-LABEL: builtin_test_asuint_float
+// // CHECK: bitcast float %0 to i32
+// // CHECK: ret <4 x i32> %dx.clamp
+// export uint4 builtin_test_asuint_float_vector(float p0) {
+//   return __builtin_hlsl_elementwise_asuint(p0);
+// }
+
+
+// CHECK-LABEL: builtin_test_asuint_float
+// CHECK: bitcast float %0 to i32
+// CHECK: ret <4 x i32> %dx.clamp
+export uint4 builtin_test_asuint_floa4t(float p0) {
+  return asuint(p0);
+}
+
+// export uint4 builtin_test_asuint4_uint(uint p0) {
+//   return __builtin_hlsl_elementwise_asuint(p0);
+// }
+
+
+// export uint4 builtin_test_asuint4_int(int p0) {
+//   return __builtin_hlsl_elementwise_asuint(p0);
+// }
+
+// export uint builtin_test_asuint_float(float p0) {
+//   return __builtin_hlsl_elementwise_asuint(p0);
+// }
\ No newline at end of file

>From 55a6cad600d2cca7a07b98c160f76c63709cba0d Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 3 Sep 2024 19:06:22 +0000
Subject: [PATCH 2/5] Adding `asuint`  implementation to hlsl

---
 clang/lib/CodeGen/CGBuiltin.cpp               | 16 ++---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 27 ++++----
 clang/lib/Sema/SemaHLSL.cpp                   | 43 +++++++++----
 clang/test/CodeGenHLSL/builtins/asuint.hlsl   | 63 ++++++-------------
 .../test/SemaHLSL/BuiltIns/asuint-errors.hlsl | 18 ++++++
 5 files changed, 86 insertions(+), 81 deletions(-)
 create mode 100644 clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 4f43370a04424d..d6d740604190c1 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -27,11 +27,9 @@
 #include "clang/AST/Decl.h"
 #include "clang/AST/OSLog.h"
 #include "clang/AST/OperationKinds.h"
-#include "clang/Basic/Builtins.h"
 #include "clang/Basic/TargetBuiltins.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Basic/TargetOptions.h"
-#include "clang/Basic/TokenKinds.h"
 #include "clang/CodeGen/CGFunctionInfo.h"
 #include "clang/Frontend/FrontendDiagnostic.h"
 #include "llvm/ADT/APFloat.h"
@@ -41,7 +39,6 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
@@ -65,7 +62,6 @@
 #include "llvm/Support/ConvertUTF.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/ScopedPrinter.h"
-#include "llvm/Support/raw_ostream.h"
 #include "llvm/TargetParser/AArch64TargetParser.h"
 #include "llvm/TargetParser/RISCVISAInfo.h"
 #include "llvm/TargetParser/X86TargetParser.h"
@@ -18871,14 +18867,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         {}, false, true));
   }
   case Builtin::BI__builtin_hlsl_elementwise_asuint: {
-    Value *Op = EmitScalarExpr(E->getArg(0));
-    E->dump();
+    Value *Op = EmitScalarExpr(E->getArg(0)->IgnoreImpCasts());
+
     llvm::Type *DestTy = llvm::Type::getInt32Ty(this->getLLVMContext());
 
-    if (Op -> getType()->isVectorTy()){
-      auto VecTy = E->getArg(0)->getType()->getAs<VectorType>();
-      DestTy = llvm::VectorType::get(DestTy, VecTy->getNumElements(),
-                                     VecTy->isSizelessVectorType());
+    if (Op->getType()->isVectorTy()) {
+      const VectorType *VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+      DestTy = llvm::VectorType::get(
+          DestTy, ElementCount::getFixed(VecTy->getNumElements()));
     }
 
     return Builder.CreateBitCast(Op, DestTy);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 1a7e1f646619e0..2ba2677b65af3c 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -367,6 +367,17 @@ bool any(double4);
 /// \brief Returns the arcsine of the input value, \a Val.
 /// \param Val The input value.
 
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
+half asin(half);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
+half2 asin(half2);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
+half3 asin(half3);
+_HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
+half4 asin(half4);
+#endif
+
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
 float asin(float);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
@@ -377,11 +388,11 @@ _HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
 float4 asin(float4);
 
 //===----------------------------------------------------------------------===//
-// asin builtins
+// asuint builtins
 //===----------------------------------------------------------------------===//
 
-/// \fn uint asin(T Val)
-/// \brief Reinterprest.
+/// \fn uint asuint(T Val)
+/// \brief Interprets the bit pattern of x as an unsigned integer.
 /// \param Val The input value.
 
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
@@ -393,16 +404,6 @@ uint3 asuint(float3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
 uint4 asuint(float4);
 
-
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint asuint(double);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint2 asuint(double2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint3 asuint(double3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint4 asuint(double4);
-
 //===----------------------------------------------------------------------===//
 // atan builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index ed99a6e97d8368..2e41daa4af7ef8 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -9,7 +9,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Sema/SemaHLSL.h"
-#include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/AST/DeclCXX.h"
@@ -29,8 +28,6 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/Type.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -1473,6 +1470,25 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
   return true;
 }
 
+bool CheckArgTypeWithoutImplicits(
+    Sema *S, Expr *Arg, QualType ExpectedType,
+    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+
+  QualType ArgTy = Arg->IgnoreImpCasts()->getType();
+
+  clang::QualType BaseType =
+      ArgTy->isVectorType()
+          ? ArgTy->getAs<clang::VectorType>()->getElementType()
+          : ArgTy;
+
+  if (Check(BaseType)) {
+    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+        << ArgTy << ExpectedType << 1 << 0 << 0;
+    return true;
+  }
+  return false;
+}
+
 bool CheckArgsTypesAreCorrect(
     Sema *S, CallExpr *TheCall, QualType ExpectedType,
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
@@ -1499,6 +1515,14 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
                                   checkAllFloatTypes);
 }
 
+bool CheckArgIsFloatOrIntWithoutImplicits(Sema *S, Expr *Arg) {
+  auto checkFloat = [](clang::QualType PassedType) -> bool {
+    return !PassedType->isFloat32Type() && !PassedType->isIntegerType();
+  };
+
+  return CheckArgTypeWithoutImplicits(S, Arg, S->Context.FloatTy, checkFloat);
+}
+
 bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
   auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
     clang::QualType BaseType =
@@ -1760,16 +1784,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-
-    if(ArgTyA->isVectorType()){
-      auto VecTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
-      auto ReturnType = this->getASTContext().getVectorType(TheCall->getCallReturnType(this->getASTContext()), VecTy->getNumElements(),
-                                          VectorKind::Generic);
-
-      TheCall->setType(ReturnType);
-    }
+    Expr *Arg = TheCall->getArg(0);
+    if (CheckArgIsFloatOrIntWithoutImplicits(&SemaRef, Arg))
+      return true;
 
     break;
   }
diff --git a/clang/test/CodeGenHLSL/builtins/asuint.hlsl b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
index 33acb00ae11182..2ae7d8219ac671 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
@@ -1,53 +1,26 @@
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
 
-// // CHECK-LABEL: builtin_test_asuint_float
-// // CHECK: bitcast float %0 to i32
-// // CHECK: ret <4 x i32> %dx.clamp
-// export uint builtin_test_asuint_float(float p0) {
-//   return __builtin_hlsl_elementwise_asuint(p0);
-// }
 
-
-// // CHECK-LABEL: builtin_test_asuint_float
-// // CHECK: bitcast float %0 to i32
-// // CHECK: ret <4 x i32> %dx.clamp
-// export uint builtin_test_asuint_double(double p0) {
-//   return __builtin_hlsl_elementwise_asuint(p0);
-// }
-
-
-// // CHECK-LABEL: builtin_test_asuint_float
-// // CHECK: bitcast float %0 to i32
-// // CHECK: ret <4 x i32> %dx.clamp
-// export uint builtin_test_asuint_half(half p0) {
-//   return __builtin_hlsl_elementwise_asuint(p0);
-// }
-
-
-// // CHECK-LABEL: builtin_test_asuint_float
-// // CHECK: bitcast float %0 to i32
-// // CHECK: ret <4 x i32> %dx.clamp
-// export uint4 builtin_test_asuint_float_vector(float p0) {
-//   return __builtin_hlsl_elementwise_asuint(p0);
-// }
-
-
-// CHECK-LABEL: builtin_test_asuint_float
-// CHECK: bitcast float %0 to i32
-// CHECK: ret <4 x i32> %dx.clamp
-export uint4 builtin_test_asuint_floa4t(float p0) {
+// CHECK-LABEL: test_asuint4_uint
+// CHECK: ret i32 %0
+export uint test_asuint4_uint(uint p0) {
   return asuint(p0);
 }
 
-// export uint4 builtin_test_asuint4_uint(uint p0) {
-//   return __builtin_hlsl_elementwise_asuint(p0);
-// }
-
+// CHECK-LABEL: test_asuint4_int
+// CHECK: %splat.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
+export uint4 test_asuint4_int(int p0) {
+  return asuint(p0);
+}
 
-// export uint4 builtin_test_asuint4_int(int p0) {
-//   return __builtin_hlsl_elementwise_asuint(p0);
-// }
+// CHECK-LABEL: test_asuint_float
+// CHECK: %1 = bitcast float %0 to i32
+export uint test_asuint_float(float p0) {
+  return asuint(p0);
+}
 
-// export uint builtin_test_asuint_float(float p0) {
-//   return __builtin_hlsl_elementwise_asuint(p0);
-// }
\ No newline at end of file
+// CHECK-LABEL: test_asuint_float
+// CHECK: %1 = bitcast <4 x float> %0 to <4 x i32>
+export uint4 test_asuint_float4(float4 p0) {
+  return asuint(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
new file mode 100644
index 00000000000000..e9da975bff1b5e
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+
+export uint4 test_asuint_too_many_arg(float p0, float p1) {
+  return __builtin_hlsl_elementwise_asuint(p0, p1);
+  // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+
+export uint fn(double p1) {
+    return asuint(p1);
+    // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
+
+export uint fn(half p1) {
+    return asuint(p1);
+    // expected-error at -1 {{passing 'half' to parameter of incompatible type 'float'}}
+}

>From 529c8dc39d2fb435fa11fb7cbf6aacfd6f915412 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 6 Sep 2024 18:44:38 +0000
Subject: [PATCH 3/5] Adding `asuint`  implementation to hlsl

---
 clang/include/clang/Basic/Builtins.td       |  7 +++-
 clang/lib/CodeGen/CGBuiltin.cpp             |  2 +-
 clang/lib/Headers/hlsl/hlsl_intrinsics.h    | 18 +++++++++
 clang/lib/Sema/SemaHLSL.cpp                 | 32 ++++-----------
 clang/test/CodeGenHLSL/builtins/asuint.hlsl | 43 ++++++++++++++-------
 5 files changed, 61 insertions(+), 41 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b055c50689eff6..43170566465938 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -53,6 +53,9 @@ class MSInt32_64Template : Template<["msint32_t", "int64_t"],
 class FloatDoubleTemplate : Template<["float", "double"],
                                      ["f",     ""]>;
 
+class HLSLFloatAndIntTemplate : Template<
+            ["unsigned int", "int", "float"],
+            ["",             "si",    "f"]>;
 // FIXME: These assume that char -> i8, short -> i16, int -> i32,
 // long long -> i64.
 class SyncBuiltinsTemplate :
@@ -4751,10 +4754,10 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-def HLSLAsUint : LangBuiltin<"HLSL_LANG"> {
+def HLSLAsUint : LangBuiltin<"HLSL_LANG">, HLSLFloatAndIntTemplate {
   let Spellings = ["__builtin_hlsl_elementwise_asuint"];
   let Attributes = [NoThrow, Const];
-  let Prototype = "void(...)";
+  let Prototype = "unsigned int (T)";
 }
 
 def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d6d740604190c1..4ee8a54f75beff 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18867,7 +18867,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         {}, false, true));
   }
   case Builtin::BI__builtin_hlsl_elementwise_asuint: {
-    Value *Op = EmitScalarExpr(E->getArg(0)->IgnoreImpCasts());
+    Value *Op = EmitScalarExpr(E->getArg(0));
 
     llvm::Type *DestTy = llvm::Type::getInt32Ty(this->getLLVMContext());
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 2ba2677b65af3c..821009a4bde4ef 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -404,6 +404,24 @@ uint3 asuint(float3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
 uint4 asuint(float4);
 
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint asuint(uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint2 asuint(uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint3 asuint(uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint4 asuint(uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint asuint(int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint2 asuint(int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint3 asuint(int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint4 asuint(int4);
+
 //===----------------------------------------------------------------------===//
 // atan builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 2e41daa4af7ef8..730b1e7542f5fc 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1470,25 +1470,6 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
   return true;
 }
 
-bool CheckArgTypeWithoutImplicits(
-    Sema *S, Expr *Arg, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-
-  QualType ArgTy = Arg->IgnoreImpCasts()->getType();
-
-  clang::QualType BaseType =
-      ArgTy->isVectorType()
-          ? ArgTy->getAs<clang::VectorType>()->getElementType()
-          : ArgTy;
-
-  if (Check(BaseType)) {
-    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
-        << ArgTy << ExpectedType << 1 << 0 << 0;
-    return true;
-  }
-  return false;
-}
-
 bool CheckArgsTypesAreCorrect(
     Sema *S, CallExpr *TheCall, QualType ExpectedType,
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
@@ -1515,12 +1496,16 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
                                   checkAllFloatTypes);
 }
 
-bool CheckArgIsFloatOrIntWithoutImplicits(Sema *S, Expr *Arg) {
+bool CheckNotFloatAndInt(Sema *S, CallExpr *TheCall) {
   auto checkFloat = [](clang::QualType PassedType) -> bool {
-    return !PassedType->isFloat32Type() && !PassedType->isIntegerType();
+    clang::QualType BaseType =
+        PassedType->isVectorType()
+            ? PassedType->getAs<clang::VectorType>()->getElementType()
+            : PassedType;
+    return !(BaseType->isFloat32Type() || BaseType->isIntegerType());
   };
 
-  return CheckArgTypeWithoutImplicits(S, Arg, S->Context.FloatTy, checkFloat);
+  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy, checkFloat);
 }
 
 bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
@@ -1784,8 +1769,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
 
-    Expr *Arg = TheCall->getArg(0);
-    if (CheckArgIsFloatOrIntWithoutImplicits(&SemaRef, Arg))
+    if (CheckNotFloatAndInt(&SemaRef, TheCall))
       return true;
 
     break;
diff --git a/clang/test/CodeGenHLSL/builtins/asuint.hlsl b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
index 2ae7d8219ac671..1edf07942d3a09 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
@@ -1,26 +1,41 @@
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
 
+// CHECK: define {{.*}}test_uint{{.*}}(i32 {{.*}} [[VAL:%.*]]){{.*}} 
+// CHECK-NOT: bitcast
+// CHECK: ret i32 [[VAL]]
+export uint test_uint(uint p0) {
+  return asuint(p0);
+}
+
+// CHECK: define {{.*}}test_int{{.*}}(i32 {{.*}} [[VAL:%.*]]){{.*}} 
+// CHECK-NOT: bitcast
+// CHECK: ret i32 [[VAL]]
+export uint test_int(int p0) {
+  return asuint(p0);
+}
 
-// CHECK-LABEL: test_asuint4_uint
-// CHECK: ret i32 %0
-export uint test_asuint4_uint(uint p0) {
+// CHECK: define {{.*}}test_float{{.*}}(float {{.*}} [[VAL:%.*]]){{.*}} 
+// CHECK: bitcast float [[VAL]] to i32
+export uint test_float(float p0) {
   return asuint(p0);
 }
 
-// CHECK-LABEL: test_asuint4_int
-// CHECK: %splat.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
-export uint4 test_asuint4_int(int p0) {
+// CHECK: define {{.*}}test_vector_uint{{.*}}(<4 x i32> {{.*}} [[VAL:%.*]]){{.*}} 
+// CHECK-NOT: bitcast
+// CHECK: ret <4 x i32> [[VAL]]
+export uint4 test_vector_uint(uint4 p0) {
   return asuint(p0);
 }
 
-// CHECK-LABEL: test_asuint_float
-// CHECK: %1 = bitcast float %0 to i32
-export uint test_asuint_float(float p0) {
+// CHECK: define {{.*}}test_vector_int{{.*}}(<4 x i32> {{.*}} [[VAL:%.*]]){{.*}} 
+// CHECK-NOT: bitcast
+// CHECK: ret <4 x i32> [[VAL]]
+export uint4 test_vector_int(int4 p0) {
   return asuint(p0);
 }
 
-// CHECK-LABEL: test_asuint_float
-// CHECK: %1 = bitcast <4 x float> %0 to <4 x i32>
-export uint4 test_asuint_float4(float4 p0) {
+// CHECK: define {{.*}}test_vector_float{{.*}}(<4 x float> {{.*}} [[VAL:%.*]]){{.*}} 
+// CHECK: bitcast <4 x float> [[VAL]] to <4 x i32>
+export uint4 test_vector_float(float4 p0) {
   return asuint(p0);
-}
\ No newline at end of file
+}

>From 16a02160b00ee5b4d50575207e8c9fdc7b39eb23 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 9 Sep 2024 18:35:20 +0000
Subject: [PATCH 4/5] implementing __hlsl_bit_cast_32

---
 clang/include/clang/Basic/Builtins.td         | 15 +++----
 clang/lib/CodeGen/CGBuiltin.cpp               | 18 ++++++---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 40 +++++++++----------
 clang/lib/Sema/SemaHLSL.cpp                   |  4 +-
 clang/test/CodeGenHLSL/builtins/asuint.hlsl   | 12 +++---
 ...nt-errors.hlsl => bit-cast-32-errors.hlsl} | 16 ++++----
 6 files changed, 54 insertions(+), 51 deletions(-)
 rename clang/test/SemaHLSL/BuiltIns/{asuint-errors.hlsl => bit-cast-32-errors.hlsl} (51%)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 43170566465938..da298500acffae 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -53,9 +53,6 @@ class MSInt32_64Template : Template<["msint32_t", "int64_t"],
 class FloatDoubleTemplate : Template<["float", "double"],
                                      ["f",     ""]>;
 
-class HLSLFloatAndIntTemplate : Template<
-            ["unsigned int", "int", "float"],
-            ["",             "si",    "f"]>;
 // FIXME: These assume that char -> i8, short -> i16, int -> i32,
 // long long -> i64.
 class SyncBuiltinsTemplate :
@@ -4754,12 +4751,6 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-def HLSLAsUint : LangBuiltin<"HLSL_LANG">, HLSLFloatAndIntTemplate {
-  let Spellings = ["__builtin_hlsl_elementwise_asuint"];
-  let Attributes = [NoThrow, Const];
-  let Prototype = "unsigned int (T)";
-}
-
 def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
   let Attributes = [NoThrow, Const];
@@ -4772,6 +4763,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLBitCast32 : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_bit_cast_32"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_select"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 4ee8a54f75beff..d7693d6c905487 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18866,10 +18866,20 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
   }
-  case Builtin::BI__builtin_hlsl_elementwise_asuint: {
+  case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
+    Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
+    return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
+  }
+  case Builtin::BI__builtin_hlsl_bit_cast_32: {
     Value *Op = EmitScalarExpr(E->getArg(0));
 
-    llvm::Type *DestTy = llvm::Type::getInt32Ty(this->getLLVMContext());
+    llvm::Type *DestTy = ConvertType(E->getCallReturnType(getContext()));
+
+    if (DestTy->isVectorTy()) {
+      const VectorType *VecTy =
+          E->getCallReturnType(getContext())->getAs<VectorType>();
+      DestTy = ConvertType(VecTy->getElementType());
+    }
 
     if (Op->getType()->isVectorTy()) {
       const VectorType *VecTy = E->getArg(0)->getType()->getAs<VectorType>();
@@ -18879,10 +18889,6 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
 
     return Builder.CreateBitCast(Op, DestTy);
   }
-  case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
-    Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
-    return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
-  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 821009a4bde4ef..6a5ff52cd8e0cc 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -395,32 +395,32 @@ float4 asin(float4);
 /// \brief Interprets the bit pattern of x as an unsigned integer.
 /// \param Val The input value.
 
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint asuint(float);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint2 asuint(float2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint3 asuint(float3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint4 asuint(float4);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
+uint asuint(int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
+uint2 asuint(int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
+uint3 asuint(int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
+uint4 asuint(int4);
 
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
 uint asuint(uint);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
 uint2 asuint(uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
 uint3 asuint(uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
 uint4 asuint(uint4);
 
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint asuint(int);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint2 asuint(int2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint3 asuint(int3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
-uint4 asuint(int4);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
+uint asuint(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
+uint2 asuint(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
+uint3 asuint(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
+uint4 asuint(float4);
 
 //===----------------------------------------------------------------------===//
 // atan builtins
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 730b1e7542f5fc..b7e3476fd6f6c5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -9,6 +9,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Sema/SemaHLSL.h"
+#include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/AST/DeclCXX.h"
@@ -1765,13 +1766,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
-  case Builtin::BI__builtin_hlsl_elementwise_asuint: {
+  case Builtin::BI__builtin_hlsl_bit_cast_32: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
 
     if (CheckNotFloatAndInt(&SemaRef, TheCall))
       return true;
-
     break;
   }
   case Builtin::BI__builtin_elementwise_acos:
diff --git a/clang/test/CodeGenHLSL/builtins/asuint.hlsl b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
index 1edf07942d3a09..ac3dae26d6caed 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
@@ -3,39 +3,39 @@
 // CHECK: define {{.*}}test_uint{{.*}}(i32 {{.*}} [[VAL:%.*]]){{.*}} 
 // CHECK-NOT: bitcast
 // CHECK: ret i32 [[VAL]]
-export uint test_uint(uint p0) {
+uint test_uint(uint p0) {
   return asuint(p0);
 }
 
 // CHECK: define {{.*}}test_int{{.*}}(i32 {{.*}} [[VAL:%.*]]){{.*}} 
 // CHECK-NOT: bitcast
 // CHECK: ret i32 [[VAL]]
-export uint test_int(int p0) {
+uint test_int(int p0) {
   return asuint(p0);
 }
 
 // CHECK: define {{.*}}test_float{{.*}}(float {{.*}} [[VAL:%.*]]){{.*}} 
 // CHECK: bitcast float [[VAL]] to i32
-export uint test_float(float p0) {
+uint test_float(float p0) {
   return asuint(p0);
 }
 
 // CHECK: define {{.*}}test_vector_uint{{.*}}(<4 x i32> {{.*}} [[VAL:%.*]]){{.*}} 
 // CHECK-NOT: bitcast
 // CHECK: ret <4 x i32> [[VAL]]
-export uint4 test_vector_uint(uint4 p0) {
+uint4 test_vector_uint(uint4 p0) {
   return asuint(p0);
 }
 
 // CHECK: define {{.*}}test_vector_int{{.*}}(<4 x i32> {{.*}} [[VAL:%.*]]){{.*}} 
 // CHECK-NOT: bitcast
 // CHECK: ret <4 x i32> [[VAL]]
-export uint4 test_vector_int(int4 p0) {
+uint4 test_vector_int(int4 p0) {
   return asuint(p0);
 }
 
 // CHECK: define {{.*}}test_vector_float{{.*}}(<4 x float> {{.*}} [[VAL:%.*]]){{.*}} 
 // CHECK: bitcast <4 x float> [[VAL]] to <4 x i32>
-export uint4 test_vector_float(float4 p0) {
+uint4 test_vector_float(float4 p0) {
   return asuint(p0);
 }
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/bit-cast-32-errors.hlsl
similarity index 51%
rename from clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
rename to clang/test/SemaHLSL/BuiltIns/bit-cast-32-errors.hlsl
index e9da975bff1b5e..ff0108e353c1b7 100644
--- a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/bit-cast-32-errors.hlsl
@@ -1,18 +1,18 @@
 // RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
 
 
-export uint4 test_asuint_too_many_arg(float p0, float p1) {
-  return __builtin_hlsl_elementwise_asuint(p0, p1);
+uint4 test_asuint_too_many_arg(float p0, float p1) {
+  return __builtin_hlsl_bit_cast_32(p0, p1);
   // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
 }
 
-
-export uint fn(double p1) {
-    return asuint(p1);
+uint test_asuint_double(double p1) {
+    return __builtin_hlsl_bit_cast_32(p1);
     // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
 }
 
-export uint fn(half p1) {
-    return asuint(p1);
-    // expected-error at -1 {{passing 'half' to parameter of incompatible type 'float'}}
+
+uint test_asuint_half(half p1) {
+    return __builtin_hlsl_bit_cast_32(p1);
+    // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
 }

>From d2f55b360d65b8fea16aaab6bba9d7da284d6e49 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 12 Sep 2024 06:24:02 +0000
Subject: [PATCH 5/5] changing implementation of asuint to use bit_cast

---
 clang/include/clang/Basic/Builtins.td         |  6 ----
 clang/lib/CodeGen/CGBuiltin.cpp               | 19 ----------
 clang/lib/Headers/CMakeLists.txt              |  1 +
 clang/lib/Headers/hlsl/hlsl_details.h         | 35 +++++++++++++++++++
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 35 +++++--------------
 clang/lib/Sema/SemaHLSL.cpp                   | 20 -----------
 .../test/SemaHLSL/BuiltIns/asuint-errors.hlsl | 18 ++++++++++
 .../SemaHLSL/BuiltIns/bit-cast-32-errors.hlsl | 18 ----------
 8 files changed, 63 insertions(+), 89 deletions(-)
 create mode 100644 clang/lib/Headers/hlsl/hlsl_details.h
 create mode 100644 clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
 delete mode 100644 clang/test/SemaHLSL/BuiltIns/bit-cast-32-errors.hlsl

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index da298500acffae..3dc04f68b3172a 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4763,12 +4763,6 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-def HLSLBitCast32 : LangBuiltin<"HLSL_LANG"> {
-  let Spellings = ["__builtin_hlsl_bit_cast_32"];
-  let Attributes = [NoThrow, Const];
-  let Prototype = "void(...)";
-}
-
 def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_select"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d7693d6c905487..9950c06a0b9a6b 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18870,25 +18870,6 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
     return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
   }
-  case Builtin::BI__builtin_hlsl_bit_cast_32: {
-    Value *Op = EmitScalarExpr(E->getArg(0));
-
-    llvm::Type *DestTy = ConvertType(E->getCallReturnType(getContext()));
-
-    if (DestTy->isVectorTy()) {
-      const VectorType *VecTy =
-          E->getCallReturnType(getContext())->getAs<VectorType>();
-      DestTy = ConvertType(VecTy->getElementType());
-    }
-
-    if (Op->getType()->isVectorTy()) {
-      const VectorType *VecTy = E->getArg(0)->getType()->getAs<VectorType>();
-      DestTy = llvm::VectorType::get(
-          DestTy, ElementCount::getFixed(VecTy->getNumElements()));
-    }
-
-    return Builder.CreateBitCast(Op, DestTy);
-  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt
index e928b5b142827b..53b16afcf92402 100644
--- a/clang/lib/Headers/CMakeLists.txt
+++ b/clang/lib/Headers/CMakeLists.txt
@@ -87,6 +87,7 @@ set(hlsl_h
 set(hlsl_subdir_files
   hlsl/hlsl_basic_types.h
   hlsl/hlsl_intrinsics.h
+  hlsl/hlsl_details.h
   )
 set(hlsl_files
   ${hlsl_h}
diff --git a/clang/lib/Headers/hlsl/hlsl_details.h b/clang/lib/Headers/hlsl/hlsl_details.h
new file mode 100644
index 00000000000000..1aff524da6855f
--- /dev/null
+++ b/clang/lib/Headers/hlsl/hlsl_details.h
@@ -0,0 +1,35 @@
+//===----- hlsl_intrinsics.h - HLSL definitions for intrinsics ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _HLSL_HLSL_DETAILS_H_
+#define _HLSL_HLSL_DETAILS_H_
+
+namespace __details {
+#define HLSL_INLINE_ATTRIBUTE                                                  \
+  __attribute__((__always_inline__, __nodebug__)) static inline
+
+template <bool B, typename T> struct enable_if {};
+
+template <typename T> struct enable_if<true, T> { using Type = T; };
+
+template <typename U, typename T, int N>
+HLSL_INLINE_ATTRIBUTE
+    typename enable_if<sizeof(U) == sizeof(T), vector<U, N> >::Type
+    bit_cast(vector<T, N> V) {
+  return __builtin_bit_cast(vector<U, N>, V);
+}
+
+template <typename U, typename T>
+HLSL_INLINE_ATTRIBUTE typename enable_if<sizeof(U) == sizeof(T), U>::Type
+bit_cast(T F) {
+  return __builtin_bit_cast(U, F);
+}
+
+} // namespace __details
+
+#endif //_HLSL_HLSL_DETAILS_H_
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6a5ff52cd8e0cc..1dadf8c4c0e1ca 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -10,6 +10,7 @@
 #define _HLSL_HLSL_INTRINSICS_H_
 
 namespace hlsl {
+#include "hlsl_details.h"
 
 // Note: Functions in this file are sorted alphabetically, then grouped by base
 // element type, and the element types are sorted by size, then singed integer,
@@ -395,32 +396,14 @@ float4 asin(float4);
 /// \brief Interprets the bit pattern of x as an unsigned integer.
 /// \param Val The input value.
 
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint asuint(int);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint2 asuint(int2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint3 asuint(int3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint4 asuint(int4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint asuint(uint);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint2 asuint(uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint3 asuint(uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint4 asuint(uint4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint asuint(float);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint2 asuint(float2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint3 asuint(float3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_bit_cast_32)
-uint4 asuint(float4);
+template <typename T, int N>
+HLSL_INLINE_ATTRIBUTE vector<uint, N> asuint(vector<T, N> V) {
+  return __details::bit_cast<uint, T, N>(V);
+}
+
+template <typename T> HLSL_INLINE_ATTRIBUTE uint asuint(T F) {
+  return __details::bit_cast<uint, T>(F);
+}
 
 //===----------------------------------------------------------------------===//
 // atan builtins
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b7e3476fd6f6c5..4e44813fe515ce 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1497,18 +1497,6 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
                                   checkAllFloatTypes);
 }
 
-bool CheckNotFloatAndInt(Sema *S, CallExpr *TheCall) {
-  auto checkFloat = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !(BaseType->isFloat32Type() || BaseType->isIntegerType());
-  };
-
-  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy, checkFloat);
-}
-
 bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
   auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
     clang::QualType BaseType =
@@ -1766,14 +1754,6 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
-  case Builtin::BI__builtin_hlsl_bit_cast_32: {
-    if (SemaRef.checkArgCount(TheCall, 1))
-      return true;
-
-    if (CheckNotFloatAndInt(&SemaRef, TheCall))
-      return true;
-    break;
-  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
new file mode 100644
index 00000000000000..0760cea409640c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+
+uint4 test_asuint_too_many_arg(float p0, float p1) {
+  return asuint(p0, p1);
+  // expected-error at -1 {{no matching function for call to 'asuint'}}
+}
+
+uint test_asuint_double(double p1) {
+    return asuint(p1);
+    // expected-error at -1{clang/.*/include/hlsl/hlsl_details.h} {{no matching function for call to 'bit_cast'}}
+}
+
+
+uint test_asuint_half(half p1) {
+    return asuint(p1);
+    // expected-error at -1{clang/.*/include/hlsl/hlsl_details.h} {{no matching function for call to 'bit_cast'}}
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/bit-cast-32-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/bit-cast-32-errors.hlsl
deleted file mode 100644
index ff0108e353c1b7..00000000000000
--- a/clang/test/SemaHLSL/BuiltIns/bit-cast-32-errors.hlsl
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
-
-
-uint4 test_asuint_too_many_arg(float p0, float p1) {
-  return __builtin_hlsl_bit_cast_32(p0, p1);
-  // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
-}
-
-uint test_asuint_double(double p1) {
-    return __builtin_hlsl_bit_cast_32(p1);
-    // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
-}
-
-
-uint test_asuint_half(half p1) {
-    return __builtin_hlsl_bit_cast_32(p1);
-    // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
-}



More information about the cfe-commits mailing list