[clang] [llvm] [DXIL] implement dot intrinsic lowering for integers (PR #85662)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Mar 18 09:41:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-llvm-ir
Author: Farzon Lotfi (farzonl)
<details>
<summary>Changes</summary>
this implements part 1 of 2 for #<!-- -->83626
- `CGBuiltin.cpp` - modified to have seperate cases for signed and unsigned integers.
- `SemaChecking.cpp` - modified to prevent the generation of a double dot product intrinsic if the builtin were to be called directly.
- `IntrinsicsDirectX.td` creation of the signed and unsigned dot intrinsics needed for instruction expansion.
- `DXILIntrinsicExpansion.cpp` - handle instruction expansion cases for integer dot product.
---
Full diff: https://github.com/llvm/llvm-project/pull/85662.diff
7 Files Affected:
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+13-1)
- (modified) clang/lib/Sema/SemaChecking.cpp (+15)
- (modified) clang/test/CodeGenHLSL/builtins/dot.hlsl (+18-18)
- (modified) clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl (+9)
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+10-1)
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+36)
- (added) llvm/test/CodeGen/DirectX/idot.ll (+100)
``````````diff
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e965df810add54..e89691ab7921c3 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18036,6 +18036,17 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}
+Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+ if (QT->hasSignedIntegerRepresentation())
+ return Intrinsic::dx_sdot;
+ if (QT->hasUnsignedIntegerRepresentation())
+ return Intrinsic::dx_udot;
+
+ assert(QT->hasFloatingRepresentation());
+ return Intrinsic::dx_dot;
+ ;
+}
+
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
if (!getLangOpts().HLSL)
@@ -18096,7 +18107,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
"Dot product requires vectors to be of the same size.");
return Builder.CreateIntrinsic(
- /*ReturnType=*/T0->getScalarType(), Intrinsic::dx_dot,
+ /*ReturnType=*/T0->getScalarType(),
+ getDotProductIntrinsic(E->getArg(0)->getType()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a0b256ab5579ee..384b929d37bc82 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5484,6 +5484,19 @@ bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
checkFloatorHalf);
}
+bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
+ auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
+ if (PassedType->isVectorType() && PassedType->hasFloatingRepresentation()) {
+ clang::QualType BaseType =
+ PassedType->getAs<clang::VectorType>()->getElementType();
+ return !BaseType->isHalfType() && !BaseType->isFloat32Type();
+ }
+ return false;
+ };
+ return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+ checkDoubleVector);
+}
+
void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
QualType ReturnType) {
auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
@@ -5520,6 +5533,8 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
if (SemaBuiltinVectorToScalarMath(TheCall))
return true;
+ if (CheckNoDoubleVectors(this, TheCall))
+ return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index c064d118caf3e7..0f993193c00cce 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -11,15 +11,15 @@
// NATIVE_HALF: ret i16 %dx.dot
int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
// NATIVE_HALF: ret i16 %dx.dot
int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
// NATIVE_HALF: ret i16 %dx.dot
int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
// NATIVE_HALF: ret i16 %dx.dot
int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
@@ -27,15 +27,15 @@ int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
// NATIVE_HALF: ret i16 %dx.dot
uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
// NATIVE_HALF: ret i16 %dx.dot
uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
// NATIVE_HALF: ret i16 %dx.dot
uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
// NATIVE_HALF: ret i16 %dx.dot
uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
#endif
@@ -44,15 +44,15 @@ uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
// CHECK: ret i32 %dx.dot
int test_dot_int(int p0, int p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
// CHECK: ret i32 %dx.dot
int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
// CHECK: ret i32 %dx.dot
int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
// CHECK: ret i32 %dx.dot
int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
@@ -60,15 +60,15 @@ int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
// CHECK: ret i32 %dx.dot
uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
// CHECK: ret i32 %dx.dot
uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
// CHECK: ret i32 %dx.dot
uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
// CHECK: ret i32 %dx.dot
uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
@@ -76,15 +76,15 @@ uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
// CHECK: ret i64 %dx.dot
int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
// CHECK: ret i64 %dx.dot
int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
// CHECK: ret i64 %dx.dot
int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
// CHECK: ret i64 %dx.dot
int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
@@ -92,15 +92,15 @@ int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
// CHECK: ret i64 %dx.dot
uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
// CHECK: ret i64 %dx.dot
uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
// CHECK: ret i64 %dx.dot
uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
// CHECK: ret i64 %dx.dot
uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
index 59eb9482b9ef92..ba7ffc20484ae0 100644
--- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
@@ -108,3 +108,12 @@ int test_builtin_dot_bool_type_promotion(bool p0, bool p1) {
return __builtin_hlsl_dot(p0, p1);
// expected-error at -1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
}
+
+double test_dot_double(double2 p0, double2 p1) {
+ return dot(p0, p1);
+ // expected-error at -1 {{call to 'dot' is ambiguous}}
+}
+double test_dot_double_builtin(double2 p0, double2 p1) {
+ return __builtin_hlsl_dot(p0, p1);
+ // expected-error at -1 {{passing 'double2' (aka 'vector<double, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 5c72f06f96ed12..1164b241ba7b0d 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -23,9 +23,18 @@ def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+
def int_dx_dot :
Intrinsic<[LLVMVectorElementType<0>],
- [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_sdot :
+ Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_udot :
+ Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index bc38c10a1fceb0..0db42bc0a0fb64 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -39,11 +39,44 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_rcp:
+ case Intrinsic::dx_sdot:
+ case Intrinsic::dx_udot:
return true;
}
return false;
}
+static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+ assert(DotIntrinsic == Intrinsic::dx_sdot ||
+ DotIntrinsic == Intrinsic::dx_udot);
+ Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
+ ? Intrinsic::dx_imad
+ : Intrinsic::dx_umad;
+ Value *A = Orig->getOperand(0);
+ Value *B = Orig->getOperand(1);
+ Type *ATy = A->getType();
+ Type *BTy = B->getType();
+ assert(ATy->isVectorTy() && BTy->isVectorTy());
+
+ IRBuilder<> Builder(Orig->getParent());
+ Builder.SetInsertPoint(Orig);
+
+ auto *AVec = dyn_cast<FixedVectorType>(A->getType());
+ Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
+ Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
+ Value *Result = Builder.CreateMul(Elt0, Elt1);
+ for (unsigned I = 1; I < AVec->getNumElements(); I++) {
+ Elt0 = Builder.CreateExtractElement(A, I);
+ Elt1 = Builder.CreateExtractElement(B, I);
+ Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
+ ArrayRef<Value *>{Elt0, Elt1, Result},
+ nullptr, "dx.mad");
+ }
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+}
+
static bool expandExpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
@@ -191,6 +224,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_rcp:
return expandRcpIntrinsic(Orig);
+ case Intrinsic::dx_sdot:
+ case Intrinsic::dx_udot:
+ return expandIntegerDot(Orig, F.getIntrinsicID());
}
return false;
}
diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll
new file mode 100644
index 00000000000000..286bef3f655116
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/idot.ll
@@ -0,0 +1,100 @@
+; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+
+; CHECK-LABEL: dot_int16_t2
+define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
+entry:
+; CHECK: extractelement <2 x i16> %a, i64 0
+; CHECK: extractelement <2 x i16> %b, i64 0
+; CHECK: mul i16 %{{.*}}, %{{.*}}
+; CHECK: extractelement <2 x i16> %a, i64 1
+; CHECK: extractelement <2 x i16> %b, i64 1
+; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+ %dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
+ ret i16 %dx.dot
+}
+
+; CHECK-LABEL: sdot_int4
+define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+entry:
+; CHECK: extractelement <4 x i32> %a, i64 0
+; CHECK: extractelement <4 x i32> %b, i64 0
+; CHECK: mul i32 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i32> %a, i64 1
+; CHECK: extractelement <4 x i32> %b, i64 1
+; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK: extractelement <4 x i32> %a, i64 2
+; CHECK: extractelement <4 x i32> %b, i64 2
+; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK: extractelement <4 x i32> %a, i64 3
+; CHECK: extractelement <4 x i32> %b, i64 3
+; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
+ ret i32 %dx.dot
+}
+
+; CHECK-LABEL: dot_uint16_t3
+define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
+entry:
+; CHECK: extractelement <3 x i16> %a, i64 0
+; CHECK: extractelement <3 x i16> %b, i64 0
+; CHECK: mul i16 %{{.*}}, %{{.*}}
+; CHECK: extractelement <3 x i16> %a, i64 1
+; CHECK: extractelement <3 x i16> %b, i64 1
+; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; CHECK: extractelement <3 x i16> %a, i64 2
+; CHECK: extractelement <3 x i16> %b, i64 2
+; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
+ %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
+ ret i16 %dx.dot
+}
+
+; CHECK-LABEL: dot_uint4
+define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+entry:
+; CHECK: extractelement <4 x i32> %a, i64 0
+; CHECK: extractelement <4 x i32> %b, i64 0
+; CHECK: mul i32 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i32> %a, i64 1
+; CHECK: extractelement <4 x i32> %b, i64 1
+; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK: extractelement <4 x i32> %a, i64 2
+; CHECK: extractelement <4 x i32> %b, i64 2
+; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; CHECK: extractelement <4 x i32> %a, i64 3
+; CHECK: extractelement <4 x i32> %b, i64 3
+; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
+ ret i32 %dx.dot
+}
+
+; CHECK-LABEL: dot_uint64_t4
+define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
+entry:
+; CHECK: extractelement <2 x i64> %a, i64 0
+; CHECK: extractelement <2 x i64> %b, i64 0
+; CHECK: mul i64 %{{.*}}, %{{.*}}
+; CHECK: extractelement <2 x i64> %a, i64 1
+; CHECK: extractelement <2 x i64> %b, i64 1
+; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
+; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
+ %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
+ ret i64 %dx.dot
+}
+
+declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i32>)
+declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>)
+declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i32>)
+declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>)
+declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/85662
More information about the cfe-commits
mailing list