[clang] [llvm] Add Float `Dot` Intrinsic Lowering (PR #86071)

via cfe-commits cfe-commits at lists.llvm.org
Wed Mar 20 20:20:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-llvm-ir

Author: Farzon Lotfi (farzonl)

<details>
<summary>Changes</summary>

Completes #<!-- -->83626
- `CGBuiltin.cpp` - modify `getDotProductIntrinsic` to be able to emit `dot2`, `dot3`, and `dot4` intrinsics based on element count
- `IntrinsicsDirectX.td` - for floating point add `dot2`, `dot3`, and `dot4` inntrinsics -`DXIL.td` add dxilop intrinsic lowering for `dot2`, `dot3`, & `dot4`. 
- `DXILOpLowering.cpp` - add vector arg flattening for dot product. 
- `DXILOpBuilder.h` - modify `createDXILOpCall` to take a smallVector instead of an iterator
- `DXILOpBuilder.cpp` - modify `createDXILOpCall` by moving the small vector up to the calling function in `DXILOpLowering.cpp`. 
  - Moving one function up gives us access to the `CallInst` and `Function` which were needed to distinguish the dot product intrinsics and get the operands without using the iterator.

---

Patch is 20.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86071.diff


11 Files Affected:

- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16-9) 
- (modified) clang/test/CodeGenHLSL/builtins/dot.hlsl (+14-14) 
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+9-1) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+9) 
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+3-5) 
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+2-3) 
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+53-2) 
- (added) llvm/test/CodeGen/DirectX/dot2_error.ll (+10) 
- (added) llvm/test/CodeGen/DirectX/dot3_error.ll (+10) 
- (added) llvm/test/CodeGen/DirectX/dot4_error.ll (+10) 
- (added) llvm/test/CodeGen/DirectX/fdot.ll (+94) 


``````````diff
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 77cb269d43c5a8..a4b99181769326 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18036,15 +18036,22 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
-Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
+  if (QT->hasFloatingRepresentation()) {
+    switch (elementCount) {
+    case 2:
+      return Intrinsic::dx_dot2;
+    case 3:
+      return Intrinsic::dx_dot3;
+    case 4:
+      return Intrinsic::dx_dot4;
+    }
+  }
   if (QT->hasSignedIntegerRepresentation())
     return Intrinsic::dx_sdot;
-  if (QT->hasUnsignedIntegerRepresentation())
-    return Intrinsic::dx_udot;
 
-  assert(QT->hasFloatingRepresentation());
-  return Intrinsic::dx_dot;
-  ;
+  assert(QT->hasUnsignedIntegerRepresentation());
+  return Intrinsic::dx_udot;
 }
 
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18098,8 +18105,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     assert(T0->getScalarType() == T1->getScalarType() &&
            "Dot product of vectors need the same element types.");
 
-    [[maybe_unused]] auto *VecTy0 =
-        E->getArg(0)->getType()->getAs<VectorType>();
+    auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
     [[maybe_unused]] auto *VecTy1 =
         E->getArg(1)->getType()->getAs<VectorType>();
     // A HLSLVectorTruncation should have happend
