[clang] [llvm] [HLSL][DXIL][SPIRV] Create llvm dot intrinsic and use for HLSL (PR #102872)

Greg Roth via cfe-commits cfe-commits at lists.llvm.org
Fri Aug 16 03:07:38 PDT 2024


https://github.com/pow2clk updated https://github.com/llvm/llvm-project/pull/102872

>From 6fde4bc98d0156024cf7acc27e2e986b9bec3993 Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Fri, 2 Aug 2024 20:10:04 -0600
Subject: [PATCH 1/6] Create llvm dot intrinsic

Per https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294
`dot` should be an LLVM intrinsic. This adds the llvm intrinsics
and updates HLSL builtin codegen to emit them.

Removed some stale comments that gave the obsolete impression that
type conversions should be expected to match overloads.

With dot moving into an LLVM intrinsic, the lowering to dx-specific
operations doesn't take place until DXIL intrinsic expansion. This
moves the introduction of arity-specific DX opcodes to DXIL
intrinsic expansion.

Part of #88056
---
 clang/lib/CodeGen/CGBuiltin.cpp               |  47 +++--
 .../CodeGenHLSL/builtins/dot-builtin.hlsl     |  12 +-
 clang/test/CodeGenHLSL/builtins/dot.hlsl      | 160 +++++++++---------
 llvm/include/llvm/IR/Intrinsics.td            |   9 +
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  61 +++++--
 5 files changed, 159 insertions(+), 130 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..67148e32014ed2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18470,22 +18470,14 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
-Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
-  if (QT->hasFloatingRepresentation()) {
-    switch (elementCount) {
-    case 2:
-      return Intrinsic::dx_dot2;
-    case 3:
-      return Intrinsic::dx_dot3;
-    case 4:
-      return Intrinsic::dx_dot4;
-    }
-  }
-  if (QT->hasSignedIntegerRepresentation())
-    return Intrinsic::dx_sdot;
-
-  assert(QT->hasUnsignedIntegerRepresentation());
-  return Intrinsic::dx_udot;
+// Return dot product intrinsic that corresponds to the QT scalar type
+Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+  if (QT->isFloatingType())
+    return Intrinsic::fdot;
+  if (QT->isSignedIntegerType())
+    return Intrinsic::sdot;
+  assert(QT->isUnsignedIntegerType());
+  return Intrinsic::udot;
 }
 
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18528,37 +18520,38 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     Value *Op1 = EmitScalarExpr(E->getArg(1));
     llvm::Type *T0 = Op0->getType();
     llvm::Type *T1 = Op1->getType();
+
+    // If the arguments are scalars, just emit a multiply
     if (!T0->isVectorTy() && !T1->isVectorTy()) {
       if (T0->isFloatingPointTy())
-        return Builder.CreateFMul(Op0, Op1, "dx.dot");
+        return Builder.CreateFMul(Op0, Op1, "dot");
 
       if (T0->isIntegerTy())
-        return Builder.CreateMul(Op0, Op1, "dx.dot");
+        return Builder.CreateMul(Op0, Op1, "dot");
 
-      // Bools should have been promoted
       llvm_unreachable(
           "Scalar dot product is only supported on ints and floats.");
     }
+    // For vectors, validate types and emit the appropriate intrinsic
+
     // A VectorSplat should have happened
     assert(T0->isVectorTy() && T1->isVectorTy() &&
            "Dot product of vector and scalar is not supported.");
 
-    // A vector sext or sitofp should have happened
-    assert(T0->getScalarType() == T1->getScalarType() &&
-           "Dot product of vectors need the same element types.");
-
     auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
     [[maybe_unused]] auto *VecTy1 =
         E->getArg(1)->getType()->getAs<VectorType>();
