[clang] [HLSL] Update Sema Checking Diagnostics for builtins (PR #138429)

via cfe-commits cfe-commits at lists.llvm.org
Sat May 3 17:43:42 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-hlsl

Author: Sarah Spall (spall)

<details>
<summary>Changes</summary>

Update how Sema Checking is done for HLSL builtins to allow for better error messages, mainly using 'err_builtin_invalid_arg_type'.
Try to follow the formula outlined in issue #<!-- -->134721 
Closes #<!-- -->134721

---

Patch is 61.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138429.diff


21 Files Affected:

- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+1-1) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+131-185) 
- (modified) clang/test/SemaHLSL/BuiltIns/AddUint64-errors.hlsl (+3-3) 
- (modified) clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl (+10) 
- (modified) clang/test/SemaHLSL/BuiltIns/clamp-errors.hlsl (+16-11) 
- (modified) clang/test/SemaHLSL/BuiltIns/cross-errors.hlsl (+9-3) 
- (modified) clang/test/SemaHLSL/BuiltIns/degrees-errors.hlsl (+3-3) 
- (modified) clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl (+17-16) 
- (modified) clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl (+4-4) 
- (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors.hlsl (+2-2) 
- (modified) clang/test/SemaHLSL/BuiltIns/half-float-only-errors2.hlsl (+2-2) 
- (modified) clang/test/SemaHLSL/BuiltIns/isinf-errors.hlsl (+6-6) 
- (modified) clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl (+10-10) 
- (modified) clang/test/SemaHLSL/BuiltIns/logical-operator-errors.hlsl (+6) 
- (modified) clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl (+11-11) 
- (modified) clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl (+4-4) 
- (modified) clang/test/SemaHLSL/BuiltIns/radians-errors.hlsl (+4-4) 
- (modified) clang/test/SemaHLSL/BuiltIns/rcp-errors.hlsl (+4-5) 
- (modified) clang/test/SemaHLSL/BuiltIns/reversebits-errors.hlsl (+1-1) 
- (modified) clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl (+5-6) 
- (modified) clang/test/SemaHLSL/BuiltIns/step-errors.hlsl (+4-4) 


``````````diff
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index ccb14e9927adf..c94d34e0259be 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12649,7 +12649,7 @@ def err_builtin_invalid_arg_type: Error<
   // An 'or' if non-empty second and third components are combined
   "%plural{0:|:%plural{0:|:or }2}3"
   // Third component: floating-point types
-  "%select{|floating-point}3"
+  "%select{|floating-point|16 or 32 bit floating-point}3"
   // A space after a non-empty third component
   "%plural{0:|: }3"
   "%plural{[0,3]:type|:types}1 (was %4)">;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 70aacaa2aadbe..6486ef765f32e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2005,68 +2005,6 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
   DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
 }
 
-// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
-  assert(TheCall->getNumArgs() > 1);
-  ExprResult A = TheCall->getArg(0);
-
-  QualType ArgTyA = A.get()->getType();
-
-  auto *VecTyA = ArgTyA->getAs<VectorType>();
-  SourceLocation BuiltinLoc = TheCall->getBeginLoc();
-
-  bool AllBArgAreVectors = true;
-  for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
-    ExprResult B = TheCall->getArg(i);
-    QualType ArgTyB = B.get()->getType();
-    auto *VecTyB = ArgTyB->getAs<VectorType>();
-    if (VecTyB == nullptr)
-      AllBArgAreVectors &= false;
-    if (VecTyA && VecTyB == nullptr) {
-      // Note: if we get here 'B' is scalar which
-      // requires a VectorSplat on ArgN
-      S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-          << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-      return true;
-    }
-    if (VecTyA && VecTyB) {
-      bool retValue = false;
-      if (!S->Context.hasSameUnqualifiedType(VecTyA->getElementType(),
-                                             VecTyB->getElementType())) {
-        // Note: type promotion is intended to be handeled via the intrinsics
-        //  and not the builtin itself.
-        S->Diag(TheCall->getBeginLoc(),
-                diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
-        // You should only be hitting this case if you are calling the builtin
-        // directly. HLSL intrinsics should avoid this case via a
-        // HLSLVectorTruncation.
-        S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
-            << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-            << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
-        retValue = true;
-      }
-      if (retValue)
-        return retValue;
-    }
-  }
-
-  if (VecTyA == nullptr && AllBArgAreVectors) {
-    // Note: if we get here 'A' is a scalar which
-    // requires a VectorSplat on Arg0
-    S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
-        << TheCall->getDirectCallee() << /*useAllTerminology*/ true
-        << SourceRange(A.get()->getBeginLoc(), A.get()->getEndLoc());
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() > 1);
   QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2094,63 +2032,46 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckArgTypeIsCorrect(
-    Sema *S, Expr *Arg, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  QualType PassedType = Arg->getType();
-  if (Check(PassedType)) {
-    if (auto *VecTyA = PassedType->getAs<VectorType>())
-      ExpectedType = S->Context.getVectorType(
-          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
-    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
-        << PassedType << ExpectedType << 1 << 0 << 0;
-    return true;
-  }
-  return false;
-}
-
 static bool CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
-    Expr *Arg = TheCall->getArg(i);
-    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                            clang::QualType PassedType)>
+        Check) {
+  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
+    Expr *Arg = TheCall->getArg(I);
+    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
       return true;
-    }
   }
   return false;
 }
 
