[clang] [llvm] [HLSL] Implementation of dot intrinsic (PR #81190)

Farzon Lotfi via cfe-commits cfe-commits at lists.llvm.org
Fri Feb 23 14:36:06 PST 2024


================
@@ -5161,6 +5166,157 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
   return false;
 }
 
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Note: UsualArithmeticConversions handles the case where at least
+// one arg isn't a bool
+bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
+  unsigned NumArgs = TheCall->getNumArgs();
+
+  for (unsigned i = 0; i < NumArgs; ++i) {
+    ExprResult A = TheCall->getArg(i);
+    if (!A.get()->getType()->isBooleanType())
+      return false;
+  }
+  // if we got here all args are bool
+  for (unsigned i = 0; i < NumArgs; ++i) {
+    ExprResult A = TheCall->getArg(i);
+    ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
+                                                   Sema::AA_Converting);
+    if (ResA.isInvalid())
+      return true;
+    TheCall->setArg(i, ResA.get());
+  }
+  return false;
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Handles the CK_HLSLVectorTruncation case for builtins
+void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() > 1);
+  ExprResult A = TheCall->getArg(0);
+  ExprResult B = TheCall->getArg(1);
+  QualType ArgTyA = A.get()->getType();
+  QualType ArgTyB = B.get()->getType();
+
+  auto *VecTyA = ArgTyA->getAs<VectorType>();
+  auto *VecTyB = ArgTyB->getAs<VectorType>();
+  if (VecTyA == nullptr && VecTyB == nullptr)
+    return;
+  if (VecTyA == nullptr || VecTyB == nullptr)
+    return;
+  if (VecTyA->getNumElements() == VecTyB->getNumElements())
+    return;
+
+  Expr *LargerArg = B.get();
+  Expr *SmallerArg = A.get();
+  int largerIndex = 1;
+  if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
+    LargerArg = A.get();
+    SmallerArg = B.get();
+    largerIndex = 0;
+  }
+
+  S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
+      << LargerArg->getType() << SmallerArg->getType()
+      << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
+  ExprResult ResLargerArg = S->ImpCastExprToType(
+      LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
+  TheCall->setArg(largerIndex, ResLargerArg.get());
+  return;
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
+                               SourceRange targetSrcRange,
+                               SourceLocation BuiltinLoc) {
+  auto *vecTyTarget = source.get()->getType()->getAs<VectorType>();
+  assert(vecTyTarget);
+  QualType vecElemT = vecTyTarget->getElementType();
+  if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
+    QualType floatVecTy = S->Context.getVectorType(
+        S->Context.FloatTy, vecTyTarget->getNumElements(), VectorKind::Generic);
+
+    S->Diag(BuiltinLoc, diag::warn_impcast_integer_float_precision)
+        << source.get()->getType() << floatVecTy
+        << source.get()->getSourceRange() << targetSrcRange;
+    source = S->SemaConvertVectorExpr(
+        source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
+        source.get()->getBeginLoc());
+  }
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void PromoteVectorArgSplat(Sema *S, ExprResult &source, QualType targetTy) {
+  QualType sourceTy = source.get()->getType();
+  auto *vecTyTarget = targetTy->getAs<VectorType>();
+  QualType vecElemT = vecTyTarget->getElementType();
+  if (vecElemT->isFloatingType() && sourceTy != vecElemT)
+    // if float vec splat wil do an unnecessary cast to double
+    source = S->ImpCastExprToType(source.get(), vecElemT, CK_FloatingCast);
+  source = S->ImpCastExprToType(source.get(), targetTy, CK_VectorSplat);
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() > 1);
+  ExprResult A = TheCall->getArg(0);
+  ExprResult B = TheCall->getArg(1);
+  QualType ArgTyA = A.get()->getType();
+  QualType ArgTyB = B.get()->getType();
+  auto *VecTyA = ArgTyA->getAs<VectorType>();
+  auto *VecTyB = ArgTyB->getAs<VectorType>();
+
+  if (VecTyA == nullptr && VecTyB == nullptr)
+    return false;
+
+  if (VecTyA && VecTyB) {
+    if (VecTyA->getElementType() == VecTyB->getElementType()) {
+      TheCall->setType(VecTyA->getElementType());
+      return false;
+    }
+    // Note: type promotion is intended to be handeled via the intrinsics
+    //  and not the builtin itself.
+    S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
+        << TheCall->getDirectCallee()
+        << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
+    return true;
+  }
+
+  if (VecTyB) {
+    CheckVectorFloatPromotion(S, B, ArgTyA, A.get()->getSourceRange(),
+                              TheCall->getBeginLoc());
+    PromoteVectorArgSplat(S, A, B.get()->getType());
+  }
+  if (VecTyA) {
+    CheckVectorFloatPromotion(S, A, ArgTyB, B.get()->getSourceRange(),
+                              TheCall->getBeginLoc());
+    PromoteVectorArgSplat(S, B, A.get()->getType());
+  }
+  TheCall->setArg(0, A.get());
+  TheCall->setArg(1, B.get());
+  return false;
+}
+
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    if (checkArgCount(*this, TheCall, 2))
+      return true;
+    if (PromoteBoolsToInt(this, TheCall))
+      return true;
+    if (CheckVectorElementCallArgs(this, TheCall))
+      return true;
+    PromoteVectorArgTruncation(this, TheCall);
+    if (SemaBuiltinVectorToScalarMath(TheCall))
----------------
farzonl wrote:

Some of the overload behavior isn't possible to be tested via the direct call because of the multiple definitions of dot product that it is trying to match against. That makes it difficult to check  IR  parity with DXC when the direct call  errors with `call to 'dot' is ambiguous`. That said I can easily remove this I'm just concerned about completeness. 

https://github.com/llvm/llvm-project/pull/81190


More information about the cfe-commits mailing list