-    // A HLSLVectorTruncation should have happend
+
+    assert(VecTy0->getElementType() == VecTy1->getElementType() &&
+           "Dot product of vectors need the same element types.");
+
     assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
            "Dot product requires vectors to be of the same size.");
 
     return Builder.CreateIntrinsic(
         /*ReturnType=*/T0->getScalarType(),
-        getDotProductIntrinsic(E->getArg(0)->getType(),
-                               VecTy0->getNumElements()),
-        ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+        getDotProductIntrinsic(VecTy0->getElementType()),
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "dot");
   } break;
   case Builtin::BI__builtin_hlsl_lerp: {
     Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
index b0b95074c972d5..6036f9430db4f0 100644
--- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -2,8 +2,8 @@
 
 // CHECK-LABEL: builtin_bool_to_float_type_promotion
 // CHECK: %conv1 = uitofp i1 %loadedv to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
 // CHECK: ret float %conv2
 float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
@@ -12,8 +12,8 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
 // CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
 // CHECK: %conv = uitofp i1 %loadedv to double
 // CHECK: %conv1 = fpext float %1 to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
 // CHECK: ret float %conv2
 float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
@@ -22,8 +22,8 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
 // CHECK-LABEL: builtin_dot_int_to_float_promotion
 // CHECK: %conv = fpext float %0 to double
 // CHECK: %conv1 = sitofp i32 %1 to double
-// CHECK: dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
 // CHECK: ret float %conv2
 float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..b9486f433cced1 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -7,155 +7,155 @@
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
 #ifdef __HLSL_ENABLE_16_BIT
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
 int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
 uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
 #endif
 
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
 int test_dot_int(int p0, int p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
 int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
 int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
 int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
 uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
 uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
 uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
 uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
 int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
 int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
 int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
 int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK:  %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
 uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
 uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
 uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
 uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = fmul half %0, %1
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = fmul float %0, %1
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = fmul half %0, %1
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = fmul float %0, %1
+// NO_HALF: ret float %dot
 half test_dot_half(half p0, half p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dot
 half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dot
 half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dot
 half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = fmul float %0, %1
-// CHECK: ret float %dx.dot
+// CHECK: %dot = fmul float %0, %1
+// CHECK: ret float %dot
 float test_dot_float(float p0, float p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK:  %dot = call float @llvm.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK:  %dot = call float @llvm.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK:  %dot = call float @llvm.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
+// CHECK: %dot = fmul double %0, %1
+// CHECK: ret double %dot
 double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39fb..815da809d28a73 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1045,6 +1045,15 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable, IntrWillReturn] in {
   def int_nearbyint : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
   def int_round : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
   def int_roundeven    : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+  def int_udot : Intrinsic<[LLVMVectorElementType<0>],
+                           [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+                           [IntrNoMem, IntrWillReturn, Commutative] >;
+  def int_sdot : Intrinsic<[LLVMVectorElementType<0>],
+                           [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+                           [IntrNoMem, IntrWillReturn, Commutative] >;
+  def int_fdot : Intrinsic<[LLVMVectorElementType<0>],
+                           [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+                           [IntrNoMem, IntrWillReturn, Commutative] >;
 
   // Truncate a floating point number with a specific rounding mode
   def int_fptrunc_round : DefaultAttrsIntrinsic<[ llvm_anyfloat_ty ],
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..307ed453b9eb6d 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -45,6 +45,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_length:
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
+  case Intrinsic::fdot:
     return true;
   }
   return false;
@@ -70,31 +71,56 @@ static bool expandAbs(CallInst *Orig) {
   return true;
 }
 
-static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
   assert(DotIntrinsic == Intrinsic::dx_sdot ||
-         DotIntrinsic == Intrinsic::dx_udot);
-  Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
-                                   ? Intrinsic::dx_imad
-                                   : Intrinsic::dx_umad;
+         DotIntrinsic == Intrinsic::dx_udot || DotIntrinsic == Intrinsic::fdot);
   Value *A = Orig->getOperand(0);
   Value *B = Orig->getOperand(1);
   [[maybe_unused]] Type *ATy = A->getType();
   [[maybe_unused]] Type *BTy = B->getType();
   assert(ATy->isVectorTy() && BTy->isVectorTy());
 
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
 
   auto *AVec = dyn_cast<FixedVectorType>(A->getType());
-  Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
-  Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
-  Value *Result = Builder.CreateMul(Elt0, Elt1);
-  for (unsigned I = 1; I < AVec->getNumElements(); I++) {
-    Elt0 = Builder.CreateExtractElement(A, I);
-    Elt1 = Builder.CreateExtractElement(B, I);
-    Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
-                                     ArrayRef<Value *>{Elt0, Elt1, Result},
-                                     nullptr, "dx.mad");
+
+  Type *EltTy = ATy->getScalarType();
+  Value *Result;
+  if (EltTy->isIntegerTy()) {
+    // Expand integer dot product to multiply and add ops
+    Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
+                                     ? Intrinsic::dx_imad
+                                     : Intrinsic::dx_umad;
+    Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
+    Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
+    Result = Builder.CreateMul(Elt0, Elt1);
+    for (unsigned I = 1; I < AVec->getNumElements(); I++) {
+      Elt0 = Builder.CreateExtractElement(A, I);
+      Elt1 = Builder.CreateExtractElement(B, I);
+      Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
+                                       ArrayRef<Value *>{Elt0, Elt1, Result},
+                                       nullptr, "dx.mad");
+    }
+  } else {
+    // Use dx intrinsics for float vectors
+    assert(EltTy->isFloatingPointTy());
+
+    Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
+    switch (AVec->getNumElements()) {
+    case 2:
+      DotIntrinsic = Intrinsic::dx_dot2;
+      break;
+    case 3:
+      DotIntrinsic = Intrinsic::dx_dot3;
+      break;
+    case 4:
+      DotIntrinsic = Intrinsic::dx_dot4;
+      break;
+    default:
+      llvm_unreachable("dot product with vector outside 2-4 range");
+    }
+    Result = Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
+                                     ArrayRef<Value *>{A, B}, nullptr, "dot");
   }
   Orig->replaceAllUsesWith(Result);
   Orig->eraseFromParent();
@@ -316,7 +342,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     return expandLengthIntrinsic(Orig);
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
-    return expandIntegerDot(Orig, F.getIntrinsicID());
+  case Intrinsic::fdot:
+    return expandDotIntrinsic(Orig, F.getIntrinsicID());
   }
   return false;
 }

>From 7ca6bc5940321c18f5634bb960fa795366097e45 Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Sat, 10 Aug 2024 17:07:19 -0600
Subject: [PATCH 2/6] Update DX intrinsic expansion for new llvm intrinsics

The new LLVM integer intrinsics replace the previous dx intrinsics.
This updates the DXIL intrinsic expansion code and tests to use and
expect the new integer intrinsics and the flattened DX floating
vector size variants only after op lowering.

Part of #88056
---
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  20 +--
 llvm/lib/Target/DirectX/DXIL.td               |   6 +-
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  15 +--
 llvm/test/CodeGen/DirectX/fdot.ll             | 117 ++++++++++--------
 llvm/test/CodeGen/DirectX/idot.ll             |  34 ++---
 5 files changed, 96 insertions(+), 96 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..8ce79eb7cbaafa 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -25,26 +25,18 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; 
 
-def int_dx_dot2 : 
-    Intrinsic<[LLVMVectorElementType<0>], 
+def int_dx_dot2 :
+    Intrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot3 : 
-    Intrinsic<[LLVMVectorElementType<0>], 
+def int_dx_dot3 :
+    Intrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot4 : 
-    Intrinsic<[LLVMVectorElementType<0>], 
+def int_dx_dot4 :
+    Intrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_sdot : 
-    Intrinsic<[LLVMVectorElementType<0>], 
-    [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_udot : 
-    Intrinsic<[LLVMVectorElementType<0>], 
-    [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, IntrWillReturn, Commutative] >;
 
 def int_dx_frac  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 67015cff78a79a..ac79b84a1e9100 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -637,7 +637,7 @@ def UMad :  DXILOp<49, tertiary> {
 
 def Dot2 :  DXILOp<54, dot2> {
   let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
-            "a[n]*b[n] where n is between 0 and 1";
+            "a[n]*b[n] where n is 0 to 1 inclusive";
   let LLVMIntrinsic = int_dx_dot2;
   let arguments = !listsplat(overloadTy, 4);
   let result = overloadTy;
@@ -648,7 +648,7 @@ def Dot2 :  DXILOp<54, dot2> {
 
 def Dot3 :  DXILOp<55, dot3> {
   let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
-            "a[n]*b[n] where n is between 0 and 2";
+            "a[n]*b[n] where n is 0 to 2 inclusive";
   let LLVMIntrinsic = int_dx_dot3;
   let arguments = !listsplat(overloadTy, 6);
   let result = overloadTy;
@@ -659,7 +659,7 @@ def Dot3 :  DXILOp<55, dot3> {
 
 def Dot4 :  DXILOp<56, dot4> {
   let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
-            "a[n]*b[n] where n is between 0 and 3";
+            "a[n]*b[n] where n is 0 to 3 inclusive";
   let LLVMIntrinsic = int_dx_dot4;
   let arguments = !listsplat(overloadTy, 8);
   let result = overloadTy;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 307ed453b9eb6d..5adb62ed54e81b 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,8 +43,8 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_lerp:
   case Intrinsic::dx_length:
-  case Intrinsic::dx_sdot:
-  case Intrinsic::dx_udot:
+  case Intrinsic::sdot:
+  case Intrinsic::udot:
   case Intrinsic::fdot:
     return true;
   }
@@ -72,10 +72,11 @@ static bool expandAbs(CallInst *Orig) {
 }
 
 static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
-  assert(DotIntrinsic == Intrinsic::dx_sdot ||
-         DotIntrinsic == Intrinsic::dx_udot || DotIntrinsic == Intrinsic::fdot);
+  assert(DotIntrinsic == Intrinsic::sdot || DotIntrinsic == Intrinsic::udot ||
+         DotIntrinsic == Intrinsic::fdot);
   Value *A = Orig->getOperand(0);
   Value *B = Orig->getOperand(1);
+
   [[maybe_unused]] Type *ATy = A->getType();
   [[maybe_unused]] Type *BTy = B->getType();
   assert(ATy->isVectorTy() && BTy->isVectorTy());
@@ -88,7 +89,7 @@ static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
   Value *Result;
   if (EltTy->isIntegerTy()) {
     // Expand integer dot product to multiply and add ops
-    Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
+    Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::sdot
                                      ? Intrinsic::dx_imad
                                      : Intrinsic::dx_umad;
     Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
@@ -340,9 +341,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     return expandLerpIntrinsic(Orig);
   case Intrinsic::dx_length:
     return expandLengthIntrinsic(Orig);
-  case Intrinsic::dx_sdot:
-  case Intrinsic::dx_udot:
   case Intrinsic::fdot:
+  case Intrinsic::sdot:
+  case Intrinsic::udot:
     return expandDotIntrinsic(Orig, F.getIntrinsicID());
   }
   return false;
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
index 56817a172ff9e3..3eb39fd5d4bb74 100644
--- a/llvm/test/CodeGen/DirectX/fdot.ll
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -1,94 +1,101 @@
+; RUN: opt -S  -dxil-intrinsic-expansion -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
 ; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
-; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+; Make sure dxil operation function calls for dot are generated for float type vectors.
 
 ; CHECK-LABEL: dot_half2
 define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
 entry:
-; CHECK: extractelement <2 x half> %a, i32 0
-; CHECK: extractelement <2 x half> %a, i32 1
-; CHECK: extractelement <2 x half> %b, i32 0
-; CHECK: extractelement <2 x half> %b, i32 1
-; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
-  %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
+; DOPCHECK: extractelement <2 x half> %a, i32 0
+; DOPCHECK: extractelement <2 x half> %a, i32 1
+; DOPCHECK: extractelement <2 x half> %b, i32 0
+; DOPCHECK: extractelement <2 x half> %b, i32 1
+; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
+  %dx.dot = call half @llvm.fdot.v2f16(<2 x half> %a, <2 x half> %b)
   ret half %dx.dot
 }
 
 ; CHECK-LABEL: dot_half3
 define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
 entry:
-; CHECK: extractelement <3 x half> %a, i32 0
-; CHECK: extractelement <3 x half> %a, i32 1
-; CHECK: extractelement <3 x half> %a, i32 2
-; CHECK: extractelement <3 x half> %b, i32 0
-; CHECK: extractelement <3 x half> %b, i32 1
-; CHECK: extractelement <3 x half> %b, i32 2
-; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
-  %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
+; DOPCHECK: extractelement <3 x half> %a, i32 0
+; DOPCHECK: extractelement <3 x half> %a, i32 1
+; DOPCHECK: extractelement <3 x half> %a, i32 2
+; DOPCHECK: extractelement <3 x half> %b, i32 0
+; DOPCHECK: extractelement <3 x half> %b, i32 1
+; DOPCHECK: extractelement <3 x half> %b, i32 2
+; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
+  %dx.dot = call half @llvm.fdot.v3f16(<3 x half> %a, <3 x half> %b)
   ret half %dx.dot
 }
 
 ; CHECK-LABEL: dot_half4
 define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
 entry:
-; CHECK: extractelement <4 x half> %a, i32 0
-; CHECK: extractelement <4 x half> %a, i32 1
-; CHECK: extractelement <4 x half> %a, i32 2
-; CHECK: extractelement <4 x half> %a, i32 3
-; CHECK: extractelement <4 x half> %b, i32 0
-; CHECK: extractelement <4 x half> %b, i32 1
-; CHECK: extractelement <4 x half> %b, i32 2
-; CHECK: extractelement <4 x half> %b, i32 3
-; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
-  %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
+; DOPCHECK: extractelement <4 x half> %a, i32 0
+; DOPCHECK: extractelement <4 x half> %a, i32 1
+; DOPCHECK: extractelement <4 x half> %a, i32 2
+; DOPCHECK: extractelement <4 x half> %a, i32 3
+; DOPCHECK: extractelement <4 x half> %b, i32 0
+; DOPCHECK: extractelement <4 x half> %b, i32 1
+; DOPCHECK: extractelement <4 x half> %b, i32 2
+; DOPCHECK: extractelement <4 x half> %b, i32 3
+; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
+  %dx.dot = call half @llvm.fdot.v4f16(<4 x half> %a, <4 x half> %b)
   ret half %dx.dot
 }
 
 ; CHECK-LABEL: dot_float2
 define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
 entry:
-; CHECK: extractelement <2 x float> %a, i32 0
-; CHECK: extractelement <2 x float> %a, i32 1
-; CHECK: extractelement <2 x float> %b, i32 0
-; CHECK: extractelement <2 x float> %b, i32 1
-; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
-  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
+; DOPCHECK: extractelement <2 x float> %a, i32 0
+; DOPCHECK: extractelement <2 x float> %a, i32 1
+; DOPCHECK: extractelement <2 x float> %b, i32 0
+; DOPCHECK: extractelement <2 x float> %b, i32 1
+; DOPCHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+; EXPCHECK: call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
+  %dx.dot = call float @llvm.fdot.v2f32(<2 x float> %a, <2 x float> %b)
   ret float %dx.dot
 }
 
 ; CHECK-LABEL: dot_float3
 define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
 entry:
-; CHECK: extractelement <3 x float> %a, i32 0
-; CHECK: extractelement <3 x float> %a, i32 1
-; CHECK: extractelement <3 x float> %a, i32 2
-; CHECK: extractelement <3 x float> %b, i32 0
-; CHECK: extractelement <3 x float> %b, i32 1
-; CHECK: extractelement <3 x float> %b, i32 2
-; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
-  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
+; DOPCHECK: extractelement <3 x float> %a, i32 0
+; DOPCHECK: extractelement <3 x float> %a, i32 1
+; DOPCHECK: extractelement <3 x float> %a, i32 2
+; DOPCHECK: extractelement <3 x float> %b, i32 0
+; DOPCHECK: extractelement <3 x float> %b, i32 1
+; DOPCHECK: extractelement <3 x float> %b, i32 2
+; DOPCHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+; EXPCHECK: call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
+  %dx.dot = call float @llvm.fdot.v3f32(<3 x float> %a, <3 x float> %b)
   ret float %dx.dot
 }
 
 ; CHECK-LABEL: dot_float4
 define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
 entry:
-; CHECK: extractelement <4 x float> %a, i32 0
-; CHECK: extractelement <4 x float> %a, i32 1
-; CHECK: extractelement <4 x float> %a, i32 2
-; CHECK: extractelement <4 x float> %a, i32 3
-; CHECK: extractelement <4 x float> %b, i32 0
-; CHECK: extractelement <4 x float> %b, i32 1
-; CHECK: extractelement <4 x float> %b, i32 2
-; CHECK: extractelement <4 x float> %b, i32 3
-; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
-  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b)
+; DOPCHECK: extractelement <4 x float> %a, i32 0
+; DOPCHECK: extractelement <4 x float> %a, i32 1
+; DOPCHECK: extractelement <4 x float> %a, i32 2
+; DOPCHECK: extractelement <4 x float> %a, i32 3
+; DOPCHECK: extractelement <4 x float> %b, i32 0
+; DOPCHECK: extractelement <4 x float> %b, i32 1
+; DOPCHECK: extractelement <4 x float> %b, i32 2
+; DOPCHECK: extractelement <4 x float> %b, i32 3
+; DOPCHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+; EXPCHECK: call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b)
+  %dx.dot = call float @llvm.fdot.v4f32(<4 x float> %a, <4 x float> %b)
   ret float %dx.dot
 }
 
