[clang] [llvm] [HLSL] Implementation of dot intrinsic (PR #81190)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 23 17:19:22 PST 2024
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/81190
>From f6188a3308188aa3037b05f685a6065bfc2d69fa Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Thu, 8 Feb 2024 11:08:59 -0500
Subject: [PATCH 1/6] [HLSL] Implementation of dot intrinsic This change
implements #70073
HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot
The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td.
This is used to associate all the dot product typdef defined hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp.
In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally `a dot b` should be the same as `b dot a`.
In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot
product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL.
The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp
EmitHLSLBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.
Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics.
---
clang/include/clang/Basic/Builtins.td | 6 +
.../clang/Basic/DiagnosticSemaKinds.td | 2 +
clang/include/clang/Sema/Sema.h | 1 +
clang/lib/CodeGen/CGBuiltin.cpp | 49 ++++
clang/lib/CodeGen/CodeGenFunction.h | 1 +
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 92 ++++++++
clang/lib/Sema/SemaChecking.cpp | 84 ++++++-
clang/test/CodeGenHLSL/builtins/dot.hlsl | 216 ++++++++++++++++++
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl | 46 ++++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 5 +
10 files changed, 497 insertions(+), 5 deletions(-)
create mode 100644 clang/test/CodeGenHLSL/builtins/dot.hlsl
create mode 100644 clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index df74026c5d2d50..771c4f5d4121f4 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void*(unsigned char)";
}
+def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_dot"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index a7f2858477bee6..9cce89b92be309 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10270,6 +10270,8 @@ def err_vec_builtin_non_vector : Error<
"first two arguments to %0 must be vectors">;
def err_vec_builtin_incompatible_vector : Error<
"first two arguments to %0 must have the same type">;
+def err_vec_builtin_incompatible_size : Error<
+ "first two arguments to %0 must have the same size">;
def err_vsx_builtin_nonconstant_argument : Error<
"argument %0 to %1 must be a 2-bit unsigned literal (i.e. 0, 1, 2 or 3)">;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index e457694e4625db..2c06c8edca329a 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14057,6 +14057,7 @@ class Sema final {
bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+ bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 734eb5a035ca49..393ab497fb15c4 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -44,6 +44,7 @@
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/IntrinsicsBPF.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/IntrinsicsHexagon.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsPowerPC.h"
@@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
}
+ // EmitHLSLBuiltinExpr will check getLangOpts().HLSL
+ if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
+ return RValue::get(V);
+
if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
return EmitHipStdParUnsupportedBuiltin(this, FD);
@@ -17959,6 +17964,50 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}
+Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
+ const CallExpr *E) {
+ if (!getLangOpts().HLSL)
+ return nullptr;
+
+ switch (BuiltinID) {
+ case Builtin::BI__builtin_hlsl_dot: {
+ Value *Op0 = EmitScalarExpr(E->getArg(0));
+ Value *Op1 = EmitScalarExpr(E->getArg(1));
+ llvm::Type *T0 = Op0->getType();
+ llvm::Type *T1 = Op1->getType();
+ if (!T0->isVectorTy() && !T1->isVectorTy()) {
+ if (T0->isFloatingPointTy()) {
+ return Builder.CreateFMul(Op0, Op1, "dx.dot");
+ }
+
+ if (T0->isIntegerTy()) {
+ return Builder.CreateMul(Op0, Op1, "dx.dot");
+ }
+ assert(
+ false &&
+ "Dot product on a scalar is only supported on integers and floats.");
+ }
+ assert(T0->isVectorTy() && T1->isVectorTy() &&
+ "Dot product of vector and scalar is not supported.");
+
+ // NOTE: this assert will need to be revisited after overload resoltion
+ // PR merges.
+ assert(T0->getScalarType() == T1->getScalarType() &&
+ "Dot product of vectors need the same element types.");
+
+ auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
+ auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
+ assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
+ "Dot product requires vectors to be of the same size.");
+
+ return Builder.CreateIntrinsic(
+ /*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
+ ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+ } break;
+ }
+ return nullptr;
+}
+
Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 92ce0edeaf9e9c..b2800f699ff4b9 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
+ llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
const CallExpr *E);
llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index f87ac977997962..a92f0d0849ba77 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -179,6 +179,98 @@ double3 cos(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
double4 cos(double4);
+//===----------------------------------------------------------------------===//
+// dot product builtins
+//===----------------------------------------------------------------------===//
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half, half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half2, half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half3, half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half4, half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t, int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t2, int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t3, int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t4, int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t, uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t2, uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t3, uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t4, uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float4, float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double, double);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int, int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int2, int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int3, int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int4, int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint4, uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t, int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t2, int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t3, int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t4, int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t, uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t2, uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t3, uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t4, uint64_t4);
+
//===----------------------------------------------------------------------===//
// floor builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 7fa295ebd94044..b55866f0b20c62 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
// not a valid type, emit an error message and return true. Otherwise return
// false.
static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
- QualType Ty) {
- if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
+ QualType ArgTy, int ArgIndex) {
+ if (!ArgTy->getAs<VectorType>() &&
+ !ConstantMatrixType::isValidElementType(ArgTy)) {
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
- << 1 << /* vector, integer or float ty*/ 0 << Ty;
+ << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
}
return false;
@@ -2961,6 +2962,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
}
}
+ if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) {
+ return ExprError();
+ }
+
// Since the target specific builtins for each arch overlap, only check those
// of the arch we are compiling for.
if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
@@ -5161,6 +5166,75 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
return false;
}
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+ switch (BuiltinID) {
+ case Builtin::BI__builtin_hlsl_dot: {
+ if (checkArgCount(*this, TheCall, 2)) {
+ return true;
+ }
+ Expr *Arg0 = TheCall->getArg(0);
+ QualType ArgTy0 = Arg0->getType();
+
+ Expr *Arg1 = TheCall->getArg(1);
+ QualType ArgTy1 = Arg1->getType();
+
+ auto *VecTy0 = ArgTy0->getAs<VectorType>();
+ auto *VecTy1 = ArgTy1->getAs<VectorType>();
+ SourceLocation BuiltinLoc = TheCall->getBeginLoc();
+
+ // if arg0 is bool then call Diag with err_builtin_invalid_arg_type
+ if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
+ return true;
+ }
+
+ // if arg1 is bool then call Diag with err_builtin_invalid_arg_type
+ if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
+ return true;
+ }
+
+ if (VecTy0 == nullptr && VecTy1 == nullptr) {
+ if (ArgTy0 != ArgTy1) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
+ (VecTy0 != nullptr && VecTy1 == nullptr)) {
+
+ Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
+ << TheCall->getDirectCallee()
+ << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+ TheCall->getArg(1)->getEndLoc());
+ return true;
+ }
+
+ if (VecTy0->getElementType() != VecTy1->getElementType()) {
+ // Note: This case should never happen. If type promotion occurs
+ // then element types won't be different. This diag error is here
+ // b\c EmitHLSLBuiltinExpr asserts on this case.
+ Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
+ << TheCall->getDirectCallee()
+ << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+ TheCall->getArg(1)->getEndLoc());
+ return true;
+ }
+ if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
+ Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
+ << TheCall->getDirectCallee()
+ << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+ TheCall->getArg(1)->getEndLoc());
+ return true;
+ }
+ break;
+ }
+ }
+ return false;
+}
+
bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
// position of memory order and scope arguments in the builtin
@@ -19594,7 +19668,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
TheCall->setArg(0, A.get());
QualType TyA = A.get()->getType();
- if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
+ if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
return true;
TheCall->setType(TyA);
@@ -19622,7 +19696,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
diag::err_typecheck_call_different_arg_types)
<< TyA << TyB;
- if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
+ if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
return true;
TheCall->setArg(0, A.get());
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
new file mode 100644
index 00000000000000..9a895cd190ba9f
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -0,0 +1,216 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// -fnative-half-type sets __HLSL_ENABLE_16_BIT
+#ifdef __HLSL_ENABLE_16_BIT
+// NATIVE_HALF: %dx.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short ( int16_t p0, int16_t p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short2 ( int16_t2 p0, int16_t2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short3 ( int16_t3 p0, int16_t3 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short4 ( int16_t4 p0, int16_t4 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort ( uint16_t p0, uint16_t p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort2 ( uint16_t2 p0, uint16_t2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) {
+ return dot ( p0, p1 );
+}
+#endif
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+int test_dot_int ( int p0, int p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int2 ( int2 p0, int2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int3 ( int3 p0, int3 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int4 ( int4 p0, int4 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint ( uint p0, uint p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint2 ( uint2 p0, uint2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint3 ( uint3 p0, uint3 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint4 ( uint4 p0, uint4 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_long ( int64_t p0, int64_t p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint2 ( int64_t2 p0, int64_t2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint3 ( int64_t3 p0, int64_t3 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint4 ( int64_t4 p0, int64_t4 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_ulong ( uint64_t p0, uint64_t p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint2 ( uint64_t2 p0, uint64_t2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint3 ( uint64_t3 p0, uint64_t3 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint4 ( uint64_t4 p0, uint64_t4 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = fmul half %0, %1
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = fmul float %0, %1
+// NO_HALF: ret float %dx.dot
+half test_dot_half ( half p0, half p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half2 ( half2 p0, half2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half3 ( half3 p0, half3 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half4 ( half4 p0, half4 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = fmul float %0, %1
+// CHECK: ret float %dx.dot
+float test_dot_float ( float p0, float p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float2 ( float2 p0, float2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float3 ( float3 p0, float3 p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float4 ( float4 p0, float4 p1) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = fmul double %0, %1
+// CHECK: ret double %dx.dot
+double test_dot_double ( double p0, double p1 ) {
+ return dot ( p0, p1 );
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
new file mode 100644
index 00000000000000..a5acb400ab9c7b
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+// NOTE: This test is marked XFAIL because when overload resolution merges
+// NOTE: test_dot_element_type_mismatch & test_dot_scalar_mismatch will have different behavior
+// XFAIL: *
+
+float test_first_arg_is_not_vector ( float p0, float2 p1 ) {
+ return dot ( p0, p1 );
+ // expected-error at -1 {{first two arguments to 'dot' must be vectors}}
+}
+
+float test_second_arg_is_not_vector ( float2 p0, float p1 ) {
+ return dot ( p0, p1 );
+ // expected-error at -1 {{first two arguments to 'dot' must be vectors}}
+}
+
+int test_dot_unsupported_scalar_arg0 ( bool p0, int p1 ) {
+ return dot ( p0, p1 );
+ // expected-error at -1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
+}
+
+int test_dot_unsupported_scalar_arg1 ( int p0, bool p1 ) {
+ return dot ( p0, p1 );
+ // expected-error at -1 {{2nd argument must be a vector, integer or floating point type (was 'bool')}}
+}
+
+float test_dot_scalar_mismatch ( float p0, int p1 ) {
+ return dot ( p0, p1 );
+ // expected-error at -1 {{call to 'dot' is ambiguous}}
+}
+
+float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) {
+ return dot ( p0, p1 );
+ // expected-error at -1 {{first two arguments to 'dot' must have the same size}}
+}
+
+float test__no_second_arg ( float2 p0) {
+ return dot ( p0 );
+ // expected-error at -1 {{no matching function for call to 'dot'}}
+}
+
+float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) {
+ return dot ( p0, p1 );
+ // expected-error at -1 {{call to 'dot' is ambiguous}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 2fe4fdfd5953be..c192d4b84417c9 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -19,4 +19,9 @@ def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMe
def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
+
+def int_dx_dot :
+ Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
}
>From dc912a8db2f66dd465feb4023a7019f5b6c73c2c Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Tue, 20 Feb 2024 10:56:04 -0500
Subject: [PATCH 2/6] Add tests for call directly to builtin Add more
robustness to SemaChecking
---
clang/include/clang/Basic/Builtins.td | 2 +-
clang/include/clang/Sema/Sema.h | 2 +
clang/lib/CodeGen/CGBuiltin.cpp | 10 +-
clang/lib/Sema/SemaChecking.cpp | 225 ++++++++++++++-----
clang/test/CodeGenHLSL/builtins/dot.hlsl | 170 ++++++++++++++
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl | 50 ++---
6 files changed, 376 insertions(+), 83 deletions(-)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 771c4f5d4121f4..e3432f7925ba14 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4526,7 +4526,7 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot"];
- let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 2c06c8edca329a..ce2319d4ad58bd 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14123,6 +14123,8 @@ class Sema final {
bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);
+ bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
+ bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 393ab497fb15c4..7bc31055a6b49d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -17977,26 +17977,28 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
llvm::Type *T1 = Op1->getType();
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy()) {
- return Builder.CreateFMul(Op0, Op1, "dx.dot");
+ return Builder.CreateFMul(Op0, Op1, "dx.dot");
}
if (T0->isIntegerTy()) {
- return Builder.CreateMul(Op0, Op1, "dx.dot");
+ return Builder.CreateMul(Op0, Op1, "dx.dot");
}
+ // Bools should have been promoted
assert(
false &&
"Dot product on a scalar is only supported on integers and floats.");
}
+ // A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");
- // NOTE: this assert will need to be revisited after overload resoltion
- // PR merges.
+ // A vector sext or sitofp should have happened
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
+ // A HLSLVectorTruncation should have happend
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index b55866f0b20c62..0c638a02e08ccf 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5166,69 +5166,167 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
return false;
}
-// Note: returning true in this case results in CheckBuiltinFunctionCall
-// returning an ExprError
-bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
- switch (BuiltinID) {
- case Builtin::BI__builtin_hlsl_dot: {
- if (checkArgCount(*this, TheCall, 2)) {
+bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
+ unsigned NumArgs = TheCall->getNumArgs();
+
+ for (unsigned i = 0; i < NumArgs; ++i) {
+ ExprResult A = TheCall->getArg(i);
+ if (!A.get()->getType()->isBooleanType())
+ return false;
+ }
+ // if we got here all args are bool
+ for (unsigned i = 0; i < NumArgs; ++i) {
+ ExprResult A = TheCall->getArg(i);
+ ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
+ Sema::AA_Converting);
+ if (ResA.isInvalid())
return true;
- }
- Expr *Arg0 = TheCall->getArg(0);
- QualType ArgTy0 = Arg0->getType();
+ TheCall->setArg(0, ResA.get());
+ }
+ return false;
+}
- Expr *Arg1 = TheCall->getArg(1);
- QualType ArgTy1 = Arg1->getType();
+int overloadOrder(Sema *S, QualType ArgTyA) {
+ auto kind = ArgTyA->getAs<BuiltinType>()->getKind();
+ switch (kind) {
+ case BuiltinType::Short:
+ case BuiltinType::UShort:
+ return 1;
+ case BuiltinType::Int:
+ case BuiltinType::UInt:
+ return 2;
+ case BuiltinType::Long:
+ case BuiltinType::ULong:
+ return 3;
+ case BuiltinType::LongLong:
+ case BuiltinType::ULongLong:
+ return 4;
+ case BuiltinType::Float16:
+ case BuiltinType::Half:
+ return 5;
+ case BuiltinType::Float:
+ return 6;
+ default:
+ break;
+ }
+ return 0;
+}
- auto *VecTy0 = ArgTy0->getAs<VectorType>();
- auto *VecTy1 = ArgTy1->getAs<VectorType>();
- SourceLocation BuiltinLoc = TheCall->getBeginLoc();
+QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) {
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
+ QualType VecTyAElem = VecTyA->getElementType();
+ QualType VecTyBElem = VecTyB->getElementType();
+ int vecAElemWidth = overloadOrder(S, VecTyAElem);
+ int vecBElemWidth = overloadOrder(S, VecTyBElem);
+ return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB;
+}
- // if arg0 is bool then call Diag with err_builtin_invalid_arg_type
- if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
- return true;
- }
+void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
+ assert(TheCall->getNumArgs() > 1);
+ ExprResult A = TheCall->getArg(0);
+ ExprResult B = TheCall->getArg(1);
+ QualType ArgTyA = A.get()->getType();
+ QualType ArgTyB = B.get()->getType();
- // if arg1 is bool then call Diag with err_builtin_invalid_arg_type
- if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
- return true;
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
+ if (VecTyA == nullptr && VecTyB == nullptr)
+ return;
+ if (VecTyA == nullptr || VecTyB == nullptr)
+ return;
+ if (VecTyA->getNumElements() == VecTyB->getNumElements())
+ return;
+
+ Expr *LargerArg = B.get();
+ Expr *SmallerArg = A.get();
+ int largerIndex = 1;
+ if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
+ LargerArg = A.get();
+ SmallerArg = B.get();
+ largerIndex = 0;
+ }
+ S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
+ << LargerArg->getType() << SmallerArg->getType()
+ << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
+ ExprResult ResLargerArg = S->ImpCastExprToType(
+ LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
+ TheCall->setArg(largerIndex, ResLargerArg.get());
+ return;
+}
+
+bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+ assert(TheCall->getNumArgs() > 1);
+ ExprResult A = TheCall->getArg(0);
+ ExprResult B = TheCall->getArg(1);
+ QualType ArgTyA = A.get()->getType();
+ QualType ArgTyB = B.get()->getType();
+
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
+ if (VecTyA == nullptr && VecTyB == nullptr)
+ return false;
+ if (VecTyA && VecTyB) {
+ if (VecTyA->getElementType() == VecTyB->getElementType()) {
+ TheCall->setType(VecTyA->getElementType());
+ return false;
+ }
+ SourceLocation BuiltinLoc = TheCall->getBeginLoc();
+ QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB);
+ if (CastType == ArgTyA) {
+ ExprResult ResB = S->SemaConvertVectorExpr(
+ B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc,
+ B.get()->getBeginLoc());
+ TheCall->setArg(1, ResB.get());
+ TheCall->setType(VecTyA->getElementType());
+ return false;
}
- if (VecTy0 == nullptr && VecTy1 == nullptr) {
- if (ArgTy0 != ArgTy1) {
- return true;
- } else {
- return false;
- }
+ if (CastType == ArgTyB) {
+ ExprResult ResA = S->SemaConvertVectorExpr(
+ A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc,
+ A.get()->getBeginLoc());
+ TheCall->setArg(0, ResA.get());
+ TheCall->setType(VecTyB->getElementType());
+ return false;
}
+ return false;
+ }
- if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
- (VecTy0 != nullptr && VecTy1 == nullptr)) {
+ if (VecTyB) {
+ // Convert to the vector result type
+ ExprResult ResA = A;
+ if (VecTyB->getElementType() != ArgTyA)
+ ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(),
+ CK_FloatingCast);
+ ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat);
+ TheCall->setArg(0, ResA.get());
+ }
+ if (VecTyA) {
+ ExprResult ResB = B;
+ if (VecTyA->getElementType() != ArgTyB)
+ ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(),
+ CK_FloatingCast);
+ ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat);
+ TheCall->setArg(1, ResB.get());
+ }
+ return false;
+}
- Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
- << TheCall->getDirectCallee()
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
- TheCall->getArg(1)->getEndLoc());
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+ switch (BuiltinID) {
+ case Builtin::BI__builtin_hlsl_dot: {
+ if (checkArgCount(*this, TheCall, 2))
return true;
- }
-
- if (VecTy0->getElementType() != VecTy1->getElementType()) {
- // Note: This case should never happen. If type promotion occurs
- // then element types won't be different. This diag error is here
- // b\c EmitHLSLBuiltinExpr asserts on this case.
- Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
- << TheCall->getDirectCallee()
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
- TheCall->getArg(1)->getEndLoc());
+ if (PromoteBoolsToInt(this, TheCall))
return true;
- }
- if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
- Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
- << TheCall->getDirectCallee()
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
- TheCall->getArg(1)->getEndLoc());
+ if (PromoteVectorElementCallArgs(this, TheCall))
+ return true;
+ PromoteVectorArgTruncation(this, TheCall);
+ if (SemaBuiltinVectorToScalarMath(TheCall))
return true;
- }
break;
}
}
@@ -19676,6 +19774,29 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
}
bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
+ QualType Res;
+ bool result = SemaBuiltinVectorMath(TheCall, Res);
+ if (result)
+ return true;
+ TheCall->setType(Res);
+ return false;
+}
+
+bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
+ QualType Res;
+ bool result = SemaBuiltinVectorMath(TheCall, Res);
+ if (result)
+ return true;
+
+ if (auto *VecTy0 = Res->getAs<VectorType>()) {
+ TheCall->setType(VecTy0->getElementType());
+ } else {
+ TheCall->setType(Res);
+ }
+ return false;
+}
+
+bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
if (checkArgCount(*this, TheCall, 2))
return true;
@@ -19683,8 +19804,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
ExprResult B = TheCall->getArg(1);
// Do standard promotions between the two arguments, returning their common
// type.
- QualType Res =
- UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
+ Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
if (A.isInvalid() || B.isInvalid())
return true;
@@ -19701,7 +19821,6 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
TheCall->setArg(0, A.get());
TheCall->setArg(1, B.get());
- TheCall->setType(Res);
return false;
}
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 9a895cd190ba9f..b2cd3b6302af6a 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -55,6 +55,34 @@ uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) {
uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) {
return dot ( p0, p1 );
}
+
+// NATIVE_HALF: %conv = sitofp <2 x i16> %1 to <2 x float>
+// NATIVE_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %conv)
+// NATIVE_HALF: ret float %dx.dot
+float test_builtin_dot_vec_int16_to_float_promotion( float2 p0, int16_t2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// NATIVE_HALF: %conv = sitofp <2 x i16> %1 to <2 x half>
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %conv)
+// NATIVE_HALF: ret half %dx.dot
+half test_builtin_dot_vec_int16_to_half_promotion( half2 p0, int16_t2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// NATIVE_HALF: %conv = sext <2 x i16> %1 to <2 x i32>
+// NATIVE_HALF: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %conv)
+// NATIVE_HALF: ret i32 %dx.dot
+int test_builtin_dot_vec_int16_to_int_promotion( int2 p0, int16_t2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// NATIVE_HALF: %conv = sext <2 x i16> %1 to <2 x i64>
+// NATIVE_HALF: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %conv)
+// NATIVE_HALF: ret i64 %dx.dot
+int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
#endif
// CHECK: %dx.dot = mul i32 %0, %1
@@ -184,6 +212,13 @@ half test_dot_half3 ( half3 p0, half3 p1 ) {
half test_dot_half4 ( half4 p0, half4 p1 ) {
return dot ( p0, p1 );
}
+// NATIVE_HALF: %conv = fpext <2 x half> %1 to <2 x float>
+// NATIVE_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %conv)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_vec_half_to_float_promotion( float2 p0, half2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
// CHECK: %dx.dot = fmul float %0, %1
// CHECK: ret float %dx.dot
@@ -209,8 +244,143 @@ float test_dot_float4 ( float4 p0, float4 p1) {
return dot ( p0, p1 );
}
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float2_splat ( float p0, float2 p1 ) {
+ return dot( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float3_splat ( float p0, float3 p1 ) {
+ return dot( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float4_splat ( float p0, float4 p1 ) {
+ return dot( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_float2_splat ( float p0, float2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_float3_splat ( float p0, float3 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_float4_splat ( float p0, float4 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// CHECK: %conv = sitofp i32 %1 to float
+// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
+// CHECK: ret float %dx.dot
+float test_dot_float2_int_splat ( float2 p0, int p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = sitofp i32 %1 to float
+// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_float2_int_splat ( float2 p0, int p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %conv = sitofp i32 %1 to float
+// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
+// CHECK: ret float %dx.dot
+float test_dot_float3_int_splat ( float3 p0, int p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = sitofp i32 %1 to float
+// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_float3_int_splat ( float3 p0, int p1 ) {
+ return dot ( p0, p1 );
+}
+
// CHECK: %dx.dot = fmul double %0, %1
// CHECK: ret double %dx.dot
double test_dot_double ( double p0, double p1 ) {
return dot ( p0, p1 );
}
+
+// CHECK: %conv = zext i1 %tobool to i32
+// CHECK: %dx.dot = mul i32 %conv, %1
+// CHECK: ret i32 %dx.dot
+int test_dot_bool_scalar_arg0_type_promotion ( bool p0, int p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %conv = zext i1 %tobool to i32
+// CHECK: %dx.dot = mul i32 %0, %conv
+// CHECK: ret i32 %dx.dot
+int test_dot_bool_scalar_arg1_type_promotion ( int p0, bool p1 ) {
+ return dot ( p0, p1 );
+}
+
+// CHECK: %conv1 = uitofp i1 %tobool to double
+// CHECK: %dx.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: ret float %conv2
+float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = uitofp i1 %tobool to double
+// CHECK: %conv1 = fpext float %1 to double
+// CHECK: %dx.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: ret float %conv2
+float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = zext i1 %tobool to i32
+// CHECK: %conv3 = zext i1 %tobool2 to i32
+// CHECK: %dx.dot = mul i32 %conv, %conv3
+// CHECK: ret i32 %dx.dot
+int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = fpext float %0 to double
+// CHECK: %conv1 = sitofp i32 %1 to double
+// CHECK: dx.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: ret float %conv2
+float test_builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-error at -1 {{call to 'dot' is ambiguous}}
+}
+
+// CHECK: %conv = sitofp <2 x i32> %0 to <2 x float>
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %1)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_vec_int_to_float_promotion ( int2 p0, float2 p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = sext <2 x i32> %1 to <2 x i64>
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %conv)
+// CHECK: ret i64 %dx.dot
+int64_t test_builtin_dot_vec_int_to_int64_promotion( int64_t2 p0, int2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
index a5acb400ab9c7b..2f1a833f5ca364 100644
--- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -1,46 +1,46 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
-// NOTE: This test is marked XFAIL because when overload resolution merges
-// NOTE: test_dot_element_type_mismatch & test_dot_scalar_mismatch will have different behavior
-// XFAIL: *
-float test_first_arg_is_not_vector ( float p0, float2 p1 ) {
- return dot ( p0, p1 );
- // expected-error at -1 {{first two arguments to 'dot' must be vectors}}
+float test_no_second_arg ( float2 p0) {
+ return __builtin_hlsl_dot ( p0 );
+ // expected-error at -1 {{too few arguments to function call, expected 2, have 1}}
}
-float test_second_arg_is_not_vector ( float2 p0, float p1 ) {
- return dot ( p0, p1 );
- // expected-error at -1 {{first two arguments to 'dot' must be vectors}}
+float test_too_many_arg ( float2 p0) {
+ return __builtin_hlsl_dot ( p0, p0, p0 );
+ // expected-error at -1 {{too many arguments to function call, expected 2, have 3}}
}
-int test_dot_unsupported_scalar_arg0 ( bool p0, int p1 ) {
- return dot ( p0, p1 );
- // expected-error at -1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
+//NOTE: eventually behavior should match builtin
+float test_dot_no_second_arg ( float2 p0) {
+ return dot ( p0 );
+ // expected-error at -1 {{no matching function for call to 'dot'}}
}
-int test_dot_unsupported_scalar_arg1 ( int p0, bool p1 ) {
+float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) {
return dot ( p0, p1 );
- // expected-error at -1 {{2nd argument must be a vector, integer or floating point type (was 'bool')}}
+ // expected-warning at -1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float __attribute__((ext_vector_type(2)))' (vector of 2 'float' values)}}
}
-float test_dot_scalar_mismatch ( float p0, int p1 ) {
- return dot ( p0, p1 );
- // expected-error at -1 {{call to 'dot' is ambiguous}}
+float test_dot_builtin_vector_size_mismatch ( float3 p0, float2 p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-warning at -1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float2' (aka 'vector<float, 2>')}}
}
-float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) {
- return dot ( p0, p1 );
- // expected-error at -1 {{first two arguments to 'dot' must have the same size}}
-}
-float test__no_second_arg ( float2 p0) {
- return dot ( p0 );
- // expected-error at -1 {{no matching function for call to 'dot'}}
+//NOTE: this case runs into the same problem as the below example
+//int Fn1(int p0, int p1);
+//int Fn1(float p0, float p1);
+//int test_dot_scalar_mismatch ( float p0, int p1 ) {
+// return Fn1( p0, p1 );
+//}
+float test_dot_scalar_mismatch ( float p0, int p1 ) {
+ return dot ( p0, p1 );
+ // expected-error at -1 {{call to 'dot' is ambiguous}}
}
float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) {
return dot ( p0, p1 );
// expected-error at -1 {{call to 'dot' is ambiguous}}
-}
+}
\ No newline at end of file
>From eb438b7e16ac7ff49e8a0318a231801eeee2684a Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Thu, 22 Feb 2024 21:41:22 -0500
Subject: [PATCH 3/6] remove type promotion changes
---
.../clang/Basic/DiagnosticSemaKinds.td | 5 +-
clang/lib/CodeGen/CGBuiltin.cpp | 4 +-
clang/lib/Sema/SemaChecking.cpp | 134 ++++++++----------
.../CodeGenHLSL/builtins/dot-builtin.hlsl | 83 +++++++++++
clang/test/CodeGenHLSL/builtins/dot.hlsl | 120 ----------------
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl | 50 ++++++-
6 files changed, 192 insertions(+), 204 deletions(-)
create mode 100644 clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 9cce89b92be309..e7c11e1a1e1d82 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10270,8 +10270,6 @@ def err_vec_builtin_non_vector : Error<
"first two arguments to %0 must be vectors">;
def err_vec_builtin_incompatible_vector : Error<
"first two arguments to %0 must have the same type">;
-def err_vec_builtin_incompatible_size : Error<
- "first two arguments to %0 must have the same size">;
def err_vsx_builtin_nonconstant_argument : Error<
"argument %0 to %1 must be a 2-bit unsigned literal (i.e. 0, 1, 2 or 3)">;
@@ -12120,6 +12118,9 @@ def err_hlsl_param_qualifier_mismatch :
def warn_hlsl_impcast_vector_truncation : Warning<
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
+def warn_hlsl_impcast_bitwidth_reduction : Warning<
+ "implicit conversion from larger type: %0 to smaller type %1, possible loss of data">, InGroup<Conversion>;
+
// Layout randomization diagnostics.
def err_non_designated_init_used : Error<
"a randomized struct can only be initialized with a designated initializer">;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7bc31055a6b49d..3b682f5bf6e770 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -17977,11 +17977,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
llvm::Type *T1 = Op1->getType();
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy()) {
- return Builder.CreateFMul(Op0, Op1, "dx.dot");
+ return Builder.CreateFMul(Op0, Op1, "dx.dot");
}
if (T0->isIntegerTy()) {
- return Builder.CreateMul(Op0, Op1, "dx.dot");
+ return Builder.CreateMul(Op0, Op1, "dx.dot");
}
// Bools should have been promoted
assert(
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 0c638a02e08ccf..a734160f7547a0 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5166,6 +5166,9 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
return false;
}
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Note: UsualArithmeticConversions handles the case where at least
+// one arg isn't a bool
bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
unsigned NumArgs = TheCall->getNumArgs();
@@ -5181,47 +5184,13 @@ bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
Sema::AA_Converting);
if (ResA.isInvalid())
return true;
- TheCall->setArg(0, ResA.get());
+ TheCall->setArg(i, ResA.get());
}
return false;
}
-int overloadOrder(Sema *S, QualType ArgTyA) {
- auto kind = ArgTyA->getAs<BuiltinType>()->getKind();
- switch (kind) {
- case BuiltinType::Short:
- case BuiltinType::UShort:
- return 1;
- case BuiltinType::Int:
- case BuiltinType::UInt:
- return 2;
- case BuiltinType::Long:
- case BuiltinType::ULong:
- return 3;
- case BuiltinType::LongLong:
- case BuiltinType::ULongLong:
- return 4;
- case BuiltinType::Float16:
- case BuiltinType::Half:
- return 5;
- case BuiltinType::Float:
- return 6;
- default:
- break;
- }
- return 0;
-}
-
-QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) {
- auto *VecTyA = ArgTyA->getAs<VectorType>();
- auto *VecTyB = ArgTyB->getAs<VectorType>();
- QualType VecTyAElem = VecTyA->getElementType();
- QualType VecTyBElem = VecTyB->getElementType();
- int vecAElemWidth = overloadOrder(S, VecTyAElem);
- int vecBElemWidth = overloadOrder(S, VecTyBElem);
- return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB;
-}
-
+// Helper function for CheckHLSLBuiltinFunctionCall
+// Handles the CK_HLSLVectorTruncation case for builtins
void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
@@ -5246,6 +5215,7 @@ void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
SmallerArg = B.get();
largerIndex = 0;
}
+
S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
<< LargerArg->getType() << SmallerArg->getType()
<< LargerArg->getSourceRange() << SmallerArg->getSourceRange();
@@ -5255,61 +5225,79 @@ void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
return;
}
-bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+// Helper function for CheckHLSLBuiltinFunctionCall
+void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
+ SourceRange targetSrcRange,
+ SourceLocation BuiltinLoc) {
+ auto *vecTyTarget = source.get()->getType()->getAs<VectorType>();
+ assert(vecTyTarget);
+ QualType vecElemT = vecTyTarget->getElementType();
+ if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
+ QualType floatVecTy = S->Context.getVectorType(
+ S->Context.FloatTy, vecTyTarget->getNumElements(), VectorKind::Generic);
+ int floatByteSize =
+ S->Context.getTypeSizeInChars(S->Context.FloatTy).getQuantity();
+ int vecElemByteSize = S->Context.getTypeSizeInChars(vecElemT).getQuantity();
+ if (vecElemByteSize > floatByteSize)
+ S->Diag(BuiltinLoc, diag::warn_hlsl_impcast_bitwidth_reduction)
+ << source.get()->getType() << floatVecTy
+ << source.get()->getSourceRange() << targetSrcRange;
+
+ source = S->SemaConvertVectorExpr(
+ source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
+ source.get()->getBeginLoc());
+ }
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+void PromoteVectorArgSplat(Sema *S, ExprResult &source, QualType targetTy) {
+ QualType sourceTy = source.get()->getType();
+ auto *vecTyTarget = targetTy->getAs<VectorType>();
+ QualType vecElemT = vecTyTarget->getElementType();
+ if (vecElemT->isFloatingType() && sourceTy != vecElemT)
+ // if float vec splat wil do an unnecessary cast to double
+ source = S->ImpCastExprToType(source.get(), vecElemT, CK_FloatingCast);
+ source = S->ImpCastExprToType(source.get(), targetTy, CK_VectorSplat);
+}
+
+// Helper function for CheckHLSLBuiltinFunctionCall
+bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
QualType ArgTyA = A.get()->getType();
QualType ArgTyB = B.get()->getType();
-
auto *VecTyA = ArgTyA->getAs<VectorType>();
auto *VecTyB = ArgTyB->getAs<VectorType>();
+
if (VecTyA == nullptr && VecTyB == nullptr)
return false;
+
if (VecTyA && VecTyB) {
if (VecTyA->getElementType() == VecTyB->getElementType()) {
TheCall->setType(VecTyA->getElementType());
return false;
}
- SourceLocation BuiltinLoc = TheCall->getBeginLoc();
- QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB);
- if (CastType == ArgTyA) {
- ExprResult ResB = S->SemaConvertVectorExpr(
- B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc,
- B.get()->getBeginLoc());
- TheCall->setArg(1, ResB.get());
- TheCall->setType(VecTyA->getElementType());
- return false;
- }
-
- if (CastType == ArgTyB) {
- ExprResult ResA = S->SemaConvertVectorExpr(
- A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc,
- A.get()->getBeginLoc());
- TheCall->setArg(0, ResA.get());
- TheCall->setType(VecTyB->getElementType());
- return false;
- }
- return false;
+ // Note: type promotion is intended to be handeled via the intrinsics
+ // and not the builtin itself.
+ S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
+ << TheCall->getDirectCallee()
+ << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
+ return true;
}
if (VecTyB) {
- // Convert to the vector result type
- ExprResult ResA = A;
- if (VecTyB->getElementType() != ArgTyA)
- ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(),
- CK_FloatingCast);
- ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat);
- TheCall->setArg(0, ResA.get());
+ CheckVectorFloatPromotion(S, B, ArgTyA, A.get()->getSourceRange(),
+ TheCall->getBeginLoc());
+ PromoteVectorArgSplat(S, A, B.get()->getType());
}
if (VecTyA) {
- ExprResult ResB = B;
- if (VecTyA->getElementType() != ArgTyB)
- ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(),
- CK_FloatingCast);
- ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat);
- TheCall->setArg(1, ResB.get());
+ CheckVectorFloatPromotion(S, A, ArgTyB, B.get()->getSourceRange(),
+ TheCall->getBeginLoc());
+ PromoteVectorArgSplat(S, B, A.get()->getType());
}
+ TheCall->setArg(0, A.get());
+ TheCall->setArg(1, B.get());
return false;
}
@@ -5322,7 +5310,7 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
if (PromoteBoolsToInt(this, TheCall))
return true;
- if (PromoteVectorElementCallArgs(this, TheCall))
+ if (CheckVectorElementCallArgs(this, TheCall))
return true;
PromoteVectorArgTruncation(this, TheCall);
if (SemaBuiltinVectorToScalarMath(TheCall))
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
new file mode 100644
index 00000000000000..d68c9c11289c60
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -0,0 +1,83 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_float2_splat ( float p0, float2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_float3_splat ( float p0, float3 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_float4_splat ( float p0, float4 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+}
+
+// CHECK: %conv = sitofp i32 %1 to float
+// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
+// CHECK: ret float %dx.dot
+float test_dot_float2_int_splat ( float2 p0, int p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = sitofp i32 %1 to float
+// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
+// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
+// CHECK: ret float %dx.dot
+float test_dot_float3_int_splat ( float3 p0, int p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv1 = uitofp i1 %tobool to double
+// CHECK: %dx.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: ret float %conv2
+float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = uitofp i1 %tobool to double
+// CHECK: %conv1 = fpext float %1 to double
+// CHECK: %dx.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: ret float %conv2
+float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = zext i1 %tobool to i32
+// CHECK: %conv3 = zext i1 %tobool2 to i32
+// CHECK: %dx.dot = mul i32 %conv, %conv3
+// CHECK: ret i32 %dx.dot
+int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+// CHECK: %conv = fpext float %0 to double
+// CHECK: %conv1 = sitofp i32 %1 to double
+// CHECK: dx.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: ret float %conv2
+float test_builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
+
+
+// CHECK: %conv = sitofp <2 x i32> %0 to <2 x float>
+// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
+// CHECK: ret float %dx.dot
+float test_builtin_dot_int_vect_to_float_vec_promotion ( int2 p0, float p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+}
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index b2cd3b6302af6a..4f14b100b33bbb 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -55,34 +55,6 @@ uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) {
uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) {
return dot ( p0, p1 );
}
-
-// NATIVE_HALF: %conv = sitofp <2 x i16> %1 to <2 x float>
-// NATIVE_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %conv)
-// NATIVE_HALF: ret float %dx.dot
-float test_builtin_dot_vec_int16_to_float_promotion( float2 p0, int16_t2 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// NATIVE_HALF: %conv = sitofp <2 x i16> %1 to <2 x half>
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %conv)
-// NATIVE_HALF: ret half %dx.dot
-half test_builtin_dot_vec_int16_to_half_promotion( half2 p0, int16_t2 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// NATIVE_HALF: %conv = sext <2 x i16> %1 to <2 x i32>
-// NATIVE_HALF: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %conv)
-// NATIVE_HALF: ret i32 %dx.dot
-int test_builtin_dot_vec_int16_to_int_promotion( int2 p0, int16_t2 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// NATIVE_HALF: %conv = sext <2 x i16> %1 to <2 x i64>
-// NATIVE_HALF: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %conv)
-// NATIVE_HALF: ret i64 %dx.dot
-int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
#endif
// CHECK: %dx.dot = mul i32 %0, %1
@@ -212,13 +184,6 @@ half test_dot_half3 ( half3 p0, half3 p1 ) {
half test_dot_half4 ( half4 p0, half4 p1 ) {
return dot ( p0, p1 );
}
-// NATIVE_HALF: %conv = fpext <2 x half> %1 to <2 x float>
-// NATIVE_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %conv)
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_vec_half_to_float_promotion( float2 p0, half2 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
// CHECK: %dx.dot = fmul float %0, %1
// CHECK: ret float %dx.dot
@@ -262,33 +227,6 @@ float test_dot_float4_splat ( float p0, float4 p1 ) {
return dot( p0, p1 );
}
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_float2_splat ( float p0, float2 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_float3_splat ( float p0, float3 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_float4_splat ( float p0, float4 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// CHECK: %conv = sitofp i32 %1 to float
-// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
-// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
-// CHECK: ret float %dx.dot
-float test_dot_float2_int_splat ( float2 p0, int p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
// CHECK: %conv = sitofp i32 %1 to float
// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
@@ -298,15 +236,6 @@ float test_builtin_dot_float2_int_splat ( float2 p0, int p1 ) {
return dot ( p0, p1 );
}
-// CHECK: %conv = sitofp i32 %1 to float
-// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
-// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
-// CHECK: ret float %dx.dot
-float test_dot_float3_int_splat ( float3 p0, int p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
// CHECK: %conv = sitofp i32 %1 to float
// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
@@ -335,52 +264,3 @@ int test_dot_bool_scalar_arg0_type_promotion ( bool p0, int p1 ) {
int test_dot_bool_scalar_arg1_type_promotion ( int p0, bool p1 ) {
return dot ( p0, p1 );
}
-
-// CHECK: %conv1 = uitofp i1 %tobool to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
-// CHECK: ret float %conv2
-float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
-// CHECK: %conv = uitofp i1 %tobool to double
-// CHECK: %conv1 = fpext float %1 to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
-// CHECK: ret float %conv2
-float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
-// CHECK: %conv = zext i1 %tobool to i32
-// CHECK: %conv3 = zext i1 %tobool2 to i32
-// CHECK: %dx.dot = mul i32 %conv, %conv3
-// CHECK: ret i32 %dx.dot
-int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
-// CHECK: %conv = fpext float %0 to double
-// CHECK: %conv1 = sitofp i32 %1 to double
-// CHECK: dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
-// CHECK: ret float %conv2
-float test_builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
- // expected-error at -1 {{call to 'dot' is ambiguous}}
-}
-
-// CHECK: %conv = sitofp <2 x i32> %0 to <2 x float>
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %1)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_vec_int_to_float_promotion ( int2 p0, float2 p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
-// CHECK: %conv = sext <2 x i32> %1 to <2 x i64>
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %conv)
-// CHECK: ret i64 %dx.dot
-int64_t test_builtin_dot_vec_int_to_int64_promotion( int64_t2 p0, int2 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
index 2f1a833f5ca364..50e0144ee7a304 100644
--- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -28,13 +28,11 @@ float test_dot_builtin_vector_size_mismatch ( float3 p0, float2 p1 ) {
// expected-warning at -1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float2' (aka 'vector<float, 2>')}}
}
+float test_dot_builtin_vector_elem_size_reduction ( int64_t2 p0, float p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-warning at -1 {{conversion from larger type: 'int64_t2' (aka 'vector<int64_t, 2>') to smaller type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values), possible loss of data}}
+}
-//NOTE: this case runs into the same problem as the below example
-//int Fn1(int p0, int p1);
-//int Fn1(float p0, float p1);
-//int test_dot_scalar_mismatch ( float p0, int p1 ) {
-// return Fn1( p0, p1 );
-//}
float test_dot_scalar_mismatch ( float p0, int p1 ) {
return dot ( p0, p1 );
// expected-error at -1 {{call to 'dot' is ambiguous}}
@@ -43,4 +41,42 @@ float test_dot_scalar_mismatch ( float p0, int p1 ) {
float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) {
return dot ( p0, p1 );
// expected-error at -1 {{call to 'dot' is ambiguous}}
-}
\ No newline at end of file
+}
+
+//NOTE: for all the *_promotion we are intentionally not handling type promotion in builtins
+float test_builtin_dot_vec_int_to_float_promotion ( int2 p0, float2 p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
+}
+
+int64_t test_builtin_dot_vec_int_to_int64_promotion( int64_t2 p0, int2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
+}
+
+float test_builtin_dot_vec_half_to_float_promotion( float2 p0, half2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
+}
+
+#ifdef __HLSL_ENABLE_16_BIT
+float test_builtin_dot_vec_int16_to_float_promotion( float2 p0, int16_t2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
+}
+
+half test_builtin_dot_vec_int16_to_half_promotion( half2 p0, int16_t2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
+}
+
+int test_builtin_dot_vec_int16_to_int_promotion( int2 p0, int16_t2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
+}
+
+int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
+}
+#endif
>From 8d8d1ff964edbe7d323842f393121ec60ff57d26 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Fri, 23 Feb 2024 13:09:23 -0500
Subject: [PATCH 4/6] remove hlsl error to go with a more generic one
---
clang/include/clang/Basic/DiagnosticSemaKinds.td | 3 ---
clang/lib/Sema/SemaChecking.cpp | 10 +++-------
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl | 9 ++-------
clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl | 14 ++++++++++++++
4 files changed, 19 insertions(+), 17 deletions(-)
create mode 100644 clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index e7c11e1a1e1d82..a7f2858477bee6 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12118,9 +12118,6 @@ def err_hlsl_param_qualifier_mismatch :
def warn_hlsl_impcast_vector_truncation : Warning<
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
-def warn_hlsl_impcast_bitwidth_reduction : Warning<
- "implicit conversion from larger type: %0 to smaller type %1, possible loss of data">, InGroup<Conversion>;
-
// Layout randomization diagnostics.
def err_non_designated_init_used : Error<
"a randomized struct can only be initialized with a designated initializer">;
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a734160f7547a0..b0ff1b7e6c308e 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5235,14 +5235,10 @@ void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
QualType floatVecTy = S->Context.getVectorType(
S->Context.FloatTy, vecTyTarget->getNumElements(), VectorKind::Generic);
- int floatByteSize =
- S->Context.getTypeSizeInChars(S->Context.FloatTy).getQuantity();
- int vecElemByteSize = S->Context.getTypeSizeInChars(vecElemT).getQuantity();
- if (vecElemByteSize > floatByteSize)
- S->Diag(BuiltinLoc, diag::warn_hlsl_impcast_bitwidth_reduction)
- << source.get()->getType() << floatVecTy
- << source.get()->getSourceRange() << targetSrcRange;
+ S->Diag(BuiltinLoc, diag::warn_impcast_integer_float_precision)
+ << source.get()->getType() << floatVecTy
+ << source.get()->getSourceRange() << targetSrcRange;
source = S->SemaConvertVectorExpr(
source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
source.get()->getBeginLoc());
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
index 50e0144ee7a304..334132d8412956 100644
--- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -1,6 +1,6 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
-// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
-// RUN: -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm \
+// RUN: -disable-llvm-passes -verify -verify-ignore-unexpected
float test_no_second_arg ( float2 p0) {
return __builtin_hlsl_dot ( p0 );
@@ -28,11 +28,6 @@ float test_dot_builtin_vector_size_mismatch ( float3 p0, float2 p1 ) {
// expected-warning at -1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float2' (aka 'vector<float, 2>')}}
}
-float test_dot_builtin_vector_elem_size_reduction ( int64_t2 p0, float p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
- // expected-warning at -1 {{conversion from larger type: 'int64_t2' (aka 'vector<int64_t, 2>') to smaller type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values), possible loss of data}}
-}
-
float test_dot_scalar_mismatch ( float p0, int p1 ) {
return dot ( p0, p1 );
// expected-error at -1 {{call to 'dot' is ambiguous}}
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl
new file mode 100644
index 00000000000000..fbdeb5064f6a8c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl
@@ -0,0 +1,14 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm \
+// RUN: -disable-llvm-passes -Wimplicit-int-float-conversion-verify -verify-ignore-unexpected
+
+
+float test_dot_builtin_vector_elem_size_reduction ( int64_t2 p0, float p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-Warning at -1 {{implicit conversion from 'int64_t2' (aka 'vector<int64_t, 2>') to '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values) may lose precision}}
+}
+
+float test_dot_builtin_int_vector_elem_size_reduction ( int2 p0, float p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-Warning at -1 {{implicit conversion from 'int2' (aka 'vector<int, 2>') to '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values) may lose precision}}
+}
>From 080b65fcce390fffa4914e3fd8972284511a2600 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Fri, 23 Feb 2024 18:09:10 -0500
Subject: [PATCH 5/6] remove all type promotion
---
clang/lib/CodeGen/CGBuiltin.cpp | 6 +-
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 6 +
clang/lib/Sema/SemaChecking.cpp | 160 ++++--------------
.../CodeGenHLSL/builtins/dot-builtin.hlsl | 63 +------
clang/test/CodeGenHLSL/builtins/dot.hlsl | 6 +-
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl | 41 ++++-
6 files changed, 90 insertions(+), 192 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 3b682f5bf6e770..e4da4e9681dcd2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -17996,8 +17996,10 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");
- auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
- auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
+ [[maybe_unused]] auto *VecTy0 =
+ E->getArg(0)->getType()->getAs<VectorType>();
+ [[maybe_unused]] auto *VecTy1 =
+ E->getArg(1)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index a92f0d0849ba77..08e5d981a4a4ca 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -182,6 +182,12 @@ double4 cos(double4);
//===----------------------------------------------------------------------===//
// dot product builtins
//===----------------------------------------------------------------------===//
+
+/// \fn K dot(T X, T Y)
+/// \brief Return the dot product (a scalar value) of \a X and \a Y.
+/// \param X The X input value.
+/// \param Y The Y input value.
+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
half dot(half, half);
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index b0ff1b7e6c308e..984088e345c806 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2962,9 +2962,8 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
}
}
- if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) {
+ if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall))
return ExprError();
- }
// Since the target specific builtins for each arch overlap, only check those
// of the arch we are compiling for.
@@ -5166,96 +5165,6 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
return false;
}
-// Helper function for CheckHLSLBuiltinFunctionCall
-// Note: UsualArithmeticConversions handles the case where at least
-// one arg isn't a bool
-bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
- unsigned NumArgs = TheCall->getNumArgs();
-
- for (unsigned i = 0; i < NumArgs; ++i) {
- ExprResult A = TheCall->getArg(i);
- if (!A.get()->getType()->isBooleanType())
- return false;
- }
- // if we got here all args are bool
- for (unsigned i = 0; i < NumArgs; ++i) {
- ExprResult A = TheCall->getArg(i);
- ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
- Sema::AA_Converting);
- if (ResA.isInvalid())
- return true;
- TheCall->setArg(i, ResA.get());
- }
- return false;
-}
-
-// Helper function for CheckHLSLBuiltinFunctionCall
-// Handles the CK_HLSLVectorTruncation case for builtins
-void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
- assert(TheCall->getNumArgs() > 1);
- ExprResult A = TheCall->getArg(0);
- ExprResult B = TheCall->getArg(1);
- QualType ArgTyA = A.get()->getType();
- QualType ArgTyB = B.get()->getType();
-
- auto *VecTyA = ArgTyA->getAs<VectorType>();
- auto *VecTyB = ArgTyB->getAs<VectorType>();
- if (VecTyA == nullptr && VecTyB == nullptr)
- return;
- if (VecTyA == nullptr || VecTyB == nullptr)
- return;
- if (VecTyA->getNumElements() == VecTyB->getNumElements())
- return;
-
- Expr *LargerArg = B.get();
- Expr *SmallerArg = A.get();
- int largerIndex = 1;
- if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
- LargerArg = A.get();
- SmallerArg = B.get();
- largerIndex = 0;
- }
-
- S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
- << LargerArg->getType() << SmallerArg->getType()
- << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
- ExprResult ResLargerArg = S->ImpCastExprToType(
- LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
- TheCall->setArg(largerIndex, ResLargerArg.get());
- return;
-}
-
-// Helper function for CheckHLSLBuiltinFunctionCall
-void CheckVectorFloatPromotion(Sema *S, ExprResult &source, QualType targetTy,
- SourceRange targetSrcRange,
- SourceLocation BuiltinLoc) {
- auto *vecTyTarget = source.get()->getType()->getAs<VectorType>();
- assert(vecTyTarget);
- QualType vecElemT = vecTyTarget->getElementType();
- if (!vecElemT->isFloatingType() && targetTy->isFloatingType()) {
- QualType floatVecTy = S->Context.getVectorType(
- S->Context.FloatTy, vecTyTarget->getNumElements(), VectorKind::Generic);
-
- S->Diag(BuiltinLoc, diag::warn_impcast_integer_float_precision)
- << source.get()->getType() << floatVecTy
- << source.get()->getSourceRange() << targetSrcRange;
- source = S->SemaConvertVectorExpr(
- source.get(), S->Context.CreateTypeSourceInfo(floatVecTy), BuiltinLoc,
- source.get()->getBeginLoc());
- }
-}
-
-// Helper function for CheckHLSLBuiltinFunctionCall
-void PromoteVectorArgSplat(Sema *S, ExprResult &source, QualType targetTy) {
- QualType sourceTy = source.get()->getType();
- auto *vecTyTarget = targetTy->getAs<VectorType>();
- QualType vecElemT = vecTyTarget->getElementType();
- if (vecElemT->isFloatingType() && sourceTy != vecElemT)
- // if float vec splat wil do an unnecessary cast to double
- source = S->ImpCastExprToType(source.get(), vecElemT, CK_FloatingCast);
- source = S->ImpCastExprToType(source.get(), targetTy, CK_VectorSplat);
-}
-
// Helper function for CheckHLSLBuiltinFunctionCall
bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
@@ -5265,36 +5174,42 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
QualType ArgTyB = B.get()->getType();
auto *VecTyA = ArgTyA->getAs<VectorType>();
auto *VecTyB = ArgTyB->getAs<VectorType>();
-
+ SourceLocation BuiltinLoc = TheCall->getBeginLoc();
if (VecTyA == nullptr && VecTyB == nullptr)
return false;
if (VecTyA && VecTyB) {
- if (VecTyA->getElementType() == VecTyB->getElementType()) {
- TheCall->setType(VecTyA->getElementType());
- return false;
+ bool retValue = false;
+ if (VecTyA->getElementType() != VecTyB->getElementType()) {
+ // Note: type promotion is intended to be handeled via the intrinsics
+ // and not the builtin itself.
+ S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
+ << TheCall->getDirectCallee()
+ << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
+ retValue = true;
+ }
+ if (VecTyA->getNumElements() != VecTyB->getNumElements()) {
+ // if we get here a HLSLVectorTruncation is needed.
+ S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
+ << TheCall->getDirectCallee()
+ << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+ TheCall->getArg(1)->getEndLoc());
+ retValue = true;
}
- // Note: type promotion is intended to be handeled via the intrinsics
- // and not the builtin itself.
- S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
- << TheCall->getDirectCallee()
- << SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());
- return true;
- }
- if (VecTyB) {
- CheckVectorFloatPromotion(S, B, ArgTyA, A.get()->getSourceRange(),
- TheCall->getBeginLoc());
- PromoteVectorArgSplat(S, A, B.get()->getType());
- }
- if (VecTyA) {
- CheckVectorFloatPromotion(S, A, ArgTyB, B.get()->getSourceRange(),
- TheCall->getBeginLoc());
- PromoteVectorArgSplat(S, B, A.get()->getType());
+ if (retValue)
+ TheCall->setType(VecTyA->getElementType());
+
+ return retValue;
}
- TheCall->setArg(0, A.get());
- TheCall->setArg(1, B.get());
- return false;
+
+ // Note: if we get here one of the args is a scalar which
+ // requires a VectorSplat on Arg0 or Arg1
+ S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
+ << TheCall->getDirectCallee()
+ << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+ TheCall->getArg(1)->getEndLoc());
+ return true;
}
// Note: returning true in this case results in CheckBuiltinFunctionCall
@@ -5304,11 +5219,8 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
case Builtin::BI__builtin_hlsl_dot: {
if (checkArgCount(*this, TheCall, 2))
return true;
- if (PromoteBoolsToInt(this, TheCall))
- return true;
if (CheckVectorElementCallArgs(this, TheCall))
return true;
- PromoteVectorArgTruncation(this, TheCall);
if (SemaBuiltinVectorToScalarMath(TheCall))
return true;
break;
@@ -19759,8 +19671,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
QualType Res;
- bool result = SemaBuiltinVectorMath(TheCall, Res);
- if (result)
+ if (SemaBuiltinVectorMath(TheCall, Res))
return true;
TheCall->setType(Res);
return false;
@@ -19768,15 +19679,14 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
QualType Res;
- bool result = SemaBuiltinVectorMath(TheCall, Res);
- if (result)
+ if (SemaBuiltinVectorMath(TheCall, Res))
return true;
- if (auto *VecTy0 = Res->getAs<VectorType>()) {
+ if (auto *VecTy0 = Res->getAs<VectorType>())
TheCall->setType(VecTy0->getElementType());
- } else {
+ else
TheCall->setType(Res);
- }
+
return false;
}
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
index d68c9c11289c60..9881dabc3a1106 100644
--- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -1,44 +1,6 @@
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
-// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
-// RUN: -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \
-// RUN: --check-prefixes=CHECK
-
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_float2_splat ( float p0, float2 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_float3_splat ( float p0, float3 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_float4_splat ( float p0, float4 p1 ) {
- return __builtin_hlsl_dot( p0, p1 );
-}
-
-// CHECK: %conv = sitofp i32 %1 to float
-// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
-// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
-// CHECK: ret float %dx.dot
-float test_dot_float2_int_splat ( float2 p0, int p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
-// CHECK: %conv = sitofp i32 %1 to float
-// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
-// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
-// CHECK: ret float %dx.dot
-float test_dot_float3_int_splat ( float3 p0, int p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+// CHECK-LABEL: builtin_bool_to_float_type_promotion
// CHECK: %conv1 = uitofp i1 %tobool to double
// CHECK: %dx.dot = fmul double %conv, %conv1
// CHECK: %conv2 = fptrunc double %dx.dot to float
@@ -47,6 +9,7 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
}
+// CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
// CHECK: %conv = uitofp i1 %tobool to double
// CHECK: %conv1 = fpext float %1 to double
// CHECK: %dx.dot = fmul double %conv, %conv1
@@ -56,28 +19,12 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
}
-// CHECK: %conv = zext i1 %tobool to i32
-// CHECK: %conv3 = zext i1 %tobool2 to i32
-// CHECK: %dx.dot = mul i32 %conv, %conv3
-// CHECK: ret i32 %dx.dot
-int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
+// CHECK-LABEL: builtin_dot_int_to_float_promotion
// CHECK: %conv = fpext float %0 to double
// CHECK: %conv1 = sitofp i32 %1 to double
// CHECK: dx.dot = fmul double %conv, %conv1
// CHECK: %conv2 = fptrunc double %dx.dot to float
// CHECK: ret float %conv2
-float test_builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
-}
-
-
-// CHECK: %conv = sitofp <2 x i32> %0 to <2 x float>
-// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %splat.splat)
-// CHECK: ret float %dx.dot
-float test_builtin_dot_int_vect_to_float_vec_promotion ( int2 p0, float p1 ) {
+float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
}
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 4f14b100b33bbb..25f2f867845e9f 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -1,8 +1,8 @@
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
-// RUN: -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
index 334132d8412956..825bb7a337c626 100644
--- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -1,6 +1,4 @@
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
-// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm \
-// RUN: -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
float test_no_second_arg ( float2 p0) {
return __builtin_hlsl_dot ( p0 );
@@ -25,7 +23,7 @@ float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) {
float test_dot_builtin_vector_size_mismatch ( float3 p0, float2 p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
- // expected-warning at -1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float2' (aka 'vector<float, 2>')}}
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
}
float test_dot_scalar_mismatch ( float p0, int p1 ) {
@@ -75,3 +73,38 @@ int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1
// expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must have the same type}}
}
#endif
+
+float test_builtin_dot_float2_splat ( float p0, float2 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
+}
+
+float test_builtin_dot_float3_splat ( float p0, float3 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
+}
+
+float test_builtin_dot_float4_splat ( float p0, float4 p1 ) {
+ return __builtin_hlsl_dot( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
+}
+
+float test_dot_float2_int_splat ( float2 p0, int p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
+}
+
+float test_dot_float3_int_splat ( float3 p0, int p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
+}
+
+float test_builtin_dot_int_vect_to_float_vec_promotion ( int2 p0, float p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-error at -1 {{first two arguments to '__builtin_hlsl_dot' must be vectors}}
+}
+
+int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) {
+ return __builtin_hlsl_dot ( p0, p1 );
+ // expected-error at -1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
+}
>From 37091d136d03face541c65fe91838d9f425eb5c8 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Fri, 23 Feb 2024 20:00:49 -0500
Subject: [PATCH 6/6] address pr comments
---
clang/lib/CodeGen/CGBuiltin.cpp | 12 +++---
clang/test/CodeGenHLSL/builtins/dot.hlsl | 13 +++---
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl | 1 -
clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl | 14 -------
.../test/SemaHLSL/OverloadResolutionBugs.hlsl | 40 +++++++++++++++++++
5 files changed, 51 insertions(+), 29 deletions(-)
delete mode 100644 clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e4da4e9681dcd2..54d7451a9d6221 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -17976,17 +17976,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
llvm::Type *T0 = Op0->getType();
llvm::Type *T1 = Op1->getType();
if (!T0->isVectorTy() && !T1->isVectorTy()) {
- if (T0->isFloatingPointTy()) {
+ if (T0->isFloatingPointTy())
return Builder.CreateFMul(Op0, Op1, "dx.dot");
- }
- if (T0->isIntegerTy()) {
+ if (T0->isIntegerTy())
return Builder.CreateMul(Op0, Op1, "dx.dot");
- }
+
// Bools should have been promoted
- assert(
- false &&
- "Dot product on a scalar is only supported on integers and floats.");
+ llvm_unreachable(
+ "Scalar dot product is only supported on ints and floats.");
}
// A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 25f2f867845e9f..b2c1bae31d13b1 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -6,7 +6,6 @@
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
-// -fnative-half-type sets __HLSL_ENABLE_16_BIT
#ifdef __HLSL_ENABLE_16_BIT
// NATIVE_HALF: %dx.dot = mul i16 %0, %1
// NATIVE_HALF: ret i16 %dx.dot
@@ -113,19 +112,19 @@ int64_t test_dot_long ( int64_t p0, int64_t p1 ) {
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
// CHECK: ret i64 %dx.dot
-int64_t test_dot_uint2 ( int64_t2 p0, int64_t2 p1 ) {
+int64_t test_dot_long2 ( int64_t2 p0, int64_t2 p1 ) {
return dot ( p0, p1 );
}
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
// CHECK: ret i64 %dx.dot
-int64_t test_dot_uint3 ( int64_t3 p0, int64_t3 p1 ) {
+int64_t test_dot_long3 ( int64_t3 p0, int64_t3 p1 ) {
return dot ( p0, p1 );
}
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
// CHECK: ret i64 %dx.dot
-int64_t test_dot_uint4 ( int64_t4 p0, int64_t4 p1 ) {
+int64_t test_dot_long4 ( int64_t4 p0, int64_t4 p1 ) {
return dot ( p0, p1 );
}
@@ -137,19 +136,19 @@ uint64_t test_dot_ulong ( uint64_t p0, uint64_t p1 ) {
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
// CHECK: ret i64 %dx.dot
-uint64_t test_dot_uint2 ( uint64_t2 p0, uint64_t2 p1 ) {
+uint64_t test_dot_ulong2 ( uint64_t2 p0, uint64_t2 p1 ) {
return dot ( p0, p1 );
}
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
// CHECK: ret i64 %dx.dot
-uint64_t test_dot_uint3 ( uint64_t3 p0, uint64_t3 p1 ) {
+uint64_t test_dot_ulong3 ( uint64_t3 p0, uint64_t3 p1 ) {
return dot ( p0, p1 );
}
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
// CHECK: ret i64 %dx.dot
-uint64_t test_dot_uint4 ( uint64_t4 p0, uint64_t4 p1 ) {
+uint64_t test_dot_ulong4 ( uint64_t4 p0, uint64_t4 p1 ) {
return dot ( p0, p1 );
}
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
index 825bb7a337c626..54d093aa7ce3a4 100644
--- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -10,7 +10,6 @@ float test_too_many_arg ( float2 p0) {
// expected-error at -1 {{too many arguments to function call, expected 2, have 3}}
}
-//NOTE: eventually behavior should match builtin
float test_dot_no_second_arg ( float2 p0) {
return dot ( p0 );
// expected-error at -1 {{no matching function for call to 'dot'}}
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl
deleted file mode 100644
index fbdeb5064f6a8c..00000000000000
--- a/clang/test/SemaHLSL/BuiltIns/dot-warning.hlsl
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
-// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm \
-// RUN: -disable-llvm-passes -Wimplicit-int-float-conversion-verify -verify-ignore-unexpected
-
-
-float test_dot_builtin_vector_elem_size_reduction ( int64_t2 p0, float p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
- // expected-Warning at -1 {{implicit conversion from 'int64_t2' (aka 'vector<int64_t, 2>') to '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values) may lose precision}}
-}
-
-float test_dot_builtin_int_vector_elem_size_reduction ( int2 p0, float p1 ) {
- return __builtin_hlsl_dot ( p0, p1 );
- // expected-Warning at -1 {{implicit conversion from 'int2' (aka 'vector<int, 2>') to '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values) may lose precision}}
-}
diff --git a/clang/test/SemaHLSL/OverloadResolutionBugs.hlsl b/clang/test/SemaHLSL/OverloadResolutionBugs.hlsl
index 135d6cf335c133..8464f1c1a7c2cd 100644
--- a/clang/test/SemaHLSL/OverloadResolutionBugs.hlsl
+++ b/clang/test/SemaHLSL/OverloadResolutionBugs.hlsl
@@ -24,6 +24,46 @@ void Call4(int16_t H) {
Fn4(H);
}
+int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) {
+ return dot ( p0, p1 );
+}
+
+float test_dot_scalar_mismatch ( float p0, int p1 ) {
+ return dot ( p0, p1 );
+}
+
+float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+float test_builtin_dot_vec_int_to_float_promotion ( int2 p0, float2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+int64_t test_builtin_dot_vec_int_to_int64_promotion( int64_t2 p0, int2 p1 ) {
+ return dot ( p0, p1 );
+}
+
+float test_builtin_dot_vec_half_to_float_promotion( float2 p0, half2 p1 ) {
+ return dot( p0, p1 );
+}
+
+float test_builtin_dot_vec_int16_to_float_promotion( float2 p0, int16_t2 p1 ) {
+ return dot( p0, p1 );
+}
+
+half test_builtin_dot_vec_int16_to_half_promotion( half2 p0, int16_t2 p1 ) {
+ return dot( p0, p1 );
+}
+
+int test_builtin_dot_vec_int16_to_int_promotion( int2 p0, int16_t2 p1 ) {
+ return dot( p0, p1 );
+}
+
+int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1 ) {
+ return dot( p0, p1 );
+}
+
// https://github.com/llvm/llvm-project/issues/81049
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
More information about the llvm-commits
mailing list