[clang] [llvm] [HLSL][SPIRV]Add SPIRV generation for HLSL dot (PR #104656)
Greg Roth via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 19 12:26:36 PDT 2024
https://github.com/pow2clk updated https://github.com/llvm/llvm-project/pull/104656
>From 9aff63478b76f042c05b7ae3dd1a2c099dc615de Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Fri, 2 Aug 2024 20:10:04 -0600
Subject: [PATCH 1/5] Add SPIRV generation for HLSL dot
This adds the SPIRV fdot, sdot, and udot intrinsics and allows
them to be created at codegen depending on the target architecture.
This required moving some of the DXIL-specific choices to DXIL
instruction expansion out of codegen and providing it with at a
more generic fdot intrinsic as well.
Removed some stale comments that gave the obsolete impression that
type conversions should be expected to match overloads.
The SPIRV intrinsic handling involves generating multiply and add
operations for integers and the existing OpDot operation for
floating point.
New tests for generating SPIRV float and integer dot intrinsics are
added as well.
Incidentally changed existing dot intrinsic definitions to use
DefaultAttrsIntrinsic to match the newly added inrinsics
Fixes #88056
---
clang/lib/CodeGen/CGBuiltin.cpp | 47 +++--
clang/lib/CodeGen/CGHLSLRuntime.h | 3 +
.../CodeGenHLSL/builtins/dot-builtin.hlsl | 12 +-
clang/test/CodeGenHLSL/builtins/dot.hlsl | 160 +++++++++---------
llvm/include/llvm/IR/IntrinsicsDirectX.td | 34 ++--
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 12 ++
llvm/lib/Target/DirectX/DXIL.td | 6 +-
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 67 ++++++--
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 74 ++++++++
llvm/test/CodeGen/DirectX/fdot.ll | 117 +++++++------
llvm/test/CodeGen/DirectX/idot.ll | 24 +--
.../CodeGen/SPIRV/hlsl-intrinsics/fdot.ll | 75 ++++++++
.../CodeGen/SPIRV/hlsl-intrinsics/idot.ll | 88 ++++++++++
13 files changed, 508 insertions(+), 211 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f424ddaa175400..5c49e71df3fcfa 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18471,22 +18471,14 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}
-Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
- if (QT->hasFloatingRepresentation()) {
- switch (elementCount) {
- case 2:
- return Intrinsic::dx_dot2;
- case 3:
- return Intrinsic::dx_dot3;
- case 4:
- return Intrinsic::dx_dot4;
- }
- }
- if (QT->hasSignedIntegerRepresentation())
- return Intrinsic::dx_sdot;
-
- assert(QT->hasUnsignedIntegerRepresentation());
- return Intrinsic::dx_udot;
+// Return dot product intrinsic that corresponds to the QT scalar type
+Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
+ if (QT->isFloatingType())
+ return RT.getFDotIntrinsic();
+ if (QT->isSignedIntegerType())
+ return RT.getSDotIntrinsic();
+ assert(QT->isUnsignedIntegerType());
+ return RT.getUDotIntrinsic();
}
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18529,37 +18521,38 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *Op1 = EmitScalarExpr(E->getArg(1));
llvm::Type *T0 = Op0->getType();
llvm::Type *T1 = Op1->getType();
+
+ // If the arguments are scalars, just emit a multiply
if (!T0->isVectorTy() && !T1->isVectorTy()) {
if (T0->isFloatingPointTy())
- return Builder.CreateFMul(Op0, Op1, "dx.dot");
+ return Builder.CreateFMul(Op0, Op1, "hlsl.dot");
if (T0->isIntegerTy())
- return Builder.CreateMul(Op0, Op1, "dx.dot");
+ return Builder.CreateMul(Op0, Op1, "hlsl.dot");
- // Bools should have been promoted
llvm_unreachable(
"Scalar dot product is only supported on ints and floats.");
}
+ // For vectors, validate types and emit the appropriate intrinsic
+
// A VectorSplat should have happened
assert(T0->isVectorTy() && T1->isVectorTy() &&
"Dot product of vector and scalar is not supported.");
- // A vector sext or sitofp should have happened
- assert(T0->getScalarType() == T1->getScalarType() &&
- "Dot product of vectors need the same element types.");
-
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
- // A HLSLVectorTruncation should have happend
+
+ assert(VecTy0->getElementType() == VecTy1->getElementType() &&
+ "Dot product of vectors need the same element types.");
+
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
"Dot product requires vectors to be of the same size.");
return Builder.CreateIntrinsic(
/*ReturnType=*/T0->getScalarType(),
- getDotProductIntrinsic(E->getArg(0)->getType(),
- VecTy0->getNumElements()),
- ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+ getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
+ ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index cd604bea2e763d..2d968f74196de1 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -80,6 +80,9 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
//===----------------------------------------------------------------------===//
// End of reserved area for HLSL intrinsic getters.
diff --git a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
index b0b95074c972d5..482f089d4770fd 100644
--- a/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot-builtin.hlsl
@@ -2,8 +2,8 @@
// CHECK-LABEL: builtin_bool_to_float_type_promotion
// CHECK: %conv1 = uitofp i1 %loadedv to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %hlsl.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %hlsl.dot to float
// CHECK: ret float %conv2
float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
@@ -12,8 +12,8 @@ float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) {
// CHECK-LABEL: builtin_bool_to_float_arg1_type_promotion
// CHECK: %conv = uitofp i1 %loadedv to double
// CHECK: %conv1 = fpext float %1 to double
-// CHECK: %dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: %hlsl.dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %hlsl.dot to float
// CHECK: ret float %conv2
float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
@@ -22,8 +22,8 @@ float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) {
// CHECK-LABEL: builtin_dot_int_to_float_promotion
// CHECK: %conv = fpext float %0 to double
// CHECK: %conv1 = sitofp i32 %1 to double
-// CHECK: dx.dot = fmul double %conv, %conv1
-// CHECK: %conv2 = fptrunc double %dx.dot to float
+// CHECK: dot = fmul double %conv, %conv1
+// CHECK: %conv2 = fptrunc double %hlsl.dot to float
// CHECK: ret float %conv2
float builtin_dot_int_to_float_promotion ( float p0, int p1 ) {
return __builtin_hlsl_dot ( p0, p1 );
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..6d0cf41f4d98bd 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -7,155 +7,155 @@
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
#ifdef __HLSL_ENABLE_16_BIT
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %hlsl.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %hlsl.dot
int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %hlsl.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %hlsl.dot
uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %dx.dot
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %hlsl.dot
uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
#endif
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %hlsl.dot = mul i32 %0, %1
+// CHECK: ret i32 %hlsl.dot
int test_dot_int(int p0, int p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %hlsl.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %hlsl.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %hlsl.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i32 %0, %1
-// CHECK: ret i32 %dx.dot
+// CHECK: %hlsl.dot = mul i32 %0, %1
+// CHECK: ret i32 %hlsl.dot
uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %hlsl.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %hlsl.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
-// CHECK: ret i32 %dx.dot
+// CHECK: %hlsl.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %hlsl.dot
uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %hlsl.dot = mul i64 %0, %1
+// CHECK: ret i64 %hlsl.dot
int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %hlsl.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %hlsl.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %hlsl.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = mul i64 %0, %1
-// CHECK: ret i64 %dx.dot
+// CHECK: %hlsl.dot = mul i64 %0, %1
+// CHECK: ret i64 %hlsl.dot
uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %hlsl.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %hlsl.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
-// CHECK: ret i64 %dx.dot
+// CHECK: %hlsl.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %hlsl.dot
uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = fmul half %0, %1
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = fmul float %0, %1
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %hlsl.dot = fmul half %0, %1
+// NATIVE_HALF: ret half %hlsl.dot
+// NO_HALF: %hlsl.dot = fmul float %0, %1
+// NO_HALF: ret float %hlsl.dot
half test_dot_half(half p0, half p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %hlsl.dot = call half @llvm.dx.fdot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %hlsl.dot
+// NO_HALF: %hlsl.dot = call float @llvm.dx.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %hlsl.dot
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %hlsl.dot = call half @llvm.dx.fdot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %hlsl.dot
+// NO_HALF: %hlsl.dot = call float @llvm.dx.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %hlsl.dot
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
-// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// NO_HALF: ret float %dx.dot
+// NATIVE_HALF: %hlsl.dot = call half @llvm.dx.fdot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %hlsl.dot
+// NO_HALF: %hlsl.dot = call float @llvm.dx.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %hlsl.dot
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = fmul float %0, %1
-// CHECK: ret float %dx.dot
+// CHECK: %hlsl.dot = fmul float %0, %1
+// CHECK: ret float %hlsl.dot
float test_dot_float(float p0, float p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: ret float %hlsl.dot
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: ret float %hlsl.dot
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: ret float %hlsl.dot
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: ret float %hlsl.dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: ret float %hlsl.dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
+// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: ret float %hlsl.dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
+// CHECK: %hlsl.dot = fmul double %0, %1
+// CHECK: ret double %hlsl.dot
double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index c9102aa3dd972b..0a5902b95741b4 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -35,26 +35,30 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
-def int_dx_dot2 :
- Intrinsic<[LLVMVectorElementType<0>],
+def int_dx_dot2 :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot3 :
- Intrinsic<[LLVMVectorElementType<0>],
+ [IntrNoMem, Commutative] >;
+def int_dx_dot3 :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_dot4 :
- Intrinsic<[LLVMVectorElementType<0>],
+ [IntrNoMem, Commutative] >;
+def int_dx_dot4 :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_sdot :
- Intrinsic<[LLVMVectorElementType<0>],
+ [IntrNoMem, Commutative] >;
+def int_dx_fdot :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, Commutative] >;
+def int_dx_sdot :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
-def int_dx_udot :
- Intrinsic<[LLVMVectorElementType<0>],
+ [IntrNoMem, Commutative] >;
+def int_dx_udot :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, IntrWillReturn, Commutative] >;
+ [IntrNoMem, Commutative] >;
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 1b5e463822749e..eba68f080aed9b 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -66,4 +66,16 @@ let TargetPrefix = "spv" in {
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
+ def int_spv_fdot :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, Commutative] >;
+ def int_spv_sdot :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, Commutative] >;
+ def int_spv_udot :
+ DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, Commutative] >;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 67015cff78a79a..ac79b84a1e9100 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -637,7 +637,7 @@ def UMad : DXILOp<49, tertiary> {
def Dot2 : DXILOp<54, dot2> {
let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
- "a[n]*b[n] where n is between 0 and 1";
+ "a[n]*b[n] where n is 0 to 1 inclusive";
let LLVMIntrinsic = int_dx_dot2;
let arguments = !listsplat(overloadTy, 4);
let result = overloadTy;
@@ -648,7 +648,7 @@ def Dot2 : DXILOp<54, dot2> {
def Dot3 : DXILOp<55, dot3> {
let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
- "a[n]*b[n] where n is between 0 and 2";
+ "a[n]*b[n] where n is 0 to 2 inclusive";
let LLVMIntrinsic = int_dx_dot3;
let arguments = !listsplat(overloadTy, 6);
let result = overloadTy;
@@ -659,7 +659,7 @@ def Dot3 : DXILOp<55, dot3> {
def Dot4 : DXILOp<56, dot4> {
let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
- "a[n]*b[n] where n is between 0 and 3";
+ "a[n]*b[n] where n is 0 to 3 inclusive";
let LLVMIntrinsic = int_dx_dot4;
let arguments = !listsplat(overloadTy, 8);
let result = overloadTy;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 2c481d15be5bde..8316c040580be1 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -44,6 +44,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
case Intrinsic::dx_normalize:
+ case Intrinsic::dx_fdot:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
@@ -68,28 +69,65 @@ static Value *expandAbs(CallInst *Orig) {
"dx.max");
}
-static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+// Create DXIL dot intrinsics for floating point dot operations
+static Value *expandFloatDotIntrinsic(CallInst *Orig) {
+ Value *A = Orig->getOperand(0);
+ Value *B = Orig->getOperand(1);
+ Type *ATy = A->getType();
+ [[maybe_unused]] Type *BTy = B->getType();
+ assert(ATy->isVectorTy() && BTy->isVectorTy());
+
+ IRBuilder<> Builder(Orig);
+
+ auto *AVec = dyn_cast<FixedVectorType>(ATy);
+
+ assert(ATy->getScalarType()->isFloatingPointTy());
+
+ Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
+ switch (AVec->getNumElements()) {
+ case 2:
+ DotIntrinsic = Intrinsic::dx_dot2;
+ break;
+ case 3:
+ DotIntrinsic = Intrinsic::dx_dot3;
+ break;
+ case 4:
+ DotIntrinsic = Intrinsic::dx_dot4;
+ break;
+ default:
+ llvm_unreachable("dot product with vector outside 2-4 range");
+ }
+ return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
+ ArrayRef<Value *>{A, B}, nullptr, "dot");
+}
+
+// Expand integer dot product to multiply and add ops
+static Value *expandIntegerDotIntrinsic(CallInst *Orig,
+ Intrinsic::ID DotIntrinsic) {
assert(DotIntrinsic == Intrinsic::dx_sdot ||
DotIntrinsic == Intrinsic::dx_udot);
- Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
- ? Intrinsic::dx_imad
- : Intrinsic::dx_umad;
Value *A = Orig->getOperand(0);
Value *B = Orig->getOperand(1);
- [[maybe_unused]] Type *ATy = A->getType();
+ Type *ATy = A->getType();
[[maybe_unused]] Type *BTy = B->getType();
assert(ATy->isVectorTy() && BTy->isVectorTy());
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
+
+ auto *AVec = dyn_cast<FixedVectorType>(ATy);
- auto *AVec = dyn_cast<FixedVectorType>(A->getType());
+ assert(ATy->getScalarType()->isIntegerTy());
+
+ Value *Result;
+ Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
+ ? Intrinsic::dx_imad
+ : Intrinsic::dx_umad;
Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
- Value *Result = Builder.CreateMul(Elt0, Elt1);
- for (unsigned I = 1; I < AVec->getNumElements(); I++) {
- Elt0 = Builder.CreateExtractElement(A, I);
- Elt1 = Builder.CreateExtractElement(B, I);
+ Result = Builder.CreateMul(Elt0, Elt1);
+ for (unsigned i = 1; i < AVec->getNumElements(); i++) {
+ Elt0 = Builder.CreateExtractElement(A, i);
+ Elt1 = Builder.CreateExtractElement(B, i);
Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
ArrayRef<Value *>{Elt0, Elt1, Result},
nullptr, "dx.mad");
@@ -363,9 +401,12 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::dx_normalize:
Result = expandNormalizeIntrinsic(Orig);
break;
+ case Intrinsic::dx_fdot:
+ Result = expandFloatDotIntrinsic(Orig);
+ break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
- Result = expandIntegerDot(Orig, F.getIntrinsicID());
+ Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7cb19279518989..730a2be19ba081 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -184,6 +184,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -1446,6 +1449,67 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+// Since there is no integer dot implementation, expand by piecewise multiplying
+// and adding the results, making use of FMA operations where possible.
+bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ assert(I.getNumOperands() == 4);
+ assert(I.getOperand(2).isReg());
+ assert(I.getOperand(3).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ // Multiply the vectors, then sum the results
+ Register Vec0 = I.getOperand(2).getReg();
+ Register Vec1 = I.getOperand(3).getReg();
+ Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+ SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0);
+
+ bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV))
+ .addDef(TmpVec)
+ .addUse(GR.getSPIRVTypeID(VecType))
+ .addUse(Vec0)
+ .addUse(Vec1)
+ .constrainAllUses(TII, TRI, RBI);
+
+ assert(GR.getScalarOrVectorComponentCount(VecType) > 1 &&
+ "dot product requires a vector of at least 2 components");
+
+ Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+ Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+ .addDef(Res)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(TmpVec)
+ .addImm(0)
+ .constrainAllUses(TII, TRI, RBI);
+
+ for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) {
+ Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+
+ Result |=
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+ .addDef(Elt)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(TmpVec)
+ .addImm(i)
+ .constrainAllUses(TII, TRI, RBI);
+
+ Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1
+ ? MRI->createVirtualRegister(&SPIRV::IDRegClass)
+ : ResVReg;
+
+ Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+ .addDef(Sum)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(Res)
+ .addUse(Elt)
+ .constrainAllUses(TII, TRI, RBI);
+ Res = Sum;
+ }
+
+ return Result;
+}
+
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2157,6 +2221,16 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
break;
case Intrinsic::spv_thread_id:
return selectSpvThreadId(ResVReg, ResType, I);
+ case Intrinsic::spv_fdot:
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(2).getReg())
+ .addUse(I.getOperand(3).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+ case Intrinsic::spv_udot:
+ case Intrinsic::spv_sdot:
+ return selectIntegerDot(ResVReg, ResType, I);
case Intrinsic::spv_all:
return selectAll(ResVReg, ResType, I);
case Intrinsic::spv_any:
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
index 56817a172ff9e3..a13406e4c88611 100644
--- a/llvm/test/CodeGen/DirectX/fdot.ll
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -1,94 +1,101 @@
+; RUN: opt -S -dxil-intrinsic-expansion -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
-; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+; Make sure dxil operation function calls for dot are generated for float type vectors.
; CHECK-LABEL: dot_half2
define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
entry:
-; CHECK: extractelement <2 x half> %a, i32 0
-; CHECK: extractelement <2 x half> %a, i32 1
-; CHECK: extractelement <2 x half> %b, i32 0
-; CHECK: extractelement <2 x half> %b, i32 1
-; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
- %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
+; DOPCHECK: extractelement <2 x half> %a, i32 0
+; DOPCHECK: extractelement <2 x half> %a, i32 1
+; DOPCHECK: extractelement <2 x half> %b, i32 0
+; DOPCHECK: extractelement <2 x half> %b, i32 1
+; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
+ %dx.dot = call half @llvm.dx.fdot.v2f16(<2 x half> %a, <2 x half> %b)
ret half %dx.dot
}
; CHECK-LABEL: dot_half3
define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
entry:
-; CHECK: extractelement <3 x half> %a, i32 0
-; CHECK: extractelement <3 x half> %a, i32 1
-; CHECK: extractelement <3 x half> %a, i32 2
-; CHECK: extractelement <3 x half> %b, i32 0
-; CHECK: extractelement <3 x half> %b, i32 1
-; CHECK: extractelement <3 x half> %b, i32 2
-; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
- %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
+; DOPCHECK: extractelement <3 x half> %a, i32 0
+; DOPCHECK: extractelement <3 x half> %a, i32 1
+; DOPCHECK: extractelement <3 x half> %a, i32 2
+; DOPCHECK: extractelement <3 x half> %b, i32 0
+; DOPCHECK: extractelement <3 x half> %b, i32 1
+; DOPCHECK: extractelement <3 x half> %b, i32 2
+; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
+ %dx.dot = call half @llvm.dx.fdot.v3f16(<3 x half> %a, <3 x half> %b)
ret half %dx.dot
}
; CHECK-LABEL: dot_half4
define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
entry:
-; CHECK: extractelement <4 x half> %a, i32 0
-; CHECK: extractelement <4 x half> %a, i32 1
-; CHECK: extractelement <4 x half> %a, i32 2
-; CHECK: extractelement <4 x half> %a, i32 3
-; CHECK: extractelement <4 x half> %b, i32 0
-; CHECK: extractelement <4 x half> %b, i32 1
-; CHECK: extractelement <4 x half> %b, i32 2
-; CHECK: extractelement <4 x half> %b, i32 3
-; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
- %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
+; DOPCHECK: extractelement <4 x half> %a, i32 0
+; DOPCHECK: extractelement <4 x half> %a, i32 1
+; DOPCHECK: extractelement <4 x half> %a, i32 2
+; DOPCHECK: extractelement <4 x half> %a, i32 3
+; DOPCHECK: extractelement <4 x half> %b, i32 0
+; DOPCHECK: extractelement <4 x half> %b, i32 1
+; DOPCHECK: extractelement <4 x half> %b, i32 2
+; DOPCHECK: extractelement <4 x half> %b, i32 3
+; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
+ %dx.dot = call half @llvm.dx.fdot.v4f16(<4 x half> %a, <4 x half> %b)
ret half %dx.dot
}
; CHECK-LABEL: dot_float2
define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
entry:
-; CHECK: extractelement <2 x float> %a, i32 0
-; CHECK: extractelement <2 x float> %a, i32 1
-; CHECK: extractelement <2 x float> %b, i32 0
-; CHECK: extractelement <2 x float> %b, i32 1
-; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
- %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
+; DOPCHECK: extractelement <2 x float> %a, i32 0
+; DOPCHECK: extractelement <2 x float> %a, i32 1
+; DOPCHECK: extractelement <2 x float> %b, i32 0
+; DOPCHECK: extractelement <2 x float> %b, i32 1
+; DOPCHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+; EXPCHECK: call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
+ %dx.dot = call float @llvm.dx.fdot.v2f32(<2 x float> %a, <2 x float> %b)
ret float %dx.dot
}
; CHECK-LABEL: dot_float3
define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
entry:
-; CHECK: extractelement <3 x float> %a, i32 0
-; CHECK: extractelement <3 x float> %a, i32 1
-; CHECK: extractelement <3 x float> %a, i32 2
-; CHECK: extractelement <3 x float> %b, i32 0
-; CHECK: extractelement <3 x float> %b, i32 1
-; CHECK: extractelement <3 x float> %b, i32 2
-; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
- %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
+; DOPCHECK: extractelement <3 x float> %a, i32 0
+; DOPCHECK: extractelement <3 x float> %a, i32 1
+; DOPCHECK: extractelement <3 x float> %a, i32 2
+; DOPCHECK: extractelement <3 x float> %b, i32 0
+; DOPCHECK: extractelement <3 x float> %b, i32 1
+; DOPCHECK: extractelement <3 x float> %b, i32 2
+; DOPCHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+; EXPCHECK: call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
+ %dx.dot = call float @llvm.dx.fdot.v3f32(<3 x float> %a, <3 x float> %b)
ret float %dx.dot
}
; CHECK-LABEL: dot_float4
define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
entry:
-; CHECK: extractelement <4 x float> %a, i32 0
-; CHECK: extractelement <4 x float> %a, i32 1
-; CHECK: extractelement <4 x float> %a, i32 2
-; CHECK: extractelement <4 x float> %a, i32 3
-; CHECK: extractelement <4 x float> %b, i32 0
-; CHECK: extractelement <4 x float> %b, i32 1
-; CHECK: extractelement <4 x float> %b, i32 2
-; CHECK: extractelement <4 x float> %b, i32 3
-; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
- %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b)
+; DOPCHECK: extractelement <4 x float> %a, i32 0
+; DOPCHECK: extractelement <4 x float> %a, i32 1
+; DOPCHECK: extractelement <4 x float> %a, i32 2
+; DOPCHECK: extractelement <4 x float> %a, i32 3
+; DOPCHECK: extractelement <4 x float> %b, i32 0
+; DOPCHECK: extractelement <4 x float> %b, i32 1
+; DOPCHECK: extractelement <4 x float> %b, i32 2
+; DOPCHECK: extractelement <4 x float> %b, i32 3
+; DOPCHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+; EXPCHECK: call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b)
+ %dx.dot = call float @llvm.dx.fdot.v4f32(<4 x float> %a, <4 x float> %b)
ret float %dx.dot
}
-declare half @llvm.dx.dot.v2f16(<2 x half> , <2 x half> )
-declare half @llvm.dx.dot.v3f16(<3 x half> , <3 x half> )
-declare half @llvm.dx.dot.v4f16(<4 x half> , <4 x half> )
-declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>)
-declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>)
-declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>)
+declare half @llvm.dx.fdot.v2f16(<2 x half> , <2 x half> )
+declare half @llvm.dx.fdot.v3f16(<3 x half> , <3 x half> )
+declare half @llvm.dx.fdot.v4f16(<4 x half> , <4 x half> )
+declare float @llvm.dx.fdot.v2f32(<2 x float>, <2 x float>)
+declare float @llvm.dx.fdot.v3f32(<3 x float>, <3 x float>)
+declare float @llvm.dx.fdot.v4f32(<4 x float>, <4 x float>)
diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll
index eac1b91106ddef..5848868ed0556a 100644
--- a/llvm/test/CodeGen/DirectX/idot.ll
+++ b/llvm/test/CodeGen/DirectX/idot.ll
@@ -13,12 +13,12 @@ entry:
; CHECK: extractelement <2 x i16> %b, i64 1
; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
- %dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
- ret i16 %dx.dot
+ %dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
+ ret i16 %dot
}
-; CHECK-LABEL: sdot_int4
-define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+; CHECK-LABEL: dot_int4
+define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
entry:
; CHECK: extractelement <4 x i32> %a, i64 0
; CHECK: extractelement <4 x i32> %b, i64 0
@@ -35,8 +35,8 @@ entry:
; CHECK: extractelement <4 x i32> %b, i64 3
; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
- ret i32 %dx.dot
+ %dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
+ ret i32 %dot
}
; CHECK-LABEL: dot_uint16_t3
@@ -53,8 +53,8 @@ entry:
; CHECK: extractelement <3 x i16> %b, i64 2
; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
- %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
- ret i16 %dx.dot
+ %dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
+ ret i16 %dot
}
; CHECK-LABEL: dot_uint4
@@ -75,8 +75,8 @@ entry:
; CHECK: extractelement <4 x i32> %b, i64 3
; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
- ret i32 %dx.dot
+ %dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
+ ret i32 %dot
}
; CHECK-LABEL: dot_uint64_t4
@@ -89,8 +89,8 @@ entry:
; CHECK: extractelement <2 x i64> %b, i64 1
; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
- %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
- ret i64 %dx.dot
+ %dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
+ ret i64 %dot
}
declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i16>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll
new file mode 100644
index 00000000000000..5a8d4581aa0cdb
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll
@@ -0,0 +1,75 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure dxil operation function calls for dot are generated for float type vectors.
+
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec2_float_16:]] = OpTypeVector %[[#float_16]] 2
+; CHECK-DAG: %[[#vec3_float_16:]] = OpTypeVector %[[#float_16]] 3
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#vec2_float_32:]] = OpTypeVector %[[#float_32]] 2
+; CHECK-DAG: %[[#vec3_float_32:]] = OpTypeVector %[[#float_32]] 3
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+
+
+define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_16]]
+; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
+ %dx.dot = call half @llvm.spv.fdot.v2f16(<2 x half> %a, <2 x half> %b)
+ ret half %dx.dot
+}
+
+define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_16]]
+; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
+ %dx.dot = call half @llvm.spv.fdot.v3f16(<3 x half> %a, <3 x half> %b)
+ ret half %dx.dot
+}
+
+define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
+; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
+ %dx.dot = call half @llvm.spv.fdot.v4f16(<4 x half> %a, <4 x half> %b)
+ ret half %dx.dot
+}
+
+define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_32]]
+; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
+ %dx.dot = call float @llvm.spv.fdot.v2f32(<2 x float> %a, <2 x float> %b)
+ ret float %dx.dot
+}
+
+define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_32]]
+; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
+ %dx.dot = call float @llvm.spv.fdot.v3f32(<3 x float> %a, <3 x float> %b)
+ ret float %dx.dot
+}
+
+define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
+ %dx.dot = call float @llvm.spv.fdot.v4f32(<4 x float> %a, <4 x float> %b)
+ ret float %dx.dot
+}
+
+declare half @llvm.spv.fdot.v2f16(<2 x half> , <2 x half> )
+declare half @llvm.spv.fdot.v3f16(<3 x half> , <3 x half> )
+declare half @llvm.spv.fdot.v4f16(<4 x half> , <4 x half> )
+declare float @llvm.spv.fdot.v2f32(<2 x float>, <2 x float>)
+declare float @llvm.spv.fdot.v3f32(<3 x float>, <3 x float>)
+declare float @llvm.spv.fdot.v4f32(<4 x float>, <4 x float>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
new file mode 100644
index 00000000000000..22b6ed6bdfcbc5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
@@ -0,0 +1,88 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+
+; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
+; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
+; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
+; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32
+; CHECK-DAG: %[[#vec4_int_32:]] = OpTypeVector %[[#int_32]] 4
+; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
+; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2
+
+define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
+; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
+ %dot = call i16 @llvm.spv.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
+ ret i16 %dot
+}
+
+define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
+; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
+; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
+; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
+; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
+; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
+ %dot = call i32 @llvm.spv.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
+ ret i32 %dot
+}
+
+define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
+; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
+; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
+; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
+ %dot = call i16 @llvm.spv.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
+ ret i16 %dot
+}
+
+define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
+; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
+; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
+; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
+; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
+; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
+ %dot = call i32 @llvm.spv.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
+ ret i32 %dot
+}
+
+define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
+; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
+; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
+; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
+ %dot = call i64 @llvm.spv.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
+ ret i64 %dot
+}
+
+declare i16 @llvm.spv.sdot.v2i16(<2 x i16>, <2 x i16>)
+declare i32 @llvm.spv.sdot.v4i32(<4 x i32>, <4 x i32>)
+declare i16 @llvm.spv.udot.v3i32(<3 x i16>, <3 x i16>)
+declare i32 @llvm.spv.udot.v4i32(<4 x i32>, <4 x i32>)
+declare i64 @llvm.spv.udot.v2i64(<2 x i64>, <2 x i64>)
>From 9cf27a4d8b1da0e7b51eacb9fb6096155c294d3f Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Fri, 16 Aug 2024 19:13:06 -0600
Subject: [PATCH 2/5] Correct fdot test RUN line
---
llvm/test/CodeGen/DirectX/fdot.ll | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
index a13406e4c88611..aa1b15972e266d 100644
--- a/llvm/test/CodeGen/DirectX/fdot.ll
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -1,5 +1,5 @@
; RUN: opt -S -dxil-intrinsic-expansion -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
; Make sure dxil operation function calls for dot are generated for float type vectors.
>From 9b0e613bb7b150ae143897c513a37c619e4ec8f8 Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Mon, 19 Aug 2024 12:45:28 -0600
Subject: [PATCH 3/5] Repond to feedback
Add SPIRV checks to dot.hlsl codegen test. This required some
reorganizing of the test to capture the interchange format early
and removing the dependence on argument position
Reluctantly conform iterator variable to style guide
Add float dot selector for SPIRV with asserts
---
clang/test/CodeGenHLSL/builtins/dot.hlsl | 143 ++++++++++--------
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 6 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 37 ++++-
3 files changed, 109 insertions(+), 77 deletions(-)
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 6d0cf41f4d98bd..2b76fae61147b4 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -1,161 +1,172 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
-// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: --check-prefixes=CHECK,DXCHECK,NATIVE_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
-// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,DXCHECK,NO_HALF
-#ifdef __HLSL_ENABLE_16_BIT
-// NATIVE_HALF: %hlsl.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %hlsl.dot
-int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
-
-// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %hlsl.dot
-int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
-
-// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %hlsl.dot
-int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
-
-// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %hlsl.dot
-int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
-
-// NATIVE_HALF: %hlsl.dot = mul i16 %0, %1
-// NATIVE_HALF: ret i16 %hlsl.dot
-uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
-
-// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
-// NATIVE_HALF: ret i16 %hlsl.dot
-uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
-
-// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
-// NATIVE_HALF: ret i16 %hlsl.dot
-uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,SPVCHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,SPVCHECK,NO_HALF
-// NATIVE_HALF: %hlsl.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
-// NATIVE_HALF: ret i16 %hlsl.dot
-uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
-#endif
-// CHECK: %hlsl.dot = mul i32 %0, %1
+// CHECK: %hlsl.dot = mul i32
// CHECK: ret i32 %hlsl.dot
int test_dot_int(int p0, int p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// Capture the expected interchange format so not every check needs to be duplicated
+// DXCHECK: %hlsl.dot = call i32 @llvm.[[ICF:dx]].sdot.v2i32(<2 x i32>
+// SPVCHECK: %hlsl.dot = call i32 @llvm.[[ICF:spv]].sdot.v2i32(<2 x i32>
// CHECK: ret i32 %hlsl.dot
int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: %hlsl.dot = call i32 @llvm.[[ICF]].sdot.v3i32(<3 x i32>
// CHECK: ret i32 %hlsl.dot
int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: %hlsl.dot = call i32 @llvm.[[ICF]].sdot.v4i32(<4 x i32>
// CHECK: ret i32 %hlsl.dot
int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = mul i32 %0, %1
+// CHECK: %hlsl.dot = mul i32
// CHECK: ret i32 %hlsl.dot
uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: %hlsl.dot = call i32 @llvm.[[ICF]].udot.v2i32(<2 x i32>
// CHECK: ret i32 %hlsl.dot
uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: %hlsl.dot = call i32 @llvm.[[ICF]].udot.v3i32(<3 x i32>
// CHECK: ret i32 %hlsl.dot
uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: %hlsl.dot = call i32 @llvm.[[ICF]].udot.v4i32(<4 x i32>
// CHECK: ret i32 %hlsl.dot
uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = mul i64 %0, %1
+// CHECK: %hlsl.dot = mul i64
// CHECK: ret i64 %hlsl.dot
int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: %hlsl.dot = call i64 @llvm.[[ICF]].sdot.v2i64(<2 x i64>
// CHECK: ret i64 %hlsl.dot
int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: %hlsl.dot = call i64 @llvm.[[ICF]].sdot.v3i64(<3 x i64>
// CHECK: ret i64 %hlsl.dot
int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: %hlsl.dot = call i64 @llvm.[[ICF]].sdot.v4i64(<4 x i64>
// CHECK: ret i64 %hlsl.dot
int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = mul i64 %0, %1
+// CHECK: %hlsl.dot = mul i64
// CHECK: ret i64 %hlsl.dot
uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: %hlsl.dot = call i64 @llvm.[[ICF]].udot.v2i64(<2 x i64>
// CHECK: ret i64 %hlsl.dot
uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: %hlsl.dot = call i64 @llvm.[[ICF]].udot.v3i64(<3 x i64>
// CHECK: ret i64 %hlsl.dot
uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: %hlsl.dot = call i64 @llvm.[[ICF]].udot.v4i64(<4 x i64>
// CHECK: ret i64 %hlsl.dot
uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %hlsl.dot = fmul half %0, %1
+#ifdef __HLSL_ENABLE_16_BIT
+// NATIVE_HALF: %hlsl.dot = mul i16
+// NATIVE_HALF: ret i16 %hlsl.dot
+int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
+
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.[[ICF]].sdot.v2i16(<2 x i16>
+// NATIVE_HALF: ret i16 %hlsl.dot
+int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
+
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.[[ICF]].sdot.v3i16(<3 x i16>
+// NATIVE_HALF: ret i16 %hlsl.dot
+int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
+
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.[[ICF]].sdot.v4i16(<4 x i16>
+// NATIVE_HALF: ret i16 %hlsl.dot
+int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
+
+// NATIVE_HALF: %hlsl.dot = mul i16
+// NATIVE_HALF: ret i16 %hlsl.dot
+uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
+
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.[[ICF]].udot.v2i16(<2 x i16>
+// NATIVE_HALF: ret i16 %hlsl.dot
+uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
+
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.[[ICF]].udot.v3i16(<3 x i16>
+// NATIVE_HALF: ret i16 %hlsl.dot
+uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
+
+// NATIVE_HALF: %hlsl.dot = call i16 @llvm.[[ICF]].udot.v4i16(<4 x i16>
+// NATIVE_HALF: ret i16 %hlsl.dot
+uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
+#endif
+
+// NATIVE_HALF: %hlsl.dot = fmul half
// NATIVE_HALF: ret half %hlsl.dot
-// NO_HALF: %hlsl.dot = fmul float %0, %1
+// NO_HALF: %hlsl.dot = fmul float
// NO_HALF: ret float %hlsl.dot
half test_dot_half(half p0, half p1) { return dot(p0, p1); }
-// NATIVE_HALF: %hlsl.dot = call half @llvm.dx.fdot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: %hlsl.dot = call half @llvm.[[ICF]].fdot.v2f16(<2 x half>
// NATIVE_HALF: ret half %hlsl.dot
-// NO_HALF: %hlsl.dot = call float @llvm.dx.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: %hlsl.dot = call float @llvm.[[ICF]].fdot.v2f32(<2 x float>
// NO_HALF: ret float %hlsl.dot
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %hlsl.dot = call half @llvm.dx.fdot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: %hlsl.dot = call half @llvm.[[ICF]].fdot.v3f16(<3 x half>
// NATIVE_HALF: ret half %hlsl.dot
-// NO_HALF: %hlsl.dot = call float @llvm.dx.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: %hlsl.dot = call float @llvm.[[ICF]].fdot.v3f32(<3 x float>
// NO_HALF: ret float %hlsl.dot
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %hlsl.dot = call half @llvm.dx.fdot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: %hlsl.dot = call half @llvm.[[ICF]].fdot.v4f16(<4 x half>
// NATIVE_HALF: ret half %hlsl.dot
-// NO_HALF: %hlsl.dot = call float @llvm.dx.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: %hlsl.dot = call float @llvm.[[ICF]].fdot.v4f32(<4 x float>
// NO_HALF: ret float %hlsl.dot
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = fmul float %0, %1
+// CHECK: %hlsl.dot = fmul float
// CHECK: ret float %hlsl.dot
float test_dot_float(float p0, float p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: %hlsl.dot = call float @llvm.[[ICF]].fdot.v2f32(<2 x float>
// CHECK: ret float %hlsl.dot
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: %hlsl.dot = call float @llvm.[[ICF]].fdot.v3f32(<3 x float>
// CHECK: ret float %hlsl.dot
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: %hlsl.dot = call float @llvm.[[ICF]].fdot.v4f32(<4 x float>
// CHECK: ret float %hlsl.dot
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: %hlsl.dot = call float @llvm.[[ICF]].fdot.v2f32(<2 x float> %splat.splat, <2 x float>
// CHECK: ret float %hlsl.dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: %hlsl.dot = call float @llvm.[[ICF]].fdot.v3f32(<3 x float> %splat.splat, <3 x float>
// CHECK: ret float %hlsl.dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = call float @llvm.dx.fdot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: %hlsl.dot = call float @llvm.[[ICF]].fdot.v4f32(<4 x float> %splat.splat, <4 x float>
// CHECK: ret float %hlsl.dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %hlsl.dot = fmul double %0, %1
+// CHECK: %hlsl.dot = fmul double
// CHECK: ret double %hlsl.dot
double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 8316c040580be1..f54d1e07754dd6 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -125,9 +125,9 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,
Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
Result = Builder.CreateMul(Elt0, Elt1);
- for (unsigned i = 1; i < AVec->getNumElements(); i++) {
- Elt0 = Builder.CreateExtractElement(A, i);
- Elt1 = Builder.CreateExtractElement(B, i);
+ for (unsigned I = 1; I < AVec->getNumElements(); I++) {
+ Elt0 = Builder.CreateExtractElement(A, I);
+ Elt1 = Builder.CreateExtractElement(B, I);
Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
ArrayRef<Value *>{Elt0, Elt1, Result},
nullptr, "dx.mad");
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 730a2be19ba081..a33dbbc753b429 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -184,6 +184,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectFloatDot(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1449,8 +1452,31 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
-// Since there is no integer dot implementation, expand by piecewise multiplying
-// and adding the results, making use of FMA operations where possible.
+// Select the OpDot instruction for the given float dot
+bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ assert(I.getNumOperands() == 4);
+ assert(I.getOperand(2).isReg());
+ assert(I.getOperand(3).isReg());
+
+ [[maybe_unused]] SPIRVType *VecType =
+ GR.getSPIRVTypeForVReg(I.getOperand(2).getReg());
+
+ assert(GR.getScalarOrVectorComponentCount(VecType) > 1 &&
+ "dot product requires a vector of at least 2 components");
+
+ MachineBasicBlock &BB = *I.getParent();
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(2).getReg())
+ .addUse(I.getOperand(3).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
+// Since pre-1.6 SPIRV has no integer dot implementation, expand by piecewise
+// multiplying and adding the results
bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2222,12 +2248,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_thread_id:
return selectSpvThreadId(ResVReg, ResType, I);
case Intrinsic::spv_fdot:
- return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(I.getOperand(2).getReg())
- .addUse(I.getOperand(3).getReg())
- .constrainAllUses(TII, TRI, RBI);
+ return selectFloatDot(ResVReg, ResType, I);
case Intrinsic::spv_udot:
case Intrinsic::spv_sdot:
return selectIntegerDot(ResVReg, ResType, I);
>From 22fca601ae0cb608b1efd50ef7e58483d0432196 Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Mon, 19 Aug 2024 13:15:44 -0600
Subject: [PATCH 4/5] Leverage dot product expansion in normalize expansion
Since normalize requires a dot product operation, changing the signature
so that it can be called from within another operation expansion allows
sharing of code.
---
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 42 ++++++-------------
1 file changed, 12 insertions(+), 30 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index f54d1e07754dd6..27e44988e5f92a 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -70,9 +70,7 @@ static Value *expandAbs(CallInst *Orig) {
}
// Create DXIL dot intrinsics for floating point dot operations
-static Value *expandFloatDotIntrinsic(CallInst *Orig) {
- Value *A = Orig->getOperand(0);
- Value *B = Orig->getOperand(1);
+static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
Type *ATy = A->getType();
[[maybe_unused]] Type *BTy = B->getType();
assert(ATy->isVectorTy() && BTy->isVectorTy());
@@ -95,7 +93,10 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig) {
DotIntrinsic = Intrinsic::dx_dot4;
break;
default:
- llvm_unreachable("dot product with vector outside 2-4 range");
+ report_fatal_error(
+ Twine("Invalid dot product input vector: length is outside 2-4"),
+ /* gen_crash_diag=*/false);
+ return nullptr;
}
return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
ArrayRef<Value *>{A, B}, nullptr, "dot");
@@ -249,6 +250,8 @@ static Value *expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}
+// Use dot product of vector operand with itself to calculate the length.
+// Divide the vector by that length to normalize it.
static Value *expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
@@ -267,30 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
return Builder.CreateFDiv(X, X);
}
- unsigned XVecSize = XVec->getNumElements();
- Value *DotProduct = nullptr;
- // use the dot intrinsic corresponding to the vector size
- switch (XVecSize) {
- case 1:
- report_fatal_error(Twine("Invalid input vector: length is zero"),
- /* gen_crash_diag=*/false);
- break;
- case 2:
- DotProduct = Builder.CreateIntrinsic(
- EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
- break;
- case 3:
- DotProduct = Builder.CreateIntrinsic(
- EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
- break;
- case 4:
- DotProduct = Builder.CreateIntrinsic(
- EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
- break;
- default:
- report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
- /* gen_crash_diag=*/false);
- }
+ Value *DotProduct = expandFloatDotIntrinsic(Orig, X, X);
// verify that the length is non-zero
// (if the dot product is non-zero, then the length is non-zero)
@@ -305,7 +285,8 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
ArrayRef<Value *>{DotProduct},
nullptr, "dx.rsqrt");
- Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+ Value *MultiplicandVec =
+ Builder.CreateVectorSplat(XVec->getNumElements(), Multiplicand);
return Builder.CreateFMul(X, MultiplicandVec);
}
@@ -402,7 +383,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
Result = expandNormalizeIntrinsic(Orig);
break;
case Intrinsic::dx_fdot:
- Result = expandFloatDotIntrinsic(Orig);
+ Result =
+ expandFloatDotIntrinsic(Orig, Orig->getOperand(0), Orig->getOperand(1));
break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
>From 1cf4842555a64f4102437b430f5190601f53810b Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Mon, 19 Aug 2024 13:24:10 -0600
Subject: [PATCH 5/5] Small comment clarifications
---
llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp | 1 +
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 4 ++--
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 27e44988e5f92a..f9f51fa30e344e 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -70,6 +70,7 @@ static Value *expandAbs(CallInst *Orig) {
}
// Create DXIL dot intrinsics for floating point dot operations
+// placing the dot product of A and B values in the position indicated by Orig
static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
Type *ATy = A->getType();
[[maybe_unused]] Type *BTy = B->getType();
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index a33dbbc753b429..4d2a14b3be42e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1475,8 +1475,8 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
-// Since pre-1.6 SPIRV has no integer dot implementation, expand by piecewise
-// multiplying and adding the results
+// Since pre-1.6 SPIRV has no integer dot implementation,
+// expand by piecewise multiplying and adding the results
bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
More information about the llvm-commits
mailing list