[clang] [llvm] [HLSL][DXIL][SPIRV] Create llvm dot intrinsic and use for HLSL (PR #102872)

via cfe-commits cfe-commits at lists.llvm.org
Mon Aug 12 03:04:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Greg Roth (pow2clk)

<details>
<summary>Changes</summary>

Per https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294
`dot` should be an LLVM intrinsic. This adds the llvm intrinsics
and updates HLSL builtin codegen to emit them.

Removed some stale comments that gave the obsolete impression that
type conversions should be expected to match overloads.

With dot moving into an LLVM intrinsic, the lowering to dx-specific
operations doesn't take place until DXIL intrinsic expansion. This
moves the introduction of arity-specific DX opcodes to DXIL
intrinsic expansion.

The new LLVM integer intrinsics replace the previous dx intrinsics.
This updates the DXIL intrinsic expansion code and tests to use and
expect the new integer intrinsics and the flattened DX floating
vector size variants only after op lowering.

Use the new LLVM dot intrinsics to build SPIRV instructions.
This involves generating multiply and add operations for integers
and the existing OpDot operation for floating point. This includes
adding some generic opcodes for signed, unsigned and floats.
These require updating an existing test for all such opcodes.

New tests for generating SPIRV float and integer dot intrinsics are
added as well.

Fixes #<!-- -->88056

---

Patch is 52.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102872.diff


17 Files Affected:

- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+20-27) 
- (modified) clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl (+6-6) 
- (modified) clang/test/CodeGenHLSL/builtins/dot.hlsl (+80-80) 
- (modified) llvm/include/llvm/IR/Intrinsics.td (+9) 
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+6-14) 
- (modified) llvm/include/llvm/Support/TargetOpcodes.def (+9) 
- (modified) llvm/include/llvm/Target/GenericOpcodes.td (+21) 
- (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+6) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+3-3) 
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+50-22) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+69) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+3) 
- (modified) llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir (+9) 
- (modified) llvm/test/CodeGen/DirectX/fdot.ll (+62-55) 
- (modified) llvm/test/CodeGen/DirectX/idot.ll (+17-17) 
- (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll (+75) 
- (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll (+88) 


``````````diff
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..67148e32014ed2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18470,22 +18470,14 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
-Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
-  if (QT->hasFloatingRepresentation()) {
-    switch (elementCount) {
-    case 2:
-      return Intrinsic::dx_dot2;
-    case 3:
-      return Intrinsic::dx_dot3;
-    case 4:
-      return Intrinsic::dx_dot4;
-    }
-  }
-  if (QT->hasSignedIntegerRepresentation())
-    return Intrinsic::dx_sdot;
-
-  assert(QT->hasUnsignedIntegerRepresentation());
-  return Intrinsic::dx_udot;
+// Return dot product intrinsic that corresponds to the QT scalar type
+Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+  if (QT->isFloatingType())
+    return Intrinsic::fdot;
+  if (QT->isSignedIntegerType())
+    return Intrinsic::sdot;
+  assert(QT->isUnsignedIntegerType());
+  return Intrinsic::udot;
 }
 
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18528,37 +18520,38 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     Value *Op1 = EmitScalarExpr(E->getArg(1));
     llvm::Type *T0 = Op0->getType();
     llvm::Type *T1 = Op1->getType();
+
+    // If the arguments are scalars, just emit a multiply
     if (!T0->isVectorTy() && !T1->isVectorTy()) {
       if (T0->isFloatingPointTy())
-        return Builder.CreateFMul(Op0, Op1, "dx.dot");
+        return Builder.CreateFMul(Op0, Op1, "dot");
 
       if (T0->isIntegerTy())
-        return Builder.CreateMul(Op0, Op1, "dx.dot");
+        return Builder.CreateMul(Op0, Op1, "dot");
 
-      // Bools should have been promoted
       llvm_unreachable(
           "Scalar dot product is only supported on ints and floats.");
     }
+    // For vectors, validate types and emit the appropriate intrinsic
+
     // A VectorSplat should have happened
     assert(T0->isVectorTy() && T1->isVectorTy() &&
            "Dot product of vector and scalar is not supported.");
 
-    // A vector sext or sitofp should have happened
-    assert(T0->getScalarType() == T1->getScalarType() &&
-           "Dot product of vectors need the same element types.");
-
     auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
     [[maybe_unused]] auto *VecTy1 =
         E->getArg(1)->getType()->getAs<VectorType>();
-    // A HLSLVectorTruncation should have happend
+
+    assert(VecTy0->getElementType() == VecTy1->getElementType() &&
+           "Dot product of vectors need the same element types.");
+
     assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
            "Dot product requires vectors to be of the same size.");
 
     return Builder.CreateIntrinsic(
         /*ReturnType=*/T0->getScalarType(),
-        getDotProductIntrinsic(E->getArg(0)->getType(),
-                               VecTy0->getNumElements()),
-        ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+        getDotProductIntrinsic(VecTy0->getElementType()),
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "dot");
   } break;
   case Builtin::BI__builtin_hlsl_lerp: {
     Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
index b0b95074c972d5..6036f9430db4f0 100644
--- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -2,8 +2,8 @@
 
 // CHECK-LABEL: builtin_bool_to_float_type_promotion
 // CHECK: %conv1 = uitofp i1 %loadedv to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
 // CHECK: ret float %conv2
 float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
@@ -12,8 +12,8 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
 // CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
 // CHECK: %conv = uitofp i1 %loadedv to double
 // CHECK: %conv1 = fpext float %1 to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
 // CHECK: ret float %conv2
 float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
@@ -22,8 +22,8 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
 // CHECK-LABEL: builtin_dot_int_to_float_promotion
 // CHECK: %conv = fpext float %0 to double
 // CHECK: %conv1 = sitofp i32 %1 to double
-// CHECK: dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %dot to float
 // CHECK: ret float %conv2
 float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
   return __builtin_hlsl_dot ( p0, p1 );
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..b9486f433cced1 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -7,155 +7,155 @@
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
 #ifdef __HLSL_ENABLE_16_BIT
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
 int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dot
 uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %dot = call i16 @llvm.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dot
 uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
 #endif
 
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
 int test_dot_int(int p0, int p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
 int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
 int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
 int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = mul i32 %0, %1
+// CHECK: ret i32 %dot
 uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dot
 uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dot
 uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %dot = call i32 @llvm.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dot
 uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
 int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
 int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
 int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
 int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK:  %dot = mul i64 %0, %1
+// CHECK: ret i64 %dot
 uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dot
 uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dot
 uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %dot = call i64 @llvm.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dot
 uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = fmul half %0, %1
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = fmul float %0, %1
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = fmul half %0, %1
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = fmul float %0, %1
+// NO_HALF: ret float %dot
 half test_dot_half(half p0, half p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dot
 half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dot
 half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %dot = call half @llvm.fdot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %dot
+// NO_HALF: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dot
 half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = fmul float %0, %1
-// CHECK: ret float %dx.dot
+// CHECK: %dot = fmul float %0, %1
+// CHECK: ret float %dot
 float test_dot_float(float p0, float p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %dot = call float @llvm.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK:  %dot = call float @llvm.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK:  %dot = call float @llvm.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK:  %dot = call float @llvm.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %dot
 float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
+// CHECK: %dot = fmul double %0, %1
+// CHECK: ret double %dot
 double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39fb..815da809d28a73 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1045,6 +1045,15 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable, IntrWillReturn] in {
   def int_nearbyint : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
   def int_round : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
   def int_roundeven    : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+  def int_udot : Intrinsic<[LLVMVectorElementType<0>],
+                           [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+                           [IntrNoMem, IntrWillReturn, Commutative] >;
+  def int_sdot : Intrinsic<[LLVMVectorElementType<0>],
+                           [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+                           [IntrNoMem, IntrWillReturn, Commutative] >;
+  def int_fdot : Intrinsic<[LLVMVectorElementType<0>],
+                           [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+                           [IntrNoMem, IntrWillReturn, Commutative] >;
 
   // Truncate a floating point number with a specific rounding mode
   def int_fptrunc_round : DefaultAttrsIntrinsic<[ llvm_anyfloat_ty ],
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..8ce79eb7cbaafa 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -25,26 +25,18 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; 
 
-def int_dx_dot2 : 
-    Intrinsic<[LLVMVectorElementType<0>], 
+def int_dx_dot2 :
+    Intrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot3 : 
-    Intrinsic<[LLVMVectorElementType<0>], 
+def int_dx_dot3 :
+    Intrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot4 : 
-    Intrinsic<[LLVMVectorElementType<0>], 
+def int_dx_dot4 :
+    Intrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_sdot : 
-    Intrinsic<[LLVMVectorElementType<0>], 
-    [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_udot : 
-    Intrinsic<[LLVMVectorElementType<0>], 
-    [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, IntrWillReturn, Commutative] >;
 
 def int_dx_frac  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 9fb6de49fb2055..0808fd9d77be82 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -814,6 +814,15 @@ HANDLE_TARGET_OPCODE(G_FSINH)
 /// Floating point hyperbolic tangent.
 HANDLE_TARGET_OPCODE(G_FTANH)
 
+/// Floating point vector dot product
+HANDLE_TARGET_OPCODE(G_FDOTPROD)
+
+/// Unsigned integer vector dot product
+HANDLE_TARGET_OPCODE(G_UDOTPROD)
+
+/// Signed integer vector dot product
+HANDLE_TARGET_OPCODE(G_SDOTPROD)
+
 /// Floating point square root.
 HANDLE_TARGET_OPCODE(G_FSQRT)
 
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index 36a0a087ba457c..648671f627d649 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1057,6 +1057,27 @@ def G_FTANH : GenericInstruction {
   let hasSideEffects = false;
 }
 
+/// Floating point vector dot product
+def G_FDOTPROD : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type0:$src1, type0:$src2);
+  let hasSideEffects = false;
+}
+
+/// Signed integer vector dot product
+def G_SDOTPROD : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/102872


More information about the cfe-commits mailing list