-declare half  @llvm.dx.dot.v2f16(<2 x half> , <2 x half> )
-declare half  @llvm.dx.dot.v3f16(<3 x half> , <3 x half> )
-declare half  @llvm.dx.dot.v4f16(<4 x half> , <4 x half> )
-declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>)
-declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>)
-declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>)
+declare half  @llvm.fdot.v2f16(<2 x half> , <2 x half> )
+declare half  @llvm.fdot.v3f16(<3 x half> , <3 x half> )
+declare half  @llvm.fdot.v4f16(<4 x half> , <4 x half> )
+declare float @llvm.fdot.v2f32(<2 x float>, <2 x float>)
+declare float @llvm.fdot.v3f32(<3 x float>, <3 x float>)
+declare float @llvm.fdot.v4f32(<4 x float>, <4 x float>)
diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll
index eac1b91106ddef..94822f92b41351 100644
--- a/llvm/test/CodeGen/DirectX/idot.ll
+++ b/llvm/test/CodeGen/DirectX/idot.ll
@@ -13,12 +13,12 @@ entry:
 ; CHECK: extractelement <2 x i16> %b, i64 1
 ; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
 ; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
-  %dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
-  ret i16 %dx.dot
+  %dot = call i16 @llvm.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
+  ret i16 %dot
 }
 
-; CHECK-LABEL: sdot_int4
-define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+; CHECK-LABEL: dot_int4
+define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
 entry:
 ; CHECK: extractelement <4 x i32> %a, i64 0
 ; CHECK: extractelement <4 x i32> %b, i64 0
@@ -35,8 +35,8 @@ entry:
 ; CHECK: extractelement <4 x i32> %b, i64 3
 ; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
 ; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
-  ret i32 %dx.dot
+  %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
+  ret i32 %dot
 }
 
 ; CHECK-LABEL: dot_uint16_t3
@@ -53,8 +53,8 @@ entry:
 ; CHECK: extractelement <3 x i16> %b, i64 2
 ; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
 ; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
-  %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
-  ret i16 %dx.dot
+  %dot = call i16 @llvm.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
+  ret i16 %dot
 }
 
 ; CHECK-LABEL: dot_uint4
@@ -75,8 +75,8 @@ entry:
 ; CHECK: extractelement <4 x i32> %b, i64 3
 ; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
 ; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
-  ret i32 %dx.dot
+  %dot = call i32 @llvm.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
+  ret i32 %dot
 }
 
 ; CHECK-LABEL: dot_uint64_t4
@@ -89,12 +89,12 @@ entry:
 ; CHECK: extractelement <2 x i64> %b, i64 1
 ; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
 ; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
-  %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
-  ret i64 %dx.dot
+  %dot = call i64 @llvm.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
+  ret i64 %dot
 }
 
-declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i16>)
-declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>)
-declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i16>)
-declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>)
-declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>)
+declare i16 @llvm.sdot.v2i16(<2 x i16>, <2 x i16>)
+declare i32 @llvm.sdot.v4i32(<4 x i32>, <4 x i32>)
+declare i16 @llvm.udot.v3i32(<3 x i16>, <3 x i16>)
+declare i32 @llvm.udot.v4i32(<4 x i32>, <4 x i32>)
+declare i64 @llvm.udot.v2i64(<2 x i64>, <2 x i64>)

>From edbb80c9a4070bdfef5712a11f47d473a85d4bb3 Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Tue, 6 Aug 2024 10:50:23 -0600
Subject: [PATCH 3/6] Add SPIRV generation for HLSL dot

Use the new LLVM dot intrinsics to build SPIRV instructions.
This involves generating multiply and add operations for integers
and the existing OpDot operation for floating point. This includes
adding some generic opcodes for signed, unsigned and floats.
These require updating an existing test for all such opcodes.

New tests for generating SPIRV float and integer dot intrinsics are
added as well.

Fixes #88056
---
 llvm/include/llvm/Support/TargetOpcodes.def   |  9 ++
 llvm/include/llvm/Target/GenericOpcodes.td    | 21 +++++
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  6 ++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 78 ++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  3 +
 .../GlobalISel/legalizer-info-validation.mir  |  9 ++
 .../CodeGen/SPIRV/hlsl-intrinsics/fdot.ll     | 75 ++++++++++++++++
 .../CodeGen/SPIRV/hlsl-intrinsics/idot.ll     | 88 +++++++++++++++++++
 8 files changed, 289 insertions(+)
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll

diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 9fb6de49fb2055..0808fd9d77be82 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -814,6 +814,15 @@ HANDLE_TARGET_OPCODE(G_FSINH)
 /// Floating point hyperbolic tangent.
 HANDLE_TARGET_OPCODE(G_FTANH)
 
+/// Floating point vector dot product
+HANDLE_TARGET_OPCODE(G_FDOTPROD)
+
+/// Unsigned integer vector dot product
+HANDLE_TARGET_OPCODE(G_UDOTPROD)
+
+/// Signed integer vector dot product
+HANDLE_TARGET_OPCODE(G_SDOTPROD)
+
 /// Floating point square root.
 HANDLE_TARGET_OPCODE(G_FSQRT)
 
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 36a0a087ba457c..648671f627d649 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1057,6 +1057,27 @@ def G_FTANH : GenericInstruction {
   let hasSideEffects = false;
 }
 
