[clang] Adding `asuint` implementation to hlsl (PR #107292)

via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 5 18:52:22 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: None (joaosaffran)

<details>
<summary>Changes</summary>

Implements support for the `asuint` HLSL function casting behaviour.
Addressing the `splitdouble` scenario will be addressed in a future PR. 

Fixes: #<!-- -->70097

---
Full diff: https://github.com/llvm/llvm-project/pull/107292.diff


6 Files Affected:

- (modified) clang/include/clang/Basic/Builtins.td (+6) 
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+13) 
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+17) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+46) 
- (added) clang/test/CodeGenHLSL/builtins/asuint.hlsl (+26) 
- (added) clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl (+18) 


``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 9e2a590f265ac8..38de1df11b7b5a 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4745,6 +4745,12 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLAsUint : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_asuint"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e826c1c6fbbd23..035858e7e291b0 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18812,6 +18812,19 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
   }
+  case Builtin::BI__builtin_hlsl_elementwise_asuint: {
+    Value *Op = EmitScalarExpr(E->getArg(0)->IgnoreImpCasts());
+
+    llvm::Type *DestTy = llvm::Type::getInt32Ty(this->getLLVMContext());
+
+    if (Op->getType()->isVectorTy()) {
+      const VectorType *VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+      DestTy = llvm::VectorType::get(
+          DestTy, ElementCount::getFixed(VecTy->getNumElements()));
+    }
+
+    return Builder.CreateBitCast(Op, DestTy);
+  }
   case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
     return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5c08a45a35377d..f40469937ddc38 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -387,6 +387,23 @@ float3 asin(float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_asin)
 float4 asin(float4);
 
+//===----------------------------------------------------------------------===//
+// asuint builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn uint asuint(T Val)
+/// \brief Interprets the bit pattern of x as an unsigned integer.
+/// \param Val The input value.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint asuint(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint2 asuint(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint3 asuint(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_asuint)
+uint4 asuint(float4);
+
 //===----------------------------------------------------------------------===//
 // atan builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 778d524a005482..3adf571c75f7b0 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -12,6 +12,7 @@
 #include "clang/AST/Decl.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/AST/Type.h"
 #include "clang/Basic/DiagnosticSema.h"
 #include "clang/Basic/LLVM.h"
 #include "clang/Basic/SourceLocation.h"
@@ -1401,6 +1402,25 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
   return true;
 }
 
+bool CheckArgTypeWithoutImplicits(
+    Sema *S, Expr *Arg, QualType ExpectedType,
+    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+
+  QualType ArgTy = Arg->IgnoreImpCasts()->getType();
+
+  clang::QualType BaseType =
+      ArgTy->isVectorType()
+          ? ArgTy->getAs<clang::VectorType>()->getElementType()
+          : ArgTy;
+
+  if (Check(BaseType)) {
+    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+        << ArgTy << ExpectedType << 1 << 0 << 0;
+    return true;
+  }
+  return false;
+}
+
 bool CheckArgsTypesAreCorrect(
     Sema *S, CallExpr *TheCall, QualType ExpectedType,
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
@@ -1427,6 +1447,22 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
                                   checkAllFloatTypes);
 }
 
+bool CheckArgIsFloatOrIntWithoutImplicits(Sema *S, Expr *Arg) {
+  auto checkFloat = [](clang::QualType PassedType) -> bool {
+    return !PassedType->isFloat32Type() && !PassedType->isIntegerType();
+  };
+
+  return CheckArgTypeWithoutImplicits(S, Arg, S->Context.FloatTy, checkFloat);
+}
+
+bool CheckArgIsIntegerWithoutImplicits(Sema *S, Expr *Arg) {
+  auto checkFloat = [](clang::QualType PassedType) -> bool {
+    return !PassedType->isIntegerType();
+  };
+
+  return CheckArgTypeWithoutImplicits(S, Arg, S->Context.FloatTy, checkFloat);
+}
+
 bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
   auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
     clang::QualType BaseType =
@@ -1581,6 +1617,16 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_elementwise_asuint: {
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    Expr *Arg = TheCall->getArg(0);
+    if (CheckArgIsFloatOrIntWithoutImplicits(&SemaRef, Arg))
+      return true;
+
+    break;
+  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/asuint.hlsl b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
new file mode 100644
index 00000000000000..2ae7d8219ac671
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asuint.hlsl
@@ -0,0 +1,26 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+
+// CHECK-LABEL: test_asuint4_uint
+// CHECK: ret i32 %0
+export uint test_asuint4_uint(uint p0) {
+  return asuint(p0);
+}
+
+// CHECK-LABEL: test_asuint4_int
+// CHECK: %splat.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
+export uint4 test_asuint4_int(int p0) {
+  return asuint(p0);
+}
+
+// CHECK-LABEL: test_asuint_float
+// CHECK: %1 = bitcast float %0 to i32
+export uint test_asuint_float(float p0) {
+  return asuint(p0);
+}
+
+// CHECK-LABEL: test_asuint_float
+// CHECK: %1 = bitcast <4 x float> %0 to <4 x i32>
+export uint4 test_asuint_float4(float4 p0) {
+  return asuint(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
new file mode 100644
index 00000000000000..e9da975bff1b5e
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+
+export uint4 test_asuint_too_many_arg(float p0, float p1) {
+  return __builtin_hlsl_elementwise_asuint(p0, p1);
+  // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+
+export uint fn(double p1) {
+    return asuint(p1);
+    // expected-error at -1 {{passing 'double' to parameter of incompatible type 'float'}}
+}
+
+export uint fn(half p1) {
+    return asuint(p1);
+    // expected-error at -1 {{passing 'half' to parameter of incompatible type 'float'}}
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/107292


More information about the cfe-commits mailing list