[clang] [llvm] [HLSL][DXIL] Implement `refract` intrinsic (PR #136026)
Sarah Spall via cfe-commits
cfe-commits at lists.llvm.org
Tue Jun 24 12:03:05 PDT 2025
================
@@ -16,88 +16,90 @@ namespace clang {
SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
+/// Checks if the first `NumArgsToCheck` arguments of a function call are of
+/// vector type. If any of the arguments is not a vector type, it emits a
+/// diagnostic error and returns `true`. Otherwise, it returns `false`.
+///
+/// \param TheCall The function call expression to check.
+/// \param NumArgsToCheck The number of arguments to check for vector type.
+/// \return `true` if any of the arguments is not a vector type, `false`
+/// otherwise.
+
+bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
+ for (unsigned i = 0; i < NumArgsToCheck; ++i) {
+ ExprResult Arg = TheCall->getArg(i);
+ QualType ArgTy = Arg.get()->getType();
+ auto *VTy = ArgTy->getAs<VectorType>();
+ if (VTy == nullptr) {
+ SemaRef.Diag(Arg.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTy
+ << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+ }
+ return false;
+}
+
bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
switch (BuiltinID) {
case SPIRV::BI__builtin_spirv_distance: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
- ExprResult A = TheCall->getArg(0);
- QualType ArgTyA = A.get()->getType();
- auto *VTyA = ArgTyA->getAs<VectorType>();
- if (VTyA == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyA
- << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
- << 0 << 0;
- return true;
- }
-
- ExprResult B = TheCall->getArg(1);
- QualType ArgTyB = B.get()->getType();
- auto *VTyB = ArgTyB->getAs<VectorType>();
- if (VTyB == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyB
- << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
- << 0 << 0;
+ // Use the helper function to check both arguments
+ if (CheckVectorArgs(TheCall, 2))
return true;
- }
- QualType RetTy = VTyA->getElementType();
+ QualType RetTy =
+ TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_length: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
- ExprResult A = TheCall->getArg(0);
- QualType ArgTyA = A.get()->getType();
- auto *VTy = ArgTyA->getAs<VectorType>();
- if (VTy == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyA
- << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
- << 0 << 0;
+
+ // Use the helper function to check the argument
+ if (CheckVectorArgs(TheCall, 1))
return true;
- }
- QualType RetTy = VTy->getElementType();
+
+ QualType RetTy =
+ TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
TheCall->setType(RetTy);
break;
}
- case SPIRV::BI__builtin_spirv_reflect: {
- if (SemaRef.checkArgCount(TheCall, 2))
+ case SPIRV::BI__builtin_spirv_refract: {
+ if (SemaRef.checkArgCount(TheCall, 3))
return true;
- ExprResult A = TheCall->getArg(0);
- QualType ArgTyA = A.get()->getType();
- auto *VTyA = ArgTyA->getAs<VectorType>();
- if (VTyA == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyA
- << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
- << 0 << 0;
+ // Use the helper function to check the first two arguments
+ if (CheckVectorArgs(TheCall, 2))
----------------
spall wrote:
Same comment here about following the SemaHLSL style.
https://github.com/llvm/llvm-project/pull/136026
More information about the cfe-commits
mailing list