[clang] [llvm] [HLSL] Implementation of dot intrinsic (PR #81190)

Farzon Lotfi via cfe-commits cfe-commits at lists.llvm.org
Mon Feb 12 18:41:03 PST 2024


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/81190

>From 73cc9fde36a44ba1715a3c9fc6d48196602d5dc4 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Thu, 8 Feb 2024 11:08:59 -0500
Subject: [PATCH] [HLSL] Implementation of dot intrinsic This change implements
 #70073

HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in BuiltinsHLSL.td.
This is used to associate all the dot product typdef defined hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp.

As a side note adding the dot product intrinsic to BuiltinsHLSL.td had a
significant impact on re-compile time speeds. I recommend we move the
other hlsl functions here. Further it lets us tap into the existing
target specifc code organizations that exist in Sema and CodeGen.

In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally `a dot b` should be the same as `b dot a`.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot
product though is a target specific intrinsic and so needed to establish
a pattern for Target builtins via EmitDXILBuiltinExpr.
The call chain looks like this now: EmitBuiltinExpr -> EmitTargetBuiltinExpr -> EmitTargetArchBuiltinExpr -> EmitDXILBuiltinExp

EmitDXILBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.

Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a target specific semantic validation that can be expanded for other hlsl specific intrinsics.
---
 clang/include/clang/Basic/BuiltinsHLSL.td    |  15 ++
 clang/include/clang/Basic/CMakeLists.txt     |   6 +-
 clang/include/clang/Basic/TargetBuiltins.h   |  10 +
 clang/include/clang/Sema/Sema.h              |   1 +
 clang/lib/Basic/Targets/DirectX.cpp          |  13 ++
 clang/lib/Basic/Targets/DirectX.h            |   4 +-
 clang/lib/CodeGen/CGBuiltin.cpp              |  41 ++++
 clang/lib/CodeGen/CodeGenFunction.h          |   1 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h     |  68 +++++++
 clang/lib/Sema/SemaChecking.cpp              |  71 ++++++-
 clang/test/CodeGenHLSL/builtins/dot.hlsl     | 202 +++++++++++++++++++
 clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl |  43 ++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td    |   5 +
 13 files changed, 471 insertions(+), 9 deletions(-)
 create mode 100644 clang/include/clang/Basic/BuiltinsHLSL.td
 create mode 100644 clang/test/CodeGenHLSL/builtins/dot.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl

diff --git a/clang/include/clang/Basic/BuiltinsHLSL.td b/clang/include/clang/Basic/BuiltinsHLSL.td
new file mode 100644
index 00000000000000..edff801f905641
--- /dev/null
+++ b/clang/include/clang/Basic/BuiltinsHLSL.td
@@ -0,0 +1,15 @@
+//===--- BuiltinsHLSL.td - HLSL Builtin function database ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+include "clang/Basic/BuiltinsBase.td"
+
+def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_dot"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
diff --git a/clang/include/clang/Basic/CMakeLists.txt b/clang/include/clang/Basic/CMakeLists.txt
index 7785fb430c069b..01233eb44feffc 100644
--- a/clang/include/clang/Basic/CMakeLists.txt
+++ b/clang/include/clang/Basic/CMakeLists.txt
@@ -65,7 +65,11 @@ clang_tablegen(BuiltinsBPF.inc -gen-clang-builtins
   SOURCE BuiltinsBPF.td
   TARGET ClangBuiltinsBPF)
 
-clang_tablegen(BuiltinsRISCV.inc -gen-clang-builtins
+clang_tablegen(BuiltinsHLSL.inc -gen-clang-builtins
+  SOURCE BuiltinsHLSL.td
+  TARGET ClangBuiltinsHLSL)
+
+  clang_tablegen(BuiltinsRISCV.inc -gen-clang-builtins
   SOURCE BuiltinsRISCV.td
   TARGET ClangBuiltinsRISCV)
 
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index 4333830bf34f24..15cf111ae5e9e5 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -89,6 +89,16 @@ namespace clang {
   };
   }
 
+  /// HLSL builtins
+  namespace hlsl {
+  enum {
+    LastTIBuiltin = clang::Builtin::FirstTSBuiltin - 1,
+#define BUILTIN(ID, TYPE, ATTRS) BI##ID,
+#include "clang/Basic/BuiltinsHLSL.inc"
+    LastTSBuiltin
+  };
+  } // namespace hlsl
+
   /// PPC builtins
   namespace PPC {
     enum {
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index ed933f27f8df6b..ce08669b1b073d 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14016,6 +14016,7 @@ class Sema final {
   bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
                                    CallExpr *TheCall);
   bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+  bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
   bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
   bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
                                      CallExpr *TheCall);
diff --git a/clang/lib/Basic/Targets/DirectX.cpp b/clang/lib/Basic/Targets/DirectX.cpp
index 0dd27e6e93b33b..4a9536612eefa6 100644
--- a/clang/lib/Basic/Targets/DirectX.cpp
+++ b/clang/lib/Basic/Targets/DirectX.cpp
@@ -12,11 +12,24 @@
 
 #include "DirectX.h"
 #include "Targets.h"
+#include "clang/Basic/MacroBuilder.h"
+#include "clang/Basic/TargetBuiltins.h"
 
 using namespace clang;
 using namespace clang::targets;
 
+static constexpr Builtin::Info BuiltinInfo[] = {
+#define BUILTIN(ID, TYPE, ATTRS)                                               \
+  {#ID, TYPE, ATTRS, nullptr, HeaderDesc::NO_HEADER, ALL_LANGUAGES},
+#include "clang/Basic/BuiltinsHLSL.inc"
+};
+
 void DirectXTargetInfo::getTargetDefines(const LangOptions &Opts,
                                          MacroBuilder &Builder) const {
   DefineStd(Builder, "DIRECTX", Opts);
 }
+
+ArrayRef<Builtin::Info> DirectXTargetInfo::getTargetBuiltins() const {
+  return llvm::ArrayRef(BuiltinInfo,
+                        clang::hlsl::LastTSBuiltin - Builtin::FirstTSBuiltin);
+}
\ No newline at end of file
diff --git a/clang/lib/Basic/Targets/DirectX.h b/clang/lib/Basic/Targets/DirectX.h
index acfcc8c47ba950..99c9522b0dd542 100644
--- a/clang/lib/Basic/Targets/DirectX.h
+++ b/clang/lib/Basic/Targets/DirectX.h
@@ -73,9 +73,7 @@ class LLVM_LIBRARY_VISIBILITY DirectXTargetInfo : public TargetInfo {
     return Feature == "directx";
   }
 
-  ArrayRef<Builtin::Info> getTargetBuiltins() const override {
-    return std::nullopt;
-  }
+  ArrayRef<Builtin::Info> getTargetBuiltins() const override;
 
   std::string_view getClobbers() const override { return ""; }
 
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a7a410dab1a018..ea1ed56ebca3ed 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -44,6 +44,7 @@
 #include "llvm/IR/IntrinsicsAMDGPU.h"
 #include "llvm/IR/IntrinsicsARM.h"
 #include "llvm/IR/IntrinsicsBPF.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/IntrinsicsHexagon.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
@@ -6018,6 +6019,8 @@ static Value *EmitTargetArchBuiltinExpr(CodeGenFunction *CGF,
   case llvm::Triple::bpfeb:
   case llvm::Triple::bpfel:
     return CGF->EmitBPFBuiltinExpr(BuiltinID, E);
+  case llvm::Triple::dxil:
+    return CGF->EmitDXILBuiltinExpr(BuiltinID, E);
   case llvm::Triple::x86:
   case llvm::Triple::x86_64:
     return CGF->EmitX86BuiltinExpr(BuiltinID, E);
@@ -17895,6 +17898,44 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
+Value *CodeGenFunction::EmitDXILBuiltinExpr(unsigned BuiltinID,
+                                            const CallExpr *E) {
+  switch (BuiltinID) {
+  case hlsl::BI__builtin_hlsl_dot: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    llvm::Type *T0 = Op0->getType();
+    llvm::Type *T1 = Op1->getType();
+    if (!T0->isVectorTy() && !T1->isVectorTy()) {
+      if (T0->isFloatingPointTy()) {
+        return Builder.CreateFMul(Op0, Op1, "dx.dot");
+      }
+
+      if (T0->isIntegerTy()) {
+        return Builder.CreateMul(Op0, Op1, "dx.dot");
+      }
+      assert(
+          false &&
+          "Dot product on a scalar is only supported on integers and floats.");
+    }
+    assert(T0->isVectorTy() && T1->isVectorTy() &&
+           "Dot product of vector and scalar is not supported.");
+    assert(T0->getScalarType() == T1->getScalarType() &&
+           "Dot product of vectors need the same element types.");
+
+    auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
+    auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
+    assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
+           "Dot product requires vectors to be of the same size.");
+
+    return Builder.CreateIntrinsic(
+        /*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+  } break;
+  }
+  return nullptr;
+}
+
 Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
                                               const CallExpr *E) {
   llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index fc9b32878068c1..e585ff5d66111a 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4395,6 +4395,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
+  llvm::Value *EmitDXILBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
                                            const CallExpr *E);
   llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index da153d8f8e0349..9295ef3093c4f8 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -144,6 +144,74 @@ double3 cos(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
 double4 cos(double4);
 
+//===----------------------------------------------------------------------===//
+// dot product builtins
+//===----------------------------------------------------------------------===//
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half, half);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half2, half2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half3, half3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half4, half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t, int16_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t2, int16_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t3, int16_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t4, int16_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t, uint16_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t2, uint16_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t3, uint16_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t4, uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float4, float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double, double);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int, int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int2, int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int3, int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int4, int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint4, uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t, int64_t);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t, uint64_t);
+
 //===----------------------------------------------------------------------===//
 // floor builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index afe2673479e40a..3e6ccf40a7cddf 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2084,6 +2084,8 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
     return CheckBPFBuiltinFunctionCall(BuiltinID, TheCall);
   case llvm::Triple::hexagon:
     return CheckHexagonBuiltinFunctionCall(BuiltinID, TheCall);
+  case llvm::Triple::dxil:
+    return CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall);
   case llvm::Triple::mips:
   case llvm::Triple::mipsel:
   case llvm::Triple::mips64:
@@ -2120,10 +2122,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
 // not a valid type, emit an error message and return true. Otherwise return
 // false.
 static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
-                                        QualType Ty) {
-  if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
+                                        QualType ArgTy, int ArgIndex) {
+  if (!ArgTy->getAs<VectorType>() &&
+      !ConstantMatrixType::isValidElementType(ArgTy)) {
     return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << 1 << /* vector, integer or float ty*/ 0 << Ty;
+           << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
   }
 
   return false;
@@ -5158,6 +5161,64 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
   return false;
 }
 
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+  switch (BuiltinID) {
+  case hlsl::BI__builtin_hlsl_dot: {
+    if (checkArgCount(*this, TheCall, 2)) {
+      return true;
+    }
+    Expr *Arg0 = TheCall->getArg(0);
+    QualType ArgTy0 = Arg0->getType();
+
+    Expr *Arg1 = TheCall->getArg(1);
+    QualType ArgTy1 = Arg1->getType();
+
+    auto *VecTy0 = ArgTy0->getAs<VectorType>();
+    auto *VecTy1 = ArgTy1->getAs<VectorType>();
+    SourceLocation BuiltinLoc = TheCall->getBeginLoc();
+
+    // if arg0 is bool then call Diag with err_builtin_invalid_arg_type
+    if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
+      return true;
+    }
+
+    // if arg1 is bool then call Diag with err_builtin_invalid_arg_type
+    if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
+      return true;
+    }
+
+    if (VecTy0 == nullptr && VecTy1 == nullptr) {
+      if (ArgTy0 != ArgTy1) {
+        return true;
+      } else {
+        return false;
+      }
+    }
+
+    if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
+        (VecTy0 != nullptr && VecTy1 == nullptr)) {
+
+      Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
+          << TheCall->getDirectCallee()
+          << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+                         TheCall->getArg(1)->getEndLoc());
+      return true;
+    }
+
+    if (VecTy0->getElementType() != VecTy1->getElementType()) {
+      return true;
+    }
+    if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
+      return true;
+    }
+    break;
+  }
+  }
+  return false;
+}
+
 bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
                                           CallExpr *TheCall) {
   // position of memory order and scope arguments in the builtin
@@ -19576,7 +19637,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
   TheCall->setArg(0, A.get());
   QualType TyA = A.get()->getType();
 
-  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
+  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
     return true;
 
   TheCall->setType(TyA);
@@ -19604,7 +19665,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
                 diag::err_typecheck_call_different_arg_types)
            << TyA << TyB;
 