+/// Floating point vector dot product
+def G_FDOTPROD : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2);
+  let hasSideEffects = false;
+}
+
+/// Signed integer vector dot product
+def G_SDOTPROD : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2);
+  let hasSideEffects = false;
+}
+
+/// Unsigned integer vector dot product
+def G_UDOTPROD : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2);
+  let hasSideEffects = false;
+}
+
 // Floating point square root of a value.
 // This returns NaN for negative nonzero values.
 // NOTE: Unlike libm sqrt(), this never sets errno. In all other respects it's
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 0169a0e466d878..d471b172a7dcb2 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1903,6 +1903,12 @@ unsigned IRTranslator::getSimpleIntrinsicOpcode(Intrinsic::ID ID) {
       return TargetOpcode::G_CTPOP;
     case Intrinsic::exp:
       return TargetOpcode::G_FEXP;
+    case Intrinsic::fdot:
+      return TargetOpcode::G_FDOTPROD;
+    case Intrinsic::sdot:
+      return TargetOpcode::G_SDOTPROD;
+    case Intrinsic::udot:
+      return TargetOpcode::G_UDOTPROD;
     case Intrinsic::exp2:
       return TargetOpcode::G_FEXP2;
     case Intrinsic::exp10:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ed786bd33aa05b..98b7847a611476 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -178,6 +178,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
+                        MachineInstr &I) const;
+
   void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
                    int OpIdx) const;
   void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -380,6 +383,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
       MIB.addImm(V);
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
+
+  case TargetOpcode::G_FDOTPROD: {
+    MachineBasicBlock &BB = *I.getParent();
+    return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(I.getOperand(1).getReg())
+        .addUse(I.getOperand(2).getReg())
+        .constrainAllUses(TII, TRI, RBI);
+  }
+  case TargetOpcode::G_SDOTPROD:
+  case TargetOpcode::G_UDOTPROD:
+    return selectIntegerDot(ResVReg, ResType, I);
+
   case TargetOpcode::G_MEMMOVE:
   case TargetOpcode::G_MEMCPY:
   case TargetOpcode::G_MEMSET:
@@ -1366,6 +1383,67 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+// Since there is no integer dot implementation, expand by piecewise multiplying
+// and adding the results, making use of FMA operations where possible.
+bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
+                                                const SPIRVType *ResType,
+                                                MachineInstr &I) const {
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(1).isReg());
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  // Multiply the vectors, then sum the results
+  Register Vec0 = I.getOperand(1).getReg();
+  Register Vec1 = I.getOperand(2).getReg();
+  Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0);
+
+  bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV))
+                    .addDef(TmpVec)
+                    .addUse(GR.getSPIRVTypeID(VecType))
+                    .addUse(Vec0)
+                    .addUse(Vec1)
+                    .constrainAllUses(TII, TRI, RBI);
+
+  assert(GR.getScalarOrVectorComponentCount(VecType) > 1 &&
+         "dot product requires a vector of at least 2 components");
+
+  Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+                .addDef(Res)
+                .addUse(GR.getSPIRVTypeID(ResType))
+                .addUse(TmpVec)
+                .addImm(0)
+                .constrainAllUses(TII, TRI, RBI);
+
+  for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) {
+    Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+
+    Result |=
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+            .addDef(Elt)
+            .addUse(GR.getSPIRVTypeID(ResType))
+            .addUse(TmpVec)
+            .addImm(i)
+            .constrainAllUses(TII, TRI, RBI);
+
+    Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1
+                       ? MRI->createVirtualRegister(&SPIRV::IDRegClass)
+                       : ResVReg;
+
+    Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+                  .addDef(Sum)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(Res)
+                  .addUse(Elt)
+                  .constrainAllUses(TII, TRI, RBI);
+    Res = Sum;
+  }
+
+  return Result;
+}
+
 bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
                                                 const SPIRVType *ResType,
                                                 MachineInstr &I) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index e775f8c57b048e..fb39c107a417ad 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -285,6 +285,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                                G_FCOSH,
                                G_FSINH,
                                G_FTANH,
+                               G_FDOTPROD,
+                               G_SDOTPROD,
+                               G_UDOTPROD,
                                G_FSQRT,
                                G_FFLOOR,
                                G_FRINT,
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
index 87a415b45cca9a..57faf42a5156c6 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
@@ -716,6 +716,15 @@
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
 # DEBUG-NEXT: .. the first uncovered type index: 1, OK
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: G_FDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
+# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
+# DEBUG-NEXT: G_UDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
+# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
+# DEBUG-NEXT: G_SDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
+# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: G_FSQRT (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
 # DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll
new file mode 100644
index 00000000000000..964decf237a5cf
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll
@@ -0,0 +1,75 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure dxil operation function calls for dot are generated for float type vectors.
+
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec2_float_16:]] = OpTypeVector %[[#float_16]] 2
+; CHECK-DAG: %[[#vec3_float_16:]] = OpTypeVector %[[#float_16]] 3
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#vec2_float_32:]] = OpTypeVector %[[#float_32]] 2
+; CHECK-DAG: %[[#vec3_float_32:]] = OpTypeVector %[[#float_32]] 3
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+
+
+define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_16]]
+; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
+  %dx.dot = call half @llvm.fdot.v2f16(<2 x half> %a, <2 x half> %b)
+  ret half %dx.dot
+}
+
+define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_16]]
+; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
+  %dx.dot = call half @llvm.fdot.v3f16(<3 x half> %a, <3 x half> %b)
+  ret half %dx.dot
+}
+
+define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
+; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
+  %dx.dot = call half @llvm.fdot.v4f16(<4 x half> %a, <4 x half> %b)
+  ret half %dx.dot
+}
+
+define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_32]]
+; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
+  %dx.dot = call float @llvm.fdot.v2f32(<2 x float> %a, <2 x float> %b)
+  ret float %dx.dot
+}
+
+define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_32]]
+; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
+  %dx.dot = call float @llvm.fdot.v3f32(<3 x float> %a, <3 x float> %b)
+  ret float %dx.dot
+}
+
+define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
+  %dx.dot = call float @llvm.fdot.v4f32(<4 x float> %a, <4 x float> %b)
+  ret float %dx.dot
+}
+
+declare half  @llvm.fdot.v2f16(<2 x half> , <2 x half> )
+declare half  @llvm.fdot.v3f16(<3 x half> , <3 x half> )
+declare half  @llvm.fdot.v4f16(<4 x half> , <4 x half> )
+declare float @llvm.fdot.v2f32(<2 x float>, <2 x float>)
+declare float @llvm.fdot.v3f32(<3 x float>, <3 x float>)
+declare float @llvm.fdot.v4f32(<4 x float>, <4 x float>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
new file mode 100644
index 00000000000000..05f28920e205b7
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
@@ -0,0 +1,88 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+
+; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
+; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
+; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
+; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32
+; CHECK-DAG: %[[#vec4_int_32:]] = OpTypeVector %[[#int_32]] 4
+; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
+; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2
+
+define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
+; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
+  %dot = call i16 @llvm.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
+  ret i16 %dot
+}
+
+define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
+; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
+; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
+; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
+; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
+; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
+  %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
+  ret i32 %dot
+}
+
+define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
+; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
+; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
+; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
+  %dot = call i16 @llvm.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
+  ret i16 %dot
+}
+
+define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
+; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
+; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
+; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
+; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
+; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
+  %dot = call i32 @llvm.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
+  ret i32 %dot
+}
+
+define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
+; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
+  %dot = call i64 @llvm.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
+  ret i64 %dot
+}
+
+declare i16 @llvm.sdot.v2i16(<2 x i16>, <2 x i16>)
+declare i32 @llvm.sdot.v4i32(<4 x i32>, <4 x i32>)
+declare i16 @llvm.udot.v3i32(<3 x i16>, <3 x i16>)
+declare i32 @llvm.udot.v4i32(<4 x i32>, <4 x i32>)
+declare i64 @llvm.udot.v2i64(<2 x i64>, <2 x i64>)

>From c08c0153cbde1f43bebdbb8b50b74e77cdfc40bb Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Mon, 12 Aug 2024 16:19:03 -0600
Subject: [PATCH 4/6] Respond to feedback

Rename dot-related ops to hlsl.dot

Add documentation for new Genereic opcodes and llvm intrinsics

Use DefaultAttrsIntrinsic to define new llvm intrinsics

