[clang] [llvm] [HLSL] Implementation of dot intrinsic (PR #81190)

Chris B via cfe-commits cfe-commits at lists.llvm.org
Fri Feb 23 13:58:50 PST 2024


================
@@ -5161,6 +5166,157 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
   return false;
 }
 
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Note: UsualArithmeticConversions handles the case where at least
+// one arg isn't a bool
+bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
+  unsigned NumArgs = TheCall->getNumArgs();
+
+  for (unsigned i = 0; i < NumArgs; ++i) {
+    ExprResult A = TheCall->getArg(i);
+    if (!A.get()->getType()->isBooleanType())
+      return false;
+  }
+  // if we got here all args are bool
+  for (unsigned i = 0; i < NumArgs; ++i) {
+    ExprResult A = TheCall->getArg(i);
+    ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
+                                                   Sema::AA_Converting);
+    if (ResA.isInvalid())
+      return true;
+    TheCall->setArg(i, ResA.get());
+  }
+  return false;
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Handles the CK_HLSLVectorTruncation case for builtins
+void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() > 1);
+  ExprResult A = TheCall->getArg(0);
+  ExprResult B = TheCall->getArg(1);
+  QualType ArgTyA = A.get()->getType();
+  QualType ArgTyB = B.get()->getType();
+
+  auto *VecTyA = ArgTyA->getAs<VectorType>();
+  auto *VecTyB = ArgTyB->getAs<VectorType>();
+  if (VecTyA == nullptr && VecTyB == nullptr)
+    return;
+  if (VecTyA == nullptr || VecTyB == nullptr)
+    return;
+  if (VecTyA->getNumElements() == VecTyB->getNumElements())
+    return;
+
+  Expr *LargerArg = B.get();
+  Expr *SmallerArg = A.get();
+  int largerIndex = 1;
+  if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
+    LargerArg = A.get();
+    SmallerArg = B.get();
+    largerIndex = 0;
+  }
+
+  S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
+      << LargerArg->getType() << SmallerArg->getType()
+      << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
+  ExprResult ResLargerArg = S->ImpCastExprToType(
+      LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
+  TheCall->setArg(largerIndex, ResLargerArg.get());
+  return;
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
+                               SourceRange targetSrcRange,
+                               SourceLocation BuiltinLoc) {
+  auto *vecTyTarget = source.get()->getType()->getAs<VectorType>();
+  assert(vecTyTarget);
+  QualType vecElemT = vecTyTarget->getElementType();
+  if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
+    QualType floatVecTy = S->Context.getVectorType(
+        S->Context.FloatTy, vecTyTarget->getNumElements(), VectorKind::Generic);
+
+    S->Diag(BuiltinLoc, diag::warn_impcast_integer_float_precision)
+        << source.get()->getType() << floatVecTy
+        << source.get()->getSourceRange() << targetSrcRange;
+    source = S->SemaConvertVectorExpr(
+        source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
+        source.get()->getBeginLoc());
+  }
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void PromoteVectorArgSplat(Sema *S, ExprResult &source, QualType targetTy) {
+  QualType sourceTy = source.get()->getType();
+  auto *vecTyTarget = targetTy->getAs<VectorType>();
+  QualType vecElemT = vecTyTarget->getElementType();
+  if (vecElemT->isFloatingType() && sourceTy != vecElemT)
+    // if float vec splat wil do an unnecessary cast to double
+    source = S->ImpCastExprToType(source.get(), vecElemT, CK_FloatingCast);
+  source = S->ImpCastExprToType(source.get(), targetTy, CK_VectorSplat);
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() > 1);
+  ExprResult A = TheCall->getArg(0);
+  ExprResult B = TheCall->getArg(1);
+  QualType ArgTyA = A.get()->getType();
+  QualType ArgTyB = B.get()->getType();
+  auto *VecTyA = ArgTyA->getAs<VectorType>();
+  auto *VecTyB = ArgTyB->getAs<VectorType>();
+
+  if (VecTyA == nullptr && VecTyB == nullptr)
+    return false;
+
+  if (VecTyA && VecTyB) {
+    if (VecTyA->getElementType() == VecTyB->getElementType()) {
+      TheCall->setType(VecTyA->getElementType());
+      return false;
+    }
+    // Note: type promotion is intended to be handeled via the intrinsics
+    //  and not the builtin itself.
+    S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
+        << TheCall->getDirectCallee()
+        << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
+    return true;
+  }
+
+  if (VecTyB) {
+    CheckVectorFloatPromotion(S, B, ArgTyA, A.get()->getSourceRange(),
+                              TheCall->getBeginLoc());
+    PromoteVectorArgSplat(S, A, B.get()->getType());
+  }
+  if (VecTyA) {
+    CheckVectorFloatPromotion(S, A, ArgTyB, B.get()->getSourceRange(),
+                              TheCall->getBeginLoc());
+    PromoteVectorArgSplat(S, B, A.get()->getType());
+  }
+  TheCall->setArg(0, A.get());
+  TheCall->setArg(1, B.get());
+  return false;
+}
+
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    if (checkArgCount(*this, TheCall, 2))
+      return true;
+    if (PromoteBoolsToInt(this, TheCall))
+      return true;
+    if (CheckVectorElementCallArgs(this, TheCall))
+      return true;
+    PromoteVectorArgTruncation(this, TheCall);
+    if (SemaBuiltinVectorToScalarMath(TheCall))
----------------
llvm-beanz wrote:

I'm a little concerned about this. It looks like this is a lot of logic to duplicate type promotion and resolution. Can't we just rely on function overload resolution to generate conversion sequences, and have the direct call for the builtin only support cases where the argument types exactly match?

https://github.com/llvm/llvm-project/pull/81190


More information about the cfe-commits mailing list