[clang] [llvm] Add Float `Dot` Intrinsic Lowering (PR #86071)
via cfe-commits
cfe-commits at lists.llvm.org
Wed Mar 20 20:20:25 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-llvm-ir
Author: Farzon Lotfi (farzonl)
<details>
<summary>Changes</summary>
Completes #<!-- -->83626
- `CGBuiltin.cpp` - modify `getDotProductIntrinsic` to be able to emit `dot2`, `dot3`, and `dot4` intrinsics based on element count
- `IntrinsicsDirectX.td` - for floating point add `dot2`, `dot3`, and `dot4` inntrinsics -`DXIL.td` add dxilop intrinsic lowering for `dot2`, `dot3`, & `dot4`.
- `DXILOpLowering.cpp` - add vector arg flattening for dot product.
- `DXILOpBuilder.h` - modify `createDXILOpCall` to take a smallVector instead of an iterator
- `DXILOpBuilder.cpp` - modify `createDXILOpCall` by moving the small vector up to the calling function in `DXILOpLowering.cpp`.
- Moving one function up gives us access to the `CallInst` and `Function` which were needed to distinguish the dot product intrinsics and get the operands without using the iterator.
---
Patch is 20.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86071.diff
11 Files Affected:
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16-9)
- (modified) clang/test/CodeGenHLSL/builtins/dot.hlsl (+14-14)
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+9-1)
- (modified) llvm/lib/Target/DirectX/DXIL.td (+9)
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+3-5)
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+2-3)
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+53-2)
- (added) llvm/test/CodeGen/DirectX/dot2_error.ll (+10)
- (added) llvm/test/CodeGen/DirectX/dot3_error.ll (+10)
- (added) llvm/test/CodeGen/DirectX/dot4_error.ll (+10)
- (added) llvm/test/CodeGen/DirectX/fdot.ll (+94)
``````````diff
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 77cb269d43c5a8..a4b99181769326 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18036,15 +18036,22 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}
-Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
+ if (QT->hasFloatingRepresentation()) {
+ switch (elementCount) {
+ case 2:
+ return Intrinsic::dx_dot2;
+ case 3:
+ return Intrinsic::dx_dot3;
+ case 4:
+ return Intrinsic::dx_dot4;
+ }
+ }
if (QT->hasSignedIntegerRepresentation())
return Intrinsic::dx_sdot;
- if (QT->hasUnsignedIntegerRepresentation())
- return Intrinsic::dx_udot;
- assert(QT->hasFloatingRepresentation());
- return Intrinsic::dx_dot;
- ;
+ assert(QT->hasUnsignedIntegerRepresentation());
+ return Intrinsic::dx_udot;
}
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18098,8 +18105,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");
- [[maybe_unused]] auto *VecTy0 =
- E->getArg(0)->getType()->getAs<VectorType>();
+ auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
@@ -18108,7 +18114,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return Builder.CreateIntrinsic(
/*ReturnType=*/T0->getScalarType(),
- getDotProductIntrinsic(E->getArg(0)->getType()),
+ getDotProductIntrinsic(E->getArg(0)->getType(),
+ VecTy0->getNumElements()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 0f993193c00cce..307d71cce3cb6d 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -110,21 +110,21 @@ uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
// NO_HALF: ret float %dx.dot
half test_dot_half(half p0, half p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
@@ -132,34 +132,34 @@ half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
// CHECK: ret float %dx.dot
float test_dot_float(float p0, float p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
// CHECK: %conv = sitofp i32 %1 to float
// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %splat.splat)
// CHECK: ret float %dx.dot
float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
return dot(p0, p1);
@@ -168,7 +168,7 @@ float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
// CHECK: %conv = sitofp i32 %1 to float
// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %splat.splat)
// CHECK: ret float %dx.dot
float test_builtin_dot_float3_int_splat(float3 p0, int p1) {
return dot(p0, p1);
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 1164b241ba7b0d..a871fac46b9fd5 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -24,7 +24,15 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
-def int_dx_dot :
+def int_dx_dot2 :
+ Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_dot3 :
+ Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_dot4 :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 36eb29d53766f0..9e393902e2f208 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -295,6 +295,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
"Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
"Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
+def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1",
+ [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
+def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2",
+ [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
+def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3",
+ [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
"Reads the thread ID">;
def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index a1eacc2d48009c..990557710f8c53 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -254,7 +254,7 @@ namespace dxil {
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
Type *OverloadTy,
- llvm::iterator_range<Use *> Args) {
+ SmallVector<Value *> Args) {
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
OverloadKind Kind = getOverloadKind(OverloadTy);
@@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
}
- SmallVector<Value *> FullArgs;
- FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
- FullArgs.append(Args.begin(), Args.end());
- return B.CreateCall(DXILFn, FullArgs);
+
+ return B.CreateCall(DXILFn, Args);
}
Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index f3abcc6e02a4e3..5babeae470178b 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -13,7 +13,7 @@
#define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H
#include "DXILConstants.h"
-#include "llvm/ADT/iterator_range.h"
+#include "llvm/ADT/SmallVector.h"
namespace llvm {
class Module;
@@ -35,8 +35,7 @@ class DXILOpBuilder {
/// \param OverloadTy Overload type of the DXIL Op call constructed
/// \return DXIL Op call constructed
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
- Type *OverloadTy,
- llvm::iterator_range<Use *> Args);
+ Type *OverloadTy, SmallVector<Value *> Args);
Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
static const char *getOpCodeName(dxil::OpCode DXILOp);
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 3e334b0ec298d3..f09e322f88e1fd 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -30,6 +30,48 @@
using namespace llvm;
using namespace llvm::dxil;
+static bool isVectorArgExpansion(Function &F) {
+ switch (F.getIntrinsicID()) {
+ case Intrinsic::dx_dot2:
+ case Intrinsic::dx_dot3:
+ case Intrinsic::dx_dot4:
+ return true;
+ }
+ return false;
+}
+
+static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
+ SmallVector<Value *, 4> ExtractedElements;
+ auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
+ Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
+ Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
+ ExtractedElements.push_back(ExtractedElement);
+ }
+ return ExtractedElements;
+}
+
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+ IRBuilder<> &Builder) {
+ // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ unsigned NumOperands = Orig->getNumOperands() - 1;
+ assert(NumOperands > 0);
+ Value *Arg0 = Orig->getOperand(0);
+ [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+ assert(VecArg0);
+ SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+ for (unsigned I = 1; I < NumOperands; ++I) {
+ Value *Arg = Orig->getOperand(I);
+ [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ assert(VecArg);
+ assert(VecArg0->getElementType() == VecArg->getElementType());
+ assert(VecArg0->getNumElements() == VecArg->getNumElements());
+ auto NextOperandList = populateOperands(Arg, Builder);
+ NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+ }
+ return NewOperands;
+}
+
static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
IRBuilder<> B(M.getContext());
DXILOpBuilder DXILB(M, B);
@@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
if (!CI)
continue;
+ SmallVector<Value *> Args;
+ Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
+ Args.emplace_back(DXILOpArg);
B.SetInsertPoint(CI);
- CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
- OverloadTy, CI->args());
+ if (isVectorArgExpansion(F)) {
+ SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
+ Args.append(NewArgs.begin(), NewArgs.end());
+ } else
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ CallInst *DXILCI =
+ DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
CI->replaceAllUsesWith(DXILCI);
CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll
new file mode 100644
index 00000000000000..a27bfaedacd573
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot2 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
+entry:
+ %dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
+ ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll
new file mode 100644
index 00000000000000..eb69fb145038aa
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot3_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot3 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) {
+entry:
+ %dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b)
+ ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll
new file mode 100644
index 00000000000000..5cd632684c0c01
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot4_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot4 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) {
+entry:
+ %dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b)
+ ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
new file mode 100644
index 00000000000000..3e13b2ad2650c8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -0,0 +1,94 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+
+; CHECK-LABEL: dot_half2
+define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
+entry:
+; CHECK: extractelement <2 x half> %a, i32 0
+; CHECK: extractelement <2 x half> %a, i32 1
+; CHECK: extractelement <2 x half> %b, i32 0
+; CHECK: extractelement <2 x half> %b, i32 1
+; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
+ ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_half3
+define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
+entry:
+; CHECK: extractelement <3 x half> %a, i32 0
+; CHECK: extractelement <3 x half> %a, i32 1
+; CHECK: extractelement <3 x half> %a, i32 2
+; CHECK: extractelement <3 x half> %b, i32 0
+; CHECK: extractelement <3 x half> %b, i32 1
+; CHECK: extractelement <3 x half> %b, i32 2
+; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
+ ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_half4
+define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
+entry:
+; CHECK: extractelement <4 x half> %a, i32 0
+; CHECK: extractelement <4 x half> %a, i32 1
+; CHECK: extractelement <4 x half> %a, i32 2
+; CHECK: extractelement <4 x half> %a, i32 3
+; CHECK: extractelement <4 x half> %b, i32 0
+; CHECK: extractelement <4 x half> %b, i32 1
+; CHECK: extractelement <4 x half> %b, i32 2
+; CHECK: extractelement <4 x half> %b, i32 3
+; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
+ ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_float2
+define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
+entry:
+; CHECK: extractelement <2 x float> %a, i32 0
+; CHECK: extractelement <2 x float> %a, i32 1
+; CHECK: extractelement <2 x float> %b, i32 0
+; CHECK: extractelement <2 x float> %b, i32 1
+; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+ %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
+ ret float %dx.dot
+}
+
+; CHECK-LABEL: dot_float3
+define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
+entry:
+; CHECK: extractelement <3 x float> %a, i32 0
+; CHECK: extractelement <3 x float> %a, i32 1
+; CHECK: extractelement <3 x float> %a, i32 2
+; CHECK: extractelement <3 x float> %b, i32 0
+; CHECK: extractelement <3 x float> %b, i32 1
+; CHECK: extractelement <3 x float> %b, i32 2
+; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+ %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
+ ret float %dx.dot
+}
+
+; CHECK-LABEL: dot_float4
+define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+; CHECK: extractelement <4 x float> %a, i32 0
+; CHECK: extractelement <4 x float> %a, i32 1
+; CHECK: extractelement <4 x float> %a, i32 2
+; CHECK: extractelement <4 x float> %a, i32 3
+; CHECK: extractelement <4 x float> %b, i32 0
+; CHECK: extractelement <4 x fl...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/86071
More information about the cfe-commits
mailing list