[clang] [llvm] [HLSL] Implementation of dot intrinsic (PR #81190)

Chris B via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 23 15:24:52 PST 2024


================
@@ -5161,6 +5166,157 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
   return false;
 }
 
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Note: UsualArithmeticConversions handles the case where at least
+// one arg isn't a bool
+bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
+  unsigned NumArgs = TheCall->getNumArgs();
+
+  for (unsigned i = 0; i < NumArgs; ++i) {
+    ExprResult A = TheCall->getArg(i);
+    if (!A.get()->getType()->isBooleanType())
+      return false;
+  }
+  // if we got here all args are bool
+  for (unsigned i = 0; i < NumArgs; ++i) {
+    ExprResult A = TheCall->getArg(i);
+    ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
+                                                   Sema::AA_Converting);
+    if (ResA.isInvalid())
+      return true;
+    TheCall->setArg(i, ResA.get());
+  }
+  return false;
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Handles the CK_HLSLVectorTruncation case for builtins
+void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() > 1);
+  ExprResult A = TheCall->getArg(0);
+  ExprResult B = TheCall->getArg(1);
+  QualType ArgTyA = A.get()->getType();
+  QualType ArgTyB = B.get()->getType();
+
+  auto *VecTyA = ArgTyA->getAs<VectorType>();
+  auto *VecTyB = ArgTyB->getAs<VectorType>();
+  if (VecTyA == nullptr && VecTyB == nullptr)
+    return;
+  if (VecTyA == nullptr || VecTyB == nullptr)
+    return;
+  if (VecTyA->getNumElements() == VecTyB->getNumElements())
+    return;
+
+  Expr *LargerArg = B.get();
+  Expr *SmallerArg = A.get();
+  int largerIndex = 1;
+  if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
+    LargerArg = A.get();
+    SmallerArg = B.get();
+    largerIndex = 0;
+  }
+
+  S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
+      << LargerArg->getType() << SmallerArg->getType()
+      << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
+  ExprResult ResLargerArg = S->ImpCastExprToType(
+      LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
+  TheCall->setArg(largerIndex, ResLargerArg.get());
+  return;
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
+                               SourceRange targetSrcRange,
+                               SourceLocation BuiltinLoc) {
+  auto *vecTyTarget = source.get()->getType()->getAs<VectorType>();
+  assert(vecTyTarget);
+  QualType vecElemT = vecTyTarget->getElementType();
+  if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
+    QualType floatVecTy = S->Context.getVectorType(
+        S->Context.FloatTy, vecTyTarget->getNumElements(), VectorKind::Generic);
+
+    S->Diag(BuiltinLoc, diag::warn_impcast_integer_float_precision)
+        << source.get()->getType() << floatVecTy
+        << source.get()->getSourceRange() << targetSrcRange;
+    source = S->SemaConvertVectorExpr(
+        source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
+        source.get()->getBeginLoc());
+  }
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void PromoteVectorArgSplat(Sema *S, ExprResult &source, QualType targetTy) {
+  QualType sourceTy = source.get()->getType();
+  auto *vecTyTarget = targetTy->getAs<VectorType>();
+  QualType vecElemT = vecTyTarget->getElementType();
+  if (vecElemT->isFloatingType() && sourceTy != vecElemT)
+    // if float vec splat wil do an unnecessary cast to double
+    source = S->ImpCastExprToType(source.get(), vecElemT, CK_FloatingCast);
+  source = S->ImpCastExprToType(source.get(), targetTy, CK_VectorSplat);
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() > 1);
+  ExprResult A = TheCall->getArg(0);
+  ExprResult B = TheCall->getArg(1);
+  QualType ArgTyA = A.get()->getType();
+  QualType ArgTyB = B.get()->getType();
+  auto *VecTyA = ArgTyA->getAs<VectorType>();
+  auto *VecTyB = ArgTyB->getAs<VectorType>();
+
+  if (VecTyA == nullptr && VecTyB == nullptr)
+    return false;
+
+  if (VecTyA && VecTyB) {
+    if (VecTyA->getElementType() == VecTyB->getElementType()) {
+      TheCall->setType(VecTyA->getElementType());
+      return false;
+    }
+    // Note: type promotion is intended to be handeled via the intrinsics
+    //  and not the builtin itself.
+    S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
+        << TheCall->getDirectCallee()
+        << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
+    return true;
+  }
+
+  if (VecTyB) {
+    CheckVectorFloatPromotion(S, B, ArgTyA, A.get()->getSourceRange(),
+                              TheCall->getBeginLoc());
+    PromoteVectorArgSplat(S, A, B.get()->getType());
+  }
+  if (VecTyA) {
+    CheckVectorFloatPromotion(S, A, ArgTyB, B.get()->getSourceRange(),
+                              TheCall->getBeginLoc());
+    PromoteVectorArgSplat(S, B, A.get()->getType());
+  }
+  TheCall->setArg(0, A.get());
+  TheCall->setArg(1, B.get());
+  return false;
+}
+
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    if (checkArgCount(*this, TheCall, 2))
+      return true;
+    if (PromoteBoolsToInt(this, TheCall))
+      return true;
+    if (CheckVectorElementCallArgs(this, TheCall))
+      return true;
+    PromoteVectorArgTruncation(this, TheCall);
+    if (SemaBuiltinVectorToScalarMath(TheCall))
----------------
llvm-beanz wrote:

We shouldn’t be testing overload resolution in this change. We should test overload resolution separately as we implement HLSL’s overload resolution rules. It isn’t possible to get correct overload resolution until we implement it, and adding a bunch of complexity to try and test something that isn’t implemented seems wrong.

The call to `dot` should not be ambiguous in any cases where exact-match overloads exist, so we should be able to test code generation for the intrinsic in the exact match cases, and have comprehensive overload resolution tests later when that gets implemented.

DXC implemented overload resolution logic multiple times just like this which leads to confusing and subtle bugs. It would be best to not make that mistake again.

https://github.com/llvm/llvm-project/pull/81190


More information about the llvm-commits mailing list