Split DXIL instruction expansion for integer and float dot products
---
 clang/lib/CodeGen/CGBuiltin.cpp               |   6 +-
 .../CodeGenHLSL/builtins/dot-builtin.hlsl     |  10 +-
 clang/test/CodeGenHLSL/builtins/dot.hlsl      | 160 +++++++++---------
 llvm/docs/GlobalISel/GenericOpcode.rst        |   8 +
 llvm/docs/LangRef.rst                         | 132 +++++++++++++++
 llvm/include/llvm/IR/Intrinsics.td            |  18 +-
 llvm/include/llvm/Target/GenericOpcodes.td    |  42 ++---
 .../Target/DirectX/DXILIntrinsicExpansion.cpp | 100 ++++++-----
 8 files changed, 316 insertions(+), 160 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 67148e32014ed2..c14b7f41a37507 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18524,10 +18524,10 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     // If the arguments are scalars, just emit a multiply
     if (!T0->isVectorTy() && !T1->isVectorTy()) {
       if (T0->isFloatingPointTy())
-        return Builder.CreateFMul(Op0, Op1, "dot");
+        return Builder.CreateFMul(Op0, Op1, "hlsl.dot");
 
       if (T0->isIntegerTy())
-        return Builder.CreateMul(Op0, Op1, "dot");
+        return Builder.CreateMul(Op0, Op1, "hlsl.dot");
 
       llvm_unreachable(
           "Scalar dot product is only supported on ints and floats.");
@@ -18551,7 +18551,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     return Builder.CreateIntrinsic(
         /*ReturnType=*/T0->getScalarType(),
         getDotProductIntrinsic(VecTy0->getElementType()),
-        ArrayRef<Value *>{Op0, Op1}, nullptr, "dot");
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
   } break;
   case Builtin::BI__builtin_hlsl_lerp: {
     Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
index 6036f9430db4f0..482f089d4770fd 100644
--- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -2,8 +2,8 @@
 
 // CHECK-LABEL: builtin_bool_to_float_type_promotion
 // CHECK: %conv1 = uitofp i1 %loadedv to double
-// CHECK: %dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dot to float
+// CHECK: %hlsl.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %hlsl.dot to float
 // CHECK: ret float %conv2
 float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
@@ -12,8 +12,8 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
 // CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
 // CHECK: %conv = uitofp i1 %loadedv to double
 // CHECK: %conv1 = fpext float %1 to double
-// CHECK: %dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dot to float
+// CHECK: %hlsl.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %hlsl.dot to float
 // CHECK: ret float %conv2
 float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
@@ -23,7 +23,7 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
 // CHECK: %conv = fpext float %0 to double
 // CHECK: %conv1 = sitofp i32 %1 to double
 // CHECK: dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dot to float
+// CHECK: %conv2 = fptrunc double %hlsl.dot to float
 // CHECK: ret float %conv2
 float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index b9486f433cced1..05233be9165263 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -7,155 +7,155 @@
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
 #ifdef __HLSL_ENABLE_16_BIT
-// NATIVE_HALF: %dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dot
+// NATIVE_HALF: %hlsl.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %hlsl.dot
 int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call i16 @llvm.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
 int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call i16 @llvm.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
 int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call i16 @llvm.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
 int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dot
+// NATIVE_HALF: %hlsl.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %hlsl.dot
 uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call i16 @llvm.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
 uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call i16 @llvm.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
 uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call i16 @llvm.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
 uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
 #endif
 
-// CHECK: %dot = mul i32 %0, %1
-// CHECK: ret i32 %dot
+// CHECK: %hlsl.dot = mul i32 %0, %1
+// CHECK: ret i32 %hlsl.dot
 int test_dot_int(int p0, int p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i32 @llvm.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dot
+// CHECK: %hlsl.dot = call i32 @llvm.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
 int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i32 @llvm.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dot
+// CHECK: %hlsl.dot = call i32 @llvm.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
 int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dot
+// CHECK: %hlsl.dot = call i32 @llvm.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
 int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = mul i32 %0, %1
-// CHECK: ret i32 %dot
+// CHECK: %hlsl.dot = mul i32 %0, %1
+// CHECK: ret i32 %hlsl.dot
 uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i32 @llvm.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dot
+// CHECK: %hlsl.dot = call i32 @llvm.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
 uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i32 @llvm.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dot
+// CHECK: %hlsl.dot = call i32 @llvm.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
 uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i32 @llvm.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dot
+// CHECK: %hlsl.dot = call i32 @llvm.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
 uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = mul i64 %0, %1
-// CHECK: ret i64 %dot
+// CHECK: %hlsl.dot = mul i64 %0, %1
+// CHECK: ret i64 %hlsl.dot
 int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i64 @llvm.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dot
+// CHECK: %hlsl.dot = call i64 @llvm.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
 int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i64 @llvm.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dot
+// CHECK: %hlsl.dot = call i64 @llvm.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
 int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i64 @llvm.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dot
+// CHECK: %hlsl.dot = call i64 @llvm.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
 int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
 
-// CHECK:  %dot = mul i64 %0, %1
-// CHECK: ret i64 %dot
+// CHECK:  %hlsl.dot = mul i64 %0, %1
+// CHECK: ret i64 %hlsl.dot
 uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i64 @llvm.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dot
+// CHECK: %hlsl.dot = call i64 @llvm.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
 uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i64 @llvm.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dot
+// CHECK: %hlsl.dot = call i64 @llvm.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
 uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call i64 @llvm.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dot
+// CHECK: %hlsl.dot = call i64 @llvm.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
 uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = fmul half %0, %1
-// NATIVE_HALF: ret half %dot
-// NO_HALF: %dot = fmul float %0, %1
-// NO_HALF: ret float %dot
+// NATIVE_HALF: %hlsl.dot = fmul half %0, %1
+// NATIVE_HALF: ret half %hlsl.dot
+// NO_HALF: %hlsl.dot = fmul float %0, %1
+// NO_HALF: ret float %hlsl.dot
 half test_dot_half(half p0, half p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1)
-// NATIVE_HALF: ret half %dot
-// NO_HALF: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
-// NO_HALF: ret float %dot
+// NATIVE_HALF: %hlsl.dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %hlsl.dot
+// NO_HALF: %hlsl.dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %hlsl.dot
 half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call half @llvm.fdot.v3f16(<3 x half> %0, <3 x half> %1)
-// NATIVE_HALF: ret half %dot
-// NO_HALF: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
-// NO_HALF: ret float %dot
+// NATIVE_HALF: %hlsl.dot = call half @llvm.fdot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %hlsl.dot
+// NO_HALF: %hlsl.dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %hlsl.dot
 half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dot = call half @llvm.fdot.v4f16(<4 x half> %0, <4 x half> %1)
-// NATIVE_HALF: ret half %dot
-// NO_HALF: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
-// NO_HALF: ret float %dot
+// NATIVE_HALF: %hlsl.dot = call half @llvm.fdot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %hlsl.dot
+// NO_HALF: %hlsl.dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %hlsl.dot
 half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = fmul float %0, %1
-// CHECK: ret float %dot
+// CHECK: %hlsl.dot = fmul float %0, %1
+// CHECK: ret float %hlsl.dot
 float test_dot_float(float p0, float p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
-// CHECK: ret float %dot
+// CHECK: %hlsl.dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %hlsl.dot
 float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
-// CHECK: ret float %dot
+// CHECK: %hlsl.dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %hlsl.dot
 float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
-// CHECK: ret float %dot
+// CHECK: %hlsl.dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %hlsl.dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
 
-// CHECK:  %dot = call float @llvm.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dot
+// CHECK:  %hlsl.dot = call float @llvm.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %hlsl.dot
 float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK:  %dot = call float @llvm.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dot
+// CHECK:  %hlsl.dot = call float @llvm.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %hlsl.dot
 float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK:  %dot = call float @llvm.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dot
+// CHECK:  %hlsl.dot = call float @llvm.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %hlsl.dot
 float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
 
-// CHECK: %dot = fmul double %0, %1
-// CHECK: ret double %dot
+// CHECK: %hlsl.dot = fmul double %0, %1
+// CHECK: ret double %hlsl.dot
 double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/llvm/docs/GlobalISel/GenericOpcode.rst b/llvm/docs/GlobalISel/GenericOpcode.rst
index d32aeff5a69bb1..5a6994452c0ef7 100644
--- a/llvm/docs/GlobalISel/GenericOpcode.rst
+++ b/llvm/docs/GlobalISel/GenericOpcode.rst
@@ -633,6 +633,14 @@ G_FCOS, G_FSIN, G_FTAN, G_FACOS, G_FASIN, G_FATAN, G_FCOSH, G_FSINH, G_FTANH
 
 These correspond to the standard C trigonometry functions of the same name.
 
+
+G_FDOTPROD, G_SDOTPROD, G_UDOTPROD
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+These represent the floating point, signed integer, and unsigned integer dot products respectively.
+A dot product takes two equal-sized vectors and multiplies each element by the element in the corresponding
+location of the other vector and then sums all the products, returning a scalar value.
+
 G_INTRINSIC_TRUNC
 ^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 0ee4d7b444cfcf..b494ec734099ab 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -15674,6 +15674,138 @@ trapping or setting ``errno``.
 When specified with the fast-math-flag 'afn', the result may be approximated
 using a less accurate calculation.
 
+'``llvm.fdot.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. You can use ``llvm.fdot`` on a 2-4 element 
+vector of 16-bit or 32-bit floating-point types. Not all targets support
+all types however.
+
+::
+
+      declare half     @llvm.fdot.v2f16(<2 x half> %Vec0, <2 x half> %Vec1)
+      declare half     @llvm.fdot.v3f16(<3 x half> %Vec0, <3 x half> %Vec1)
+      declare half     @llvm.fdot.v4f16(<4 x half> %Vec0, <4 x half> %Vec1)
+      declare float     @llvm.fdot.v2f32(<2 x float> %Vec0, <2 x float> %Vec1)
+      declare float     @llvm.fdot.v3f32(<3 x float> %Vec0, <3 x float> %Vec1)
+      declare float     @llvm.fdot.v4f32(<4 x float> %Vec0, <4 x float> %Vec1)
+
+Overview:
+"""""""""
+
+The '``llvm.fdot.*``' intrinsics return the dot product of the two vector operands.
+A dot product takes two equal-sized vectors and multiplies each element by the element in the corresponding
+location of the other vector and then sums all the products, returning a scalar value.
+
+
+Arguments:
+""""""""""
+
+The arguments are vectors to be elementwise multiplied and then summed.
+The vectors may be of 2-4 elements and contain 16-bit or 32-bit float elements.
+The arguments must be vectors of the same size and scalar element type.
+The return type must match the scalar type of the elements.
+
+Semantics:
+""""""""""
+
+Return the summation of the products of the elements of the first and second arguments using
+floating point arithmetic operations.
+
+
+'``llvm.sdot.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. You can use ``llvm.sdot`` on a 2-4 element 
+vector of 16-bit to 64-bit signed integer types. Not all targets support
+all types however.
+
+::
+
+      declare i16     @llvm.sdot.v2i16(<2 x i16> %Vec0, <2 x i16> %Vec1)
+      declare i16     @llvm.sdot.v3i16(<3 x i16> %Vec0, <3 x i16> %Vec1)
+      declare i16     @llvm.sdot.v4i16(<4 x i16> %Vec0, <4 x i16> %Vec1)
+      declare i32     @llvm.sdot.v2i32(<2 x i32> %Vec0, <2 x i32> %Vec1)
+      declare i32     @llvm.sdot.v3i32(<3 x i32> %Vec0, <3 x i32> %Vec1)
+      declare i32     @llvm.sdot.v4i32(<4 x i32> %Vec0, <4 x i32> %Vec1)
+      declare i64     @llvm.sdot.v2i64(<2 x i64> %Vec0, <2 x i64> %Vec1)
+      declare i64     @llvm.sdot.v3i64(<3 x i64> %Vec0, <3 x i64> %Vec1)
+      declare i64     @llvm.sdot.v4i64(<4 x i64> %Vec0, <4 x i64> %Vec1)
+
+Overview:
+"""""""""
+
+The '``llvm.sdot.*``' intrinsics return the dot product of the two vector operands.
+A dot product takes two equal-sized vectors and multiplies each element by the element in the corresponding
+location of the other vector and then sums all the products, returning a scalar value.
+
+
+Arguments:
+""""""""""
+
+The arguments are vectors to be elementwise multiplied and then summed.
+The vectors may be of 2-4 elements and contain 16-bit to 64-bit signed integer elements.
+The arguments must be vectors of the same size and scalar element type.
+The return type must match the scalar type of the elements.
+
+Semantics:
+""""""""""
+
+Return the summation of the products of the elements of the first and second arguments using
+signed integer arithmetic operations.
+
+
+'``llvm.udot.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. You can use ``llvm.udot`` on a 2-4 element 
+vector of 16-bit to 64-bit unsigned integer types. Not all targets support
+all types however.
+
+::
+
+      declare i16     @llvm.udot.v2i16(<2 x i16> %Vec0, <2 x i16> %Vec1)
+      declare i16     @llvm.udot.v3i16(<3 x i16> %Vec0, <3 x i16> %Vec1)
+      declare i16     @llvm.udot.v4i16(<4 x i16> %Vec0, <4 x i16> %Vec1)
+      declare i32     @llvm.udot.v2i32(<2 x i32> %Vec0, <2 x i32> %Vec1)
+      declare i32     @llvm.udot.v3i32(<3 x i32> %Vec0, <3 x i32> %Vec1)
+      declare i32     @llvm.udot.v4i32(<4 x i32> %Vec0, <4 x i32> %Vec1)
+      declare i64     @llvm.udot.v2i64(<2 x i64> %Vec0, <2 x i64> %Vec1)
+      declare i64     @llvm.udot.v3i64(<3 x i64> %Vec0, <3 x i64> %Vec1)
+      declare i64     @llvm.udot.v4i64(<4 x i64> %Vec0, <4 x i64> %Vec1)
+
+Overview:
+"""""""""
+
+The '``llvm.udot.*``' intrinsics return the dot product of the two vector operands.
+A dot product takes two equal-sized vectors and multiplies each element by the element in the corresponding
+location of the other vector and then sums all the products, returning a scalar value.
+
+
+Arguments:
+""""""""""
+
+The arguments are vectors to be elementwise multiplied and then summed.
+The vectors may be of 2-4 elements and contain 16-bit to 64-bit unsigned integer elements.
+The arguments must be vectors of the same size and scalar element type.
+The return type must match the scalar type of the elements.
+
+Semantics:
+""""""""""
+
+Return the summation of the products of the elements of the first and second arguments using
+unsigned integer arithmetic operations.
+
+
 '``llvm.pow.*``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 815da809d28a73..d412ac4167cf86 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1045,15 +1045,15 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable, IntrWillReturn] in {
   def int_nearbyint : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
   def int_round : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
   def int_roundeven    : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
-  def int_udot : Intrinsic<[LLVMVectorElementType<0>],
-                           [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-                           [IntrNoMem, IntrWillReturn, Commutative] >;
-  def int_sdot : Intrinsic<[LLVMVectorElementType<0>],
-                           [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-                           [IntrNoMem, IntrWillReturn, Commutative] >;
-  def int_fdot : Intrinsic<[LLVMVectorElementType<0>],
-                           [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-                           [IntrNoMem, IntrWillReturn, Commutative] >;
+  def int_udot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
+                                       [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+				       [IntrNoMem, Commutative] >;
+  def int_sdot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
+                                       [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+                                       [IntrNoMem, Commutative] >;
+  def int_fdot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
+                                       [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+                                       [IntrNoMem, Commutative] >;
 
   // Truncate a floating point number with a specific rounding mode
   def int_fptrunc_round : DefaultAttrsIntrinsic<[ llvm_anyfloat_ty ],
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 648671f627d649..398d16deda306e 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1057,27 +1057,6 @@ def G_FTANH : GenericInstruction {
   let hasSideEffects = false;
 }
 
-/// Floating point vector dot product
-def G_FDOTPROD : GenericInstruction {
-  let OutOperandList = (outs type0:$dst);
-  let InOperandList = (ins type0:$src1, type0:$src2);
-  let hasSideEffects = false;
-}
-
-/// Signed integer vector dot product
-def G_SDOTPROD : GenericInstruction {
-  let OutOperandList = (outs type0:$dst);
-  let InOperandList = (ins type0:$src1, type0:$src2);
-  let hasSideEffects = false;
-}
-
-/// Unsigned integer vector dot product
-def G_UDOTPROD : GenericInstruction {
-  let OutOperandList = (outs type0:$dst);
-  let InOperandList = (ins type0:$src1, type0:$src2);
-  let hasSideEffects = false;
-}
-
 // Floating point square root of a value.
 // This returns NaN for negative nonzero values.
 // NOTE: Unlike libm sqrt(), this never sets errno. In all other respects it's
@@ -1109,6 +1088,27 @@ def G_FNEARBYINT : GenericInstruction {
   let hasSideEffects = false;
 }
 
+/// Floating point vector dot product
+def G_FDOTPROD : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2);
+  let hasSideEffects = false;
+}
+
+/// Signed integer vector dot product
+def G_SDOTPROD : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2);
+  let hasSideEffects = false;
+}
+
+/// Unsigned integer vector dot product
+def G_UDOTPROD : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2);
+  let hasSideEffects = false;
+}
+
 //------------------------------------------------------------------------------
 // Access to floating-point environment.
 //------------------------------------------------------------------------------
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 5adb62ed54e81b..9b7ea0950d82c6 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -71,58 +71,73 @@ static bool expandAbs(CallInst *Orig) {
   return true;
 }
 
-static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
-  assert(DotIntrinsic == Intrinsic::sdot || DotIntrinsic == Intrinsic::udot ||
-         DotIntrinsic == Intrinsic::fdot);
+// Create DXIL dot intrinsics for floating point dot operations
+static bool expandFloatDotIntrinsic(CallInst *Orig) {
   Value *A = Orig->getOperand(0);
   Value *B = Orig->getOperand(1);
 
-  [[maybe_unused]] Type *ATy = A->getType();
+  Type *ATy = A->getType();
   [[maybe_unused]] Type *BTy = B->getType();
   assert(ATy->isVectorTy() && BTy->isVectorTy());
 
   IRBuilder<> Builder(Orig);
 
-  auto *AVec = dyn_cast<FixedVectorType>(A->getType());
+  auto *AVec = dyn_cast<FixedVectorType>(ATy);
+
+  assert(ATy->getScalarType()->isFloatingPointTy());
+
+  Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
+  switch (AVec->getNumElements()) {
+  case 2:
+    DotIntrinsic = Intrinsic::dx_dot2;
+    break;
+  case 3:
+    DotIntrinsic = Intrinsic::dx_dot3;
+    break;
+  case 4:
+    DotIntrinsic = Intrinsic::dx_dot4;
+    break;
+  default:
+    llvm_unreachable("dot product with vector outside 2-4 range");
+  }
+  Value *Result = Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
+				   ArrayRef<Value *>{A, B}, nullptr, "dot");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+// Expand integer dot product to multiply and add ops
+static bool expandIntegerDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+  assert(DotIntrinsic == Intrinsic::sdot || DotIntrinsic == Intrinsic::udot);
+  Value *A = Orig->getOperand(0);
+  Value *B = Orig->getOperand(1);
+
+  Type *ATy = A->getType();
+  [[maybe_unused]] Type *BTy = B->getType();
+  assert(ATy->isVectorTy() && BTy->isVectorTy());
+
+  IRBuilder<> Builder(Orig);
+
+  auto *AVec = dyn_cast<FixedVectorType>(ATy);
+
+  assert(ATy->getScalarType()->isIntegerTy());
 
-  Type *EltTy = ATy->getScalarType();
   Value *Result;
-  if (EltTy->isIntegerTy()) {
-    // Expand integer dot product to multiply and add ops
-    Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::sdot
-                                     ? Intrinsic::dx_imad
-                                     : Intrinsic::dx_umad;
-    Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
-    Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
-    Result = Builder.CreateMul(Elt0, Elt1);
-    for (unsigned I = 1; I < AVec->getNumElements(); I++) {
-      Elt0 = Builder.CreateExtractElement(A, I);
-      Elt1 = Builder.CreateExtractElement(B, I);
-      Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
-                                       ArrayRef<Value *>{Elt0, Elt1, Result},
-                                       nullptr, "dx.mad");
-    }
-  } else {
-    // Use dx intrinsics for float vectors
-    assert(EltTy->isFloatingPointTy());
-
-    Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
-    switch (AVec->getNumElements()) {
-    case 2:
-      DotIntrinsic = Intrinsic::dx_dot2;
-      break;
-    case 3:
-      DotIntrinsic = Intrinsic::dx_dot3;
-      break;
-    case 4:
-      DotIntrinsic = Intrinsic::dx_dot4;
-      break;
-    default:
-      llvm_unreachable("dot product with vector outside 2-4 range");
-    }
-    Result = Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
-                                     ArrayRef<Value *>{A, B}, nullptr, "dot");
+  Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::sdot
+    ? Intrinsic::dx_imad
+    : Intrinsic::dx_umad;
+  Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
+  Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
+  Result = Builder.CreateMul(Elt0, Elt1);
+  for (unsigned i = 1; i < AVec->getNumElements(); i++) {
+    Elt0 = Builder.CreateExtractElement(A, i);
+    Elt1 = Builder.CreateExtractElement(B, i);
+    Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
+				     ArrayRef<Value *>{Elt0, Elt1, Result},
+				     nullptr, "dx.mad");
   }
+
   Orig->replaceAllUsesWith(Result);
   Orig->eraseFromParent();
   return true;
@@ -342,9 +357,10 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::dx_length:
     return expandLengthIntrinsic(Orig);
   case Intrinsic::fdot:
+    return expandFloatDotIntrinsic(Orig);
   case Intrinsic::sdot:
   case Intrinsic::udot:
-    return expandDotIntrinsic(Orig, F.getIntrinsicID());
+    return expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
   }
   return false;
 }

>From 0e05f013b54ea978703c61b3616eddb08f8aae98 Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Wed, 14 Aug 2024 12:42:33 -0600
Subject: [PATCH 5/6] refactor DXIL intrinsic expansion to move erase and
 replace to the caller

All expansions end with replacing the previous inrinsic with the new
expansion and erasing the old one. By moving this operation to the
caller, these expansion functions can be called in more contexts
and a small amount of duplicated code is consolidated.
---
 .../Target/DirectX/DXILIntrinsicExpansion.cpp | 141 ++++++++----------
 1 file changed, 65 insertions(+), 76 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 9b7ea0950d82c6..b67a0147bca719 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
   return false;
 }
 
-static bool expandAbs(CallInst *Orig) {
+static Value *expandAbs(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -66,13 +66,11 @@ static bool expandAbs(CallInst *Orig) {
   auto *V = Builder.CreateSub(Zero, X);
   auto *MaxCall =
       Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
-  Orig->replaceAllUsesWith(MaxCall);
-  Orig->eraseFromParent();
-  return true;
+  return MaxCall;
 }
 
 // Create DXIL dot intrinsics for floating point dot operations
-static bool expandFloatDotIntrinsic(CallInst *Orig) {
+static Value *expandFloatDotIntrinsic(CallInst *Orig) {
   Value *A = Orig->getOperand(0);
   Value *B = Orig->getOperand(1);
 
@@ -100,15 +98,13 @@ static bool expandFloatDotIntrinsic(CallInst *Orig) {
   default:
     llvm_unreachable("dot product with vector outside 2-4 range");
   }
-  Value *Result = Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
-				   ArrayRef<Value *>{A, B}, nullptr, "dot");
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
+                                 ArrayRef<Value *>{A, B}, nullptr, "dot");
 }
 
 // Expand integer dot product to multiply and add ops
-static bool expandIntegerDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+static Value *expandIntegerDotIntrinsic(CallInst *Orig,
+                                        Intrinsic::ID DotIntrinsic) {
   assert(DotIntrinsic == Intrinsic::sdot || DotIntrinsic == Intrinsic::udot);
   Value *A = Orig->getOperand(0);
   Value *B = Orig->getOperand(1);
@@ -124,9 +120,8 @@ static bool expandIntegerDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic
   assert(ATy->getScalarType()->isIntegerTy());
 
   Value *Result;
-  Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::sdot
-    ? Intrinsic::dx_imad
-    : Intrinsic::dx_umad;
+  Intrinsic::ID MadIntrinsic =
+      DotIntrinsic == Intrinsic::sdot ? Intrinsic::dx_imad : Intrinsic::dx_umad;
   Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
   Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
   Result = Builder.CreateMul(Elt0, Elt1);
@@ -134,16 +129,13 @@ static bool expandIntegerDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic
     Elt0 = Builder.CreateExtractElement(A, i);
     Elt1 = Builder.CreateExtractElement(B, i);
     Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
-				     ArrayRef<Value *>{Elt0, Elt1, Result},
-				     nullptr, "dx.mad");
+                                     ArrayRef<Value *>{Elt0, Elt1, Result},
+                                     nullptr, "dx.mad");
   }
-
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Result;
 }
 
-static bool expandExpIntrinsic(CallInst *Orig) {
+static Value *expandExpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -160,23 +152,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
   Exp2Call->setTailCall(Orig->isTailCall());
   Exp2Call->setAttributes(Orig->getAttributes());
-  Orig->replaceAllUsesWith(Exp2Call);
-  Orig->eraseFromParent();
-  return true;
+  return Exp2Call;
 }
 
-static bool expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
+  Value *Result = nullptr;
   if (!Ty->isVectorTy()) {
-    Value *Cond = EltTy->isFloatingPointTy()
-                      ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
-                      : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
-    Orig->replaceAllUsesWith(Cond);
+    Result = EltTy->isFloatingPointTy()
+                 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
+                 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
   } else {
     auto *XVec = dyn_cast<FixedVectorType>(Ty);
     Value *Cond =
@@ -189,18 +179,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
                   X, ConstantVector::getSplat(
                          ElementCount::getFixed(XVec->getNumElements()),
                          ConstantInt::get(EltTy, 0)));
-    Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
+    Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
       Value *Elt = Builder.CreateExtractElement(Cond, I);
       Result = Builder.CreateOr(Result, Elt);
     }
-    Orig->replaceAllUsesWith(Result);
   }
-  Orig->eraseFromParent();
-  return true;
+  return Result;
 }
 
-static bool expandLengthIntrinsic(CallInst *Orig) {
+static Value *expandLengthIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -223,15 +211,11 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
     Value *Mul = Builder.CreateFMul(Elt, Elt);
     Sum = Builder.CreateFAdd(Sum, Mul);
   }
-  Value *Result = Builder.CreateIntrinsic(
-      EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
-
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
+                                 nullptr, "elt.sqrt");
 }
 
-static bool expandLerpIntrinsic(CallInst *Orig) {
+static Value *expandLerpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Value *S = Orig->getOperand(2);
@@ -239,14 +223,11 @@ static bool expandLerpIntrinsic(CallInst *Orig) {
   Builder.SetInsertPoint(Orig);
   auto *V = Builder.CreateFSub(Y, X);
   V = Builder.CreateFMul(S, V);
-  auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFAdd(X, V, "dx.lerp");
 }
 
-static bool expandLogIntrinsic(CallInst *Orig,
-                               float LogConstVal = numbers::ln2f) {
+static Value *expandLogIntrinsic(CallInst *Orig,
+                                 float LogConstVal = numbers::ln2f) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -262,16 +243,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
   Log2Call->setTailCall(Orig->isTailCall());
   Log2Call->setAttributes(Orig->getAttributes());
-  auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFMul(Ln2Const, Log2Call);
 }
-static bool expandLog10Intrinsic(CallInst *Orig) {
+static Value *expandLog10Intrinsic(CallInst *Orig) {
   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
 }
 
-static bool expandPowIntrinsic(CallInst *Orig) {
+static Value *expandPowIntrinsic(CallInst *Orig) {
 
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
@@ -286,9 +264,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
   Exp2Call->setTailCall(Orig->isTailCall());
   Exp2Call->setAttributes(Orig->getAttributes());
-  Orig->replaceAllUsesWith(Exp2Call);
-  Orig->eraseFromParent();
-  return true;
+  return Exp2Call;
 }
 
 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
@@ -317,7 +293,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
   return Intrinsic::minnum;
 }
 
-static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
+static Value *expandClampIntrinsic(CallInst *Orig,
+                                   Intrinsic::ID ClampIntrinsic) {
   Value *X = Orig->getOperand(0);
   Value *Min = Orig->getOperand(1);
   Value *Max = Orig->getOperand(2);
@@ -326,43 +303,55 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
   Builder.SetInsertPoint(Orig);
   auto *MaxCall = Builder.CreateIntrinsic(
       Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
-  auto *MinCall =
-      Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
-                              {MaxCall, Max}, nullptr, "dx.min");
-
-  Orig->replaceAllUsesWith(MinCall);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
+                                 {MaxCall, Max}, nullptr, "dx.min");
 }
 
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
+  Value *Result = nullptr;
   switch (F.getIntrinsicID()) {
   case Intrinsic::abs:
-    return expandAbs(Orig);
+    Result = expandAbs(Orig);
+    break;
   case Intrinsic::exp:
-    return expandExpIntrinsic(Orig);
+    Result = expandExpIntrinsic(Orig);
+    break;
   case Intrinsic::log:
-    return expandLogIntrinsic(Orig);
+    Result = expandLogIntrinsic(Orig);
+    break;
   case Intrinsic::log10:
-    return expandLog10Intrinsic(Orig);
+    Result = expandLog10Intrinsic(Orig);
+    break;
   case Intrinsic::pow:
-    return expandPowIntrinsic(Orig);
+    Result = expandPowIntrinsic(Orig);
+    break;
   case Intrinsic::dx_any:
-    return expandAnyIntrinsic(Orig);
+    Result = expandAnyIntrinsic(Orig);
+    break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
-    return expandClampIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+    break;
   case Intrinsic::dx_lerp:
-    return expandLerpIntrinsic(Orig);
+    Result = expandLerpIntrinsic(Orig);
+    break;
   case Intrinsic::dx_length:
-    return expandLengthIntrinsic(Orig);
+    Result = expandLengthIntrinsic(Orig);
+    break;
   case Intrinsic::fdot:
-    return expandFloatDotIntrinsic(Orig);
+    Result = expandFloatDotIntrinsic(Orig);
+    break;
   case Intrinsic::sdot:
   case Intrinsic::udot:
-    return expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
+    break;
   }
-  return false;
+
+  if (Result) {
+    Orig->replaceAllUsesWith(Result);
+    Orig->eraseFromParent();
+  }
+  return !!Result;
 }
 
 static bool expansionIntrinsics(Module &M) {

>From ad527a772bc87c0d76fcd8c06133cbf57ca64669 Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Fri, 16 Aug 2024 02:48:03 -0600
Subject: [PATCH 6/6] Fix legalization of dotprod generic opcodes

---
 llvm/include/llvm/Target/GenericOpcodes.td    |  6 ++--
 llvm/lib/CodeGen/MachineVerifier.cpp          | 30 +++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  | 11 +++++--
 .../GlobalISel/legalizer-info-validation.mir  |  6 ++--
 4 files changed, 44 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 398d16deda306e..ced103debf6090 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1091,21 +1091,21 @@ def G_FNEARBYINT : GenericInstruction {
 /// Floating point vector dot product
 def G_FDOTPROD : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
-  let InOperandList = (ins type0:$src1, type0:$src2);
+  let InOperandList = (ins type1:$src1, type1:$src2);
   let hasSideEffects = false;
 }
 
 /// Signed integer vector dot product
 def G_SDOTPROD : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
-  let InOperandList = (ins type0:$src1, type0:$src2);
+  let InOperandList = (ins type1:$src1, type1:$src2);
   let hasSideEffects = false;
 }
 
 /// Unsigned integer vector dot product
 def G_UDOTPROD : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
-  let InOperandList = (ins type0:$src1, type0:$src2);
+  let InOperandList = (ins type1:$src1, type1:$src2);
   let hasSideEffects = false;
 }
 
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index d22fbe322ec36f..586ea587ef856d 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -2092,6 +2092,36 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
     }
     break;
   }
+
+  case TargetOpcode::G_SDOTPROD:
+  case TargetOpcode::G_UDOTPROD:
+  case TargetOpcode::G_FDOTPROD: {
+    LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
+    if (!DstTy.isScalar()) {
+      report("Destination must be a scalar", MI);
+      break;
+    }
+    LLT Src0Ty = MRI->getType(MI->getOperand(1).getReg());
+    LLT Src1Ty = MRI->getType(MI->getOperand(2).getReg());
+    LLT Src0EltTy = Src0Ty.getScalarType();
+    LLT Src1EltTy = Src1Ty.getScalarType();
+
+    if (!Src0Ty.isVector() || !Src1Ty.isVector()) {
+      report("Sources must be vectors", MI);
+      break;
+    }
+    if (Src0EltTy == Src1EltTy) {
+      report("Source vectors must have the same scalar element types", MI);
+    }
+    if (Src0EltTy != DstTy.getScalarType()) {
+      report("Destination type must match source element types", MI);
+      break;
+    }
+
+    if (!verifyVectorElementMatch(Src0Ty, Src1Ty, MI))
+      break;
+    break;
+  }
   case TargetOpcode::G_PREFETCH: {
     const MachineOperand &AddrOp = MI->getOperand(0);
     if (!AddrOp.isReg() || !MRI->getType(AddrOp.getReg()).isPointer()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index fb39c107a417ad..1434cb31e4fba3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -285,9 +285,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                                G_FCOSH,
                                G_FSINH,
                                G_FTANH,
-                               G_FDOTPROD,
-                               G_SDOTPROD,
-                               G_UDOTPROD,
                                G_FSQRT,
                                G_FFLOOR,
                                G_FRINT,
@@ -306,6 +303,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
       allFloatScalarsAndVectors, allIntScalarsAndVectors);
 
+  getActionDefinitionsBuilder(G_FDOTPROD)
+      .legalForCartesianProduct(allFloatScalarsAndVectors,
+                                allFloatScalarsAndVectors);
+
+  getActionDefinitionsBuilder({G_SDOTPROD, G_UDOTPROD})
+      .legalForCartesianProduct(allIntScalarsAndVectors,
+                                allIntScalarsAndVectors);
+
   if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
     getActionDefinitionsBuilder(
         {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
index 57faf42a5156c6..12510b6e58b6e7 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
@@ -716,13 +716,13 @@
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
 # DEBUG-NEXT: .. the first uncovered type index: 1, OK
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
-# DEBUG-NEXT: G_FDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
+# DEBUG-NEXT: G_FDOTPROD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
 # DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
-# DEBUG-NEXT: G_UDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
+# DEBUG-NEXT: G_UDOTPROD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
 # DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
-# DEBUG-NEXT: G_SDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
+# DEBUG-NEXT: G_SDOTPROD (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
 # DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: G_FSQRT (opcode {{[0-9]+}}): 1 type index, 0 imm indices



More information about the cfe-commits mailing list