@@ -18108,7 +18114,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
 
     return Builder.CreateIntrinsic(
         /*ReturnType=*/T0->getScalarType(),
-        getDotProductIntrinsic(E->getArg(0)->getType()),
+        getDotProductIntrinsic(E->getArg(0)->getType(),
+                               VecTy0->getNumElements()),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
   } break;
   case Builtin::BI__builtin_hlsl_lerp: {
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 0f993193c00cce..307d71cce3cb6d 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -110,21 +110,21 @@ uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
 // NO_HALF: ret float %dx.dot
 half test_dot_half(half p0, half p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
 // NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
 // NO_HALF: ret float %dx.dot
 half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
 // NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
 // NO_HALF: ret float %dx.dot
 half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
 // NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
 // NO_HALF: ret float %dx.dot
 half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
 
@@ -132,34 +132,34 @@ half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
 // CHECK: ret float %dx.dot
 float test_dot_float(float p0, float p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK:  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK:  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK:  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
 
 // CHECK: %conv = sitofp i32 %1 to float
 // CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %splat.splat)
 // CHECK: ret float %dx.dot
 float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
   return dot(p0, p1);
@@ -168,7 +168,7 @@ float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
 // CHECK: %conv = sitofp i32 %1 to float
 // CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %splat.splat)
 // CHECK: ret float %dx.dot
 float test_builtin_dot_float3_int_splat(float3 p0, int p1) {
   return dot(p0, p1);
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 1164b241ba7b0d..a871fac46b9fd5 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -24,7 +24,15 @@ def int_dx_any  : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; 
 
-def int_dx_dot : 
+def int_dx_dot2 : 
+    Intrinsic<[LLVMVectorElementType<0>], 
+    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_dot3 : 
+    Intrinsic<[LLVMVectorElementType<0>], 
+    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_dot4 : 
     Intrinsic<[LLVMVectorElementType<0>], 
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 36eb29d53766f0..9e393902e2f208 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -295,6 +295,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
                          "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
 def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
                          "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
+def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1",
+                         [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
+def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2",
+                         [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
+def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3",
+                         [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
 def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
                              "Reads the thread ID">;
 def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index a1eacc2d48009c..990557710f8c53 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -254,7 +254,7 @@ namespace dxil {
 
 CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
                                           Type *OverloadTy,
-                                          llvm::iterator_range<Use *> Args) {
+                                          SmallVector<Value *> Args) {
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
 
   OverloadKind Kind = getOverloadKind(OverloadTy);
@@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
     FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
     DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
   }
-  SmallVector<Value *> FullArgs;
-  FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
-  FullArgs.append(Args.begin(), Args.end());
-  return B.CreateCall(DXILFn, FullArgs);
+
+  return B.CreateCall(DXILFn, Args);
 }
 
 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index f3abcc6e02a4e3..5babeae470178b 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -13,7 +13,7 @@
 #define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H
 
 #include "DXILConstants.h"
-#include "llvm/ADT/iterator_range.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace llvm {
 class Module;
@@ -35,8 +35,7 @@ class DXILOpBuilder {
   /// \param OverloadTy Overload type of the DXIL Op call constructed
   /// \return DXIL Op call constructed
   CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
-                             Type *OverloadTy,
-                             llvm::iterator_range<Use *> Args);
+                             Type *OverloadTy, SmallVector<Value *> Args);
   Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);
 
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 3e334b0ec298d3..f09e322f88e1fd 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -30,6 +30,48 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
+static bool isVectorArgExpansion(Function &F) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::dx_dot2:
+  case Intrinsic::dx_dot3:
+  case Intrinsic::dx_dot4:
+    return true;
+  }
+  return false;
+}
+
+static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
+  SmallVector<Value *, 4> ExtractedElements;
+  auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+  for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
+    Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
+    Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
+    ExtractedElements.push_back(ExtractedElement);
+  }
+  return ExtractedElements;
+}
+
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+                                             IRBuilder<> &Builder) {
+  // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+  unsigned NumOperands = Orig->getNumOperands() - 1;
+  assert(NumOperands > 0);
+  Value *Arg0 = Orig->getOperand(0);
+  [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+  assert(VecArg0);
+  SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+  for (unsigned I = 1; I < NumOperands; ++I) {
+    Value *Arg = Orig->getOperand(I);
+    [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+    assert(VecArg);
+    assert(VecArg0->getElementType() == VecArg->getElementType());
+    assert(VecArg0->getNumElements() == VecArg->getNumElements());
+    auto NextOperandList = populateOperands(Arg, Builder);
+    NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+  }
+  return NewOperands;
+}
+
 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
   IRBuilder<> B(M.getContext());
   DXILOpBuilder DXILB(M, B);
@@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     if (!CI)
       continue;
 
+    SmallVector<Value *> Args;
+    Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
+    Args.emplace_back(DXILOpArg);
     B.SetInsertPoint(CI);
-    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
-                                              OverloadTy, CI->args());
+    if (isVectorArgExpansion(F)) {
+      SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
+      Args.append(NewArgs.begin(), NewArgs.end());
+    } else
+      Args.append(CI->arg_begin(), CI->arg_end());
+
+    CallInst *DXILCI =
+        DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
 
     CI->replaceAllUsesWith(DXILCI);
     CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll
new file mode 100644
index 00000000000000..a27bfaedacd573
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot2 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
+entry:
+  %dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
+  ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll
new file mode 100644
index 00000000000000..eb69fb145038aa
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot3_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot3 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) {
+entry:
+  %dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b)
+  ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll
new file mode 100644
index 00000000000000..5cd632684c0c01
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot4_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot4 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) {
+entry:
+  %dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b)
+  ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
new file mode 100644
index 00000000000000..3e13b2ad2650c8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -0,0 +1,94 @@
+; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s
+
+; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+
+; CHECK-LABEL: dot_half2
+define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
+entry:
+; CHECK: extractelement <2 x half> %a, i32 0
+; CHECK: extractelement <2 x half> %a, i32 1
+; CHECK: extractelement <2 x half> %b, i32 0
+; CHECK: extractelement <2 x half> %b, i32 1
+; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
+  ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_half3
+define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
+entry:
+; CHECK: extractelement <3 x half> %a, i32 0
+; CHECK: extractelement <3 x half> %a, i32 1
+; CHECK: extractelement <3 x half> %a, i32 2
+; CHECK: extractelement <3 x half> %b, i32 0
+; CHECK: extractelement <3 x half> %b, i32 1
+; CHECK: extractelement <3 x half> %b, i32 2
+; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
+  ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_half4
+define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
+entry:
+; CHECK: extractelement <4 x half> %a, i32 0
+; CHECK: extractelement <4 x half> %a, i32 1
+; CHECK: extractelement <4 x half> %a, i32 2
+; CHECK: extractelement <4 x half> %a, i32 3
+; CHECK: extractelement <4 x half> %b, i32 0
+; CHECK: extractelement <4 x half> %b, i32 1
+; CHECK: extractelement <4 x half> %b, i32 2
+; CHECK: extractelement <4 x half> %b, i32 3
+; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
+  ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_float2
+define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
+entry:
+; CHECK: extractelement <2 x float> %a, i32 0
+; CHECK: extractelement <2 x float> %a, i32 1
+; CHECK: extractelement <2 x float> %b, i32 0
+; CHECK: extractelement <2 x float> %b, i32 1
+; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
+  ret float %dx.dot
+}
+
+; CHECK-LABEL: dot_float3
+define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
+entry:
+; CHECK: extractelement <3 x float> %a, i32 0
+; CHECK: extractelement <3 x float> %a, i32 1
+; CHECK: extractelement <3 x float> %a, i32 2
+; CHECK: extractelement <3 x float> %b, i32 0
+; CHECK: extractelement <3 x float> %b, i32 1
+; CHECK: extractelement <3 x float> %b, i32 2
+; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
+  ret float %dx.dot
+}
+
+; CHECK-LABEL: dot_float4
+define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+; CHECK: extractelement <4 x float> %a, i32 0
+; CHECK: extractelement <4 x float> %a, i32 1
+; CHECK: extractelement <4 x float> %a, i32 2
+; CHECK: extractelement <4 x float> %a, i32 3
+; CHECK: extractelement <4 x float> %b, i32 0
+; CHECK: extractelement <4 x fl...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/86071


More information about the cfe-commits mailing list