[clang] [llvm] [HLSL] Implementation of dot intrinsic (PR #81190)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 8 13:25:53 PST 2024


https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/81190

This change implements #70073

HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td This is used to associate all the dot product typdef defined hlsl_intrinsics.h with a single intrinsic check in CGBuiltin.cpp

In IntrinsicsDirectX.td we define the llvmIR for the dot product. A few goals were in mind for this IR. First it should operate on only vectors. Second the return type should be the vector element type. Third the second parameter vector should be of the same size as the first parameter. Finally `a dot b` should be the same as `b dot a`.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot product though is a target specific intrinsic and so needed to establish a pattern for Target builtins via EmitDXILBuiltinExpr. The call chain looks like this now: EmitBuiltinExpr -> EmitTargetBuiltinExpr -> EmitTargetArchBuiltinExpr -> EmitDXILBuiltinExp

EmitDXILBuiltinExp dot product intrinsics makes a destinction between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.

>From c966e50e1be171ce6a642083508faf43ae5f220a Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Thu, 8 Feb 2024 11:08:59 -0500
Subject: [PATCH] [HLSL] Implementation of dot intrinsic This change implements
 #70073

HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td
This is used to associate all the dot product typdef defined hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp

In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally `a dot b` should be the same as `b dot a`.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot
product though is a target specific intrinsic and so needed to establish
a pattern for Target builtins via EmitDXILBuiltinExpr.
The call chain looks like this now: EmitBuiltinExpr -> EmitTargetBuiltinExpr -> EmitTargetArchBuiltinExpr -> EmitDXILBuiltinExp

EmitDXILBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.
---
 clang/include/clang/Basic/Builtins.td     |   6 +
 clang/lib/CodeGen/CGBuiltin.cpp           |  49 ++++
 clang/lib/CodeGen/CodeGenFunction.h       |   1 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h  |  86 +++++++
 clang/test/CodeGenHLSL/builtins/dot.hlsl  | 260 ++++++++++++++++++++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td |   7 +-
 6 files changed, 408 insertions(+), 1 deletion(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/dot.hlsl

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 31a2bdeb2d3e5..1d6fd969900ea 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4518,6 +4518,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void*(unsigned char)";
 }
 
+def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_dot"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a7a410dab1a01..6916f40d265c7 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -44,6 +44,7 @@
 #include "llvm/IR/IntrinsicsAMDGPU.h"
 #include "llvm/IR/IntrinsicsARM.h"
 #include "llvm/IR/IntrinsicsBPF.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/IntrinsicsHexagon.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
@@ -6018,6 +6019,8 @@ static Value *EmitTargetArchBuiltinExpr(CodeGenFunction *CGF,
   case llvm::Triple::bpfeb:
   case llvm::Triple::bpfel:
     return CGF->EmitBPFBuiltinExpr(BuiltinID, E);
+  case llvm::Triple::dxil:
+    return CGF->EmitDXILBuiltinExpr(BuiltinID, E);
   case llvm::Triple::x86:
   case llvm::Triple::x86_64:
     return CGF->EmitX86BuiltinExpr(BuiltinID, E);
@@ -17895,6 +17898,52 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
+Value *CodeGenFunction::EmitDXILBuiltinExpr(unsigned BuiltinID,
+                                            const CallExpr *E) {
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    llvm::Type *T0 = Op0->getType();
+    llvm::Type *T1 = Op1->getType();
+    if (!T0->isVectorTy() && !T1->isVectorTy()) {
+      if (T0->isFloatingPointTy()) {
+        return Builder.CreateFMul(Op0, Op1, "dx.dot");
+      }
+
+      if (T0->isIntegerTy()) {
+        return Builder.CreateMul(Op0, Op1, "dx.dot");
+      }
+      ErrorUnsupported(
+          E,
+          "Dot product on a scalar is only supported on integers and floats.");
+    }
+
+    if (T0->isVectorTy() && T1->isVectorTy()) {
+
+      if (T0->getScalarType() != T1->getScalarType()) {
+        ErrorUnsupported(E,
+                         "Dot product of vectors need the same element types.");
+      }
+
+      auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
+      auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
+      if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
+        ErrorUnsupported(
+            E, "Dot product requires vectors to be of the same size.");
+      }
+
+      return Builder.CreateIntrinsic(
+          /*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
+          ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+    }
+    ErrorUnsupported(E, "Dot product of vector and scalar is not supported.");
+
+  } break;
+  }
+  return nullptr;
+}
+
 Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
                                               const CallExpr *E) {
   llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 143ad64e8816b..1632acd2a059f 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4392,6 +4392,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
+  llvm::Value *EmitDXILBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
                                            const CallExpr *E);
   llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index da153d8f8e034..2b45d68166aeb 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -144,6 +144,92 @@ double3 cos(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
 double4 cos(double4);
 
+//===----------------------------------------------------------------------===//
+// dot product builtins
+//===----------------------------------------------------------------------===//
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half, half);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half2, half2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half3, half3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half4, half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t, int16_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t2, int16_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t3, int16_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t4, int16_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t, uint16_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t2, uint16_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t3, uint16_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t4, uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float4, float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double, double);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double2, double2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double3, double3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double4, double4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int, int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int2, int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int3, int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int4, int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint4, uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t, int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t2, int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t3, int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t4, int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t, uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t2, uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t3, uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t4, uint64_t4);
+
 //===----------------------------------------------------------------------===//
 // floor builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
new file mode 100644
index 0000000000000..506f98a31a2fb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -0,0 +1,260 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -D__HLSL_ENABLE_16_BIT -o - | FileCheck %s --check-prefix=NO_HALF
+
+#ifdef __HLSL_ENABLE_16_BIT
+// CHECK: %dx.dot = mul i16 %0, %1
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = mul i16 %0, %1
+// NO_HALF: ret i16 %dx.dot
+int16_t test_dot_short ( int16_t p0, int16_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+int16_t test_dot_short2 ( int16_t2 p0, int16_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+int16_t test_dot_short3 ( int16_t3 p0, int16_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+int16_t test_dot_short4 ( int16_t4 p0, int16_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i16 %0, %1
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = mul i16 %0, %1
+// NO_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort ( uint16_t p0, uint16_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort2 ( uint16_t2 p0, uint16_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// CHECK: ret i16 %dx.dot
+// NO_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NO_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+#endif
+
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+int test_dot_int ( int p0, int p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int2 ( int2 p0, int2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int3 ( int3 p0, int3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int4 ( int4 p0, int4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint ( uint p0, uint p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint2 ( uint2 p0, uint2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint3 ( uint3 p0, uint3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint4 ( uint4 p0, uint4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_long ( int64_t p0, int64_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_long2 ( int64_t2 p0, int64_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_long3 ( int64_t3 p0, int64_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_long4 ( int64_t4 p0, int64_t4 p1) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK:  %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_ulong ( uint64_t p0, uint64_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_ulong2 ( uint64_t2 p0, uint64_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_ulong3 ( uint64_t3 p0, uint64_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_ulong4 ( uint64_t4 p0, uint64_t4 p1) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = fmul half %0, %1
+// CHECK: ret half %dx.dot
+// NO_HALF: %dx.dot = fmul float %0, %1
+// NO_HALF: ret float %dx.dot
+half test_dot_half ( half p0, half p1 ) {
+  return dot ( p0, p1 );
+}
+
+
+// CHECK: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// CHECK: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half2 ( half2 p0, half2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// CHECK: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half3 ( half3 p0, half3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// CHECK: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half4 ( half4 p0, half4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: define noundef float @
+// CHECK: %dx.dot = fmul float %0, %1
+// CHECK: ret float %dx.dot
+float test_dot_float ( float p0, float p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float2 ( float2 p0, float2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float3 ( float3 p0, float3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %dx.dot
+float test_dot_float4 ( float4 p0, float4 p1) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: define noundef double @
+// CHECK: %dx.dot = fmul double %0, %1
+// CHECK: ret double %dx.dot
+double test_dot_double ( double p0, double p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call double @llvm.dx.dot.v2f64(<2 x double> %0, <2 x double> %1)
+// CHECK: ret double %dx.dot
+double test_dot_double2 ( double2 p0, double2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call double @llvm.dx.dot.v3f64(<3 x double> %0, <3 x double> %1)
+// CHECK: ret double %dx.dot
+double test_dot_double3 ( double3 p0, double3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call double @llvm.dx.dot.v4f64(<4 x double> %0, <4 x double> %1)
+// CHECK: ret double %dx.dot
+double test_dot_double4 ( double4 p0, double4 p1) {
+  return dot ( p0, p1 );
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 2fe4fdfd5953b..269f221d1be4a 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -19,4 +19,9 @@ def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMe
 
 def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
     Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
-}
+
+def int_dx_dot : 
+    Intrinsic<[LLVMVectorElementType<0>], 
+    [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [IntrNoMem, IntrWillReturn, Commutative] >;
+}
\ No newline at end of file



More information about the llvm-commits mailing list