-  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
+  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
     return true;
 
   TheCall->setArg(0, A.get());
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
new file mode 100644
index 00000000000000..f995585e85f8af
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -0,0 +1,202 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -D__HLSL_ENABLE_16_BIT -o - | FileCheck %s --check-prefix=NO_HALF
+
+#ifdef __HLSL_ENABLE_16_BIT
+// CHECK: %dx.dot = mul i16 %0, %1
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = mul i16 %0, %1
+// NO_HALF: ret i16 %dx.dot
+int16_t test_dot_short ( int16_t p0, int16_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+int16_t test_dot_short2 ( int16_t2 p0, int16_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+int16_t test_dot_short3 ( int16_t3 p0, int16_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+int16_t test_dot_short4 ( int16_t4 p0, int16_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i16 %0, %1
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = mul i16 %0, %1
+// NO_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort ( uint16_t p0, uint16_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort2 ( uint16_t2 p0, uint16_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+#endif
+
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+int test_dot_int ( int p0, int p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int2 ( int2 p0, int2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int3 ( int3 p0, int3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int4 ( int4 p0, int4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint ( uint p0, uint p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint2 ( uint2 p0, uint2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint3 ( uint3 p0, uint3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint4 ( uint4 p0, uint4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_long ( int64_t p0, int64_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK:  %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_ulong ( uint64_t p0, uint64_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = fmul half %0, %1
+// CHECK: ret half %dx.dot
+// NO_HALF: %dx.dot = fmul float %0, %1
+// NO_HALF: ret float %dx.dot
+half test_dot_half ( half p0, half p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// CHECK: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half2 ( half2 p0, half2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// CHECK: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half3 ( half3 p0, half3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// CHECK: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half4 ( half4 p0, half4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: define noundef float @
+// CHECK: %dx.dot = fmul float %0, %1
+// CHECK: ret float %dx.dot
+float test_dot_float ( float p0, float p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float2 ( float2 p0, float2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float3 ( float3 p0, float3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float4 ( float4 p0, float4 p1) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: define noundef double @
+// CHECK: %dx.dot = fmul double %0, %1
+// CHECK: ret double %dx.dot
+double test_dot_double ( double p0, double p1 ) {
+  return dot ( p0, p1 );
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
new file mode 100644
index 00000000000000..d4751fda973f96
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -0,0 +1,43 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+float test_first_arg_is_not_vector ( float p0, float2 p1 ) {
+  return dot ( p0, p1 );
+  // expected-error at -1 {{first two arguments to 'dot' must be vectors}}
+}
+
+float test_second_arg_is_not_vector ( float2 p0, float p1 ) {
+  return dot ( p0, p1 );
+  // expected-error at -1 {{first two arguments to 'dot' must be vectors}}
+}
+
+int test_dot_unsupported_scalar_arg0 ( bool p0, int p1 ) {
+  return dot ( p0, p1 );
+  // expected-error at -1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
+}
+
+int test_dot_unsupported_scalar_arg1 ( int p0, bool p1 ) {
+  return dot ( p0, p1 );
+  // expected-error at -1 {{2nd argument must be a vector, integer or floating point type (was 'bool')}}
+}
+
+float test_dot_scalar_mismatch ( float p0, int p1 ) {
+  return dot ( p0, p1 );
+  // expected-error at -1 {{call to 'dot' is ambiguous}}
+}
+
+float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) {
+  return dot ( p0, p1 );
+  // expected-error at -1 {{no matching function for call to 'dot'}}
+}
+
+float test__no_second_arg ( float2 p0) {
+  return dot ( p0 );
+  // expected-error at -1 {{no matching function for call to 'dot'}}
+}
+
+float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) {
+  return dot ( p0, p1 );
+  // expected-error at -1 {{no matching function for call to 'dot'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 2fe4fdfd5953be..c192d4b84417c9 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -19,4 +19,9 @@ def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMe
 
 def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
     Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
+
+def int_dx_dot : 
+    Intrinsic<[LLVMVectorElementType<0>], 
+    [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [IntrNoMem, IntrWillReturn, Commutative] >;
 }



More information about the cfe-commits mailing list