-static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkAllFloatTypes);
-}
+static bool CheckFloatOrHalfVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
 
-static bool CheckUnsignedIntRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkUnsignedInteger = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isUnsignedIntegerType();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkUnsignedInteger);
+  if (!PassedType->getAs<VectorType>() ||
+      !(EltTy->isHalfType() || EltTy->isFloat32Type()))
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
-static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
-  auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
-    clang::QualType BaseType =
-        PassedType->isVectorType()
-            ? PassedType->getAs<clang::VectorType>()->getElementType()
-            : PassedType;
-    return !BaseType->isHalfType() && !BaseType->isFloat32Type();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkFloatorHalf);
+static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->getAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
 }
 
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
@@ -2164,30 +2085,49 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
   return true;
 }
 
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
-  auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
-    if (const auto *VecTy = PassedType->getAs<VectorType>())
-      return VecTy->getElementType()->isDoubleType();
-    return false;
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                    checkDoubleVector);
+static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                                 clang::QualType PassedType) {
+  if (const auto *VecTy = PassedType->getAs<VectorType>())
+    if (VecTy->getElementType()->isDoubleType())
+      return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
+             << PassedType;
+  return false;
 }
-static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasIntegerRepresentation() &&
-           !PassedType->hasFloatingRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
-                                    checkAllSignedTypes);
+
+static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType) {
+  if (!PassedType->hasIntegerRepresentation() &&
+      !PassedType->hasFloatingRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
+           << /* fp */ 1 << PassedType;
+  return false;
 }
 
-static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
-  auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
-    return !PassedType->hasUnsignedIntegerRepresentation();
-  };
-  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                    checkAllUnsignedTypes);
+static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
+                                              int ArgOrdinal,
+                                              clang::QualType PassedType) {
+  QualType EltTy = PassedType;
+  if (auto *VecTy = EltTy->getAs<VectorType>())
+    EltTy = VecTy->getElementType();
+
+  if (!PassedType->getAs<VectorType>() || !EltTy->isUnsignedIntegerType())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
+           << PassedType;
+  return false;
+}
+
+static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  if (!PassedType->hasUnsignedIntegerRepresentation())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
+           << /* no fp */ 0 << PassedType;
+  return false;
 }
 
 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
@@ -2343,23 +2283,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_adduint64: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
-    if (CheckUnsignedIntRepresentations(&SemaRef, TheCall))
-      return true;
-
-    // CheckVectorElementCallArgs(...) guarantees both args are the same type.
-    assert(TheCall->getArg(0)->getType() == TheCall->getArg(1)->getType() &&
-           "Both args must be of the same type");
 
-    // ensure both args are vectors
-    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
-    if (!VTy) {
-      SemaRef.Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_non_vector)
-          << TheCall->getDirectCallee() << /*all*/ 1;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckUnsignedIntVecRepresentation))
       return true;
-    }
 
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
     // ensure arg integers are 32-bits
     uint64_t ElementBitCount = getASTContext()
                                    .getTypeSizeInChars(VTy->getElementType())
@@ -2380,6 +2309,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     }
 
+    // ensure first arg and second arg have the same type
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
+
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2431,10 +2364,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_or: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
       return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
+      return true;
 
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
@@ -2446,37 +2379,41 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_any: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_asdouble: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            0))
+      return true;
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            1))
+      return true;
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
 
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().DoubleTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_cross: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+
+    // ensure args are a half3 or float3
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfVecRepresentation))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
     // ensure both args have 3 elements
     int NumElementsArg1 =
@@ -2507,13 +2444,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_dot: {
-    if (SemaRef.checkArgCount(TheCall, 2))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinVectorToScalarMath(TheCall))
       return true;
-    if (CheckNoDoubleVectors(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, CheckNoDoubleVectors))
       return true;
     break;
   }
@@ -2560,8 +2493,15 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
-    if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    if (!TheCall->getArg(0)
+             ->getType()
+             ->hasFloatingRepresentation()) // half or float or double
+      return SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+                          diag::err_builtin_invalid_arg_type)
+             << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
+             << /* fp */ 1 << TheCall->getArg(0)->getType();
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
@@ -2570,14 +2510,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_elementwise_radians:
   case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
@@ -2587,34 +2533,28 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_lerp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
-    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
-        CheckAllArgsHaveSameType(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
       return true;
-    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
+    if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_mad: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
-      return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
             TheCall, /*ArgTyRestr=*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
-                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
-                : Sema::EltwiseBuiltinArgTyRestriction::None))
+            Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_normalize: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
-      return true;
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatOrHalfRepresentation))
+      return true;
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
     // return type is the same as the input type
@@ -2622,17 +2562,19 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
-    if (CheckFloatingOrIntRepresentation(&SemaRef, TheCall))
-      return true;
     if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
       return true;
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   CheckFloatingOrIntRepresentation))
+      return true;
     SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().IntTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_step: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/138429


More information about the cfe-commits mailing list