[clang] [llvm] [HLSL] Implement dot2add intrinsic (PR #131237)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 20 15:16:51 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Sumit Agarwal (sumitsays)
<details>
<summary>Changes</summary>
Resolves #<!-- -->99221
Key points: For SPIRV backend, it decompose into a `dot` followed a `add`.
- [x] Implement dot2add clang builtin,
- [x] Link dot2add clang builtin with hlsl_intrinsics.h
- [x] Add sema checks for dot2add to CheckHLSLBuiltinFunctionCall in SemaHLSL.cpp
- [x] Add codegen for dot2add to EmitHLSLBuiltinExpr in CGBuiltin.cpp
- [x] Add codegen tests to clang/test/CodeGenHLSL/builtins/dot2add.hlsl
- [x] Add sema tests to clang/test/SemaHLSL/BuiltIns/dot2add-errors.hlsl
- [x] Create the int_dx_dot2add intrinsic in IntrinsicsDirectX.td
- [x] Create the DXILOpMapping of int_dx_dot2add to 162 in DXIL.td
- [x] Create the dot2add.ll and dot2add_errors.ll tests in llvm/test/CodeGen/DirectX/
- [ ] ~~Create the int_spv_dot2add intrinsic in IntrinsicsSPIRV.td~~ --- Not needed
- [ ] ~~In SPIRVInstructionSelector.cpp create the dot2add lowering and map it to int_spv_dot2add in SPIRVInstructionSelector::selectIntrinsic.~~ --- Not needed
- [ ] ~~Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot2add.ll~~ --- Not needed
---
Full diff: https://github.com/llvm/llvm-project/pull/131237.diff
11 Files Affected:
- (modified) clang/include/clang/Basic/Builtins.td (+6)
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+15)
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+8)
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+12)
- (modified) clang/lib/Sema/SemaHLSL.cpp (+50-9)
- (added) clang/test/CodeGenHLSL/builtins/dot2add.hlsl (+17)
- (added) clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl (+11)
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+4)
- (modified) llvm/lib/Target/DirectX/DXIL.td (+11)
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+47-2)
- (added) llvm/test/CodeGen/DirectX/dot2add.ll (+8)
``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 72a5e495c4059..76ab463ca0ed6 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4891,6 +4891,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_dot2add"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index c126f88b9e3a5..b3d9db5be7d8d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19681,6 +19681,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
}
+ case Builtin::BI__builtin_hlsl_dot2add: {
+ llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch();
+ if (Arch != llvm::Triple::dxil) {
+ llvm_unreachable("Intrinsic dot2add can be executed as a builtin only on dxil");
+ }
+ Value *A = EmitScalarExpr(E->getArg(0));
+ Value *B = EmitScalarExpr(E->getArg(1));
+ Value *C = EmitScalarExpr(E->getArg(2));
+
+ //llvm::Intrinsic::dx_##IntrinsicPostfix
+ Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
+ "hlsl.dot2add");
+ }
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 5f7c047dbf340..46653d7b295b2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -45,6 +45,14 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
return length_vec_impl(X - Y);
}
+constexpr float dot2add_impl(half2 a, half2 b, float c) {
+#if defined(__DIRECTX__)
+ return __builtin_hlsl_dot2add(a, b, c);
+#else
+ return dot(a, b) + c;
+#endif
+}
+
template <typename T> constexpr T reflect_impl(T I, T N) {
return I - 2 * N * I * N;
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5459cbeb34fd0..b1c1335ce3328 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -117,6 +117,18 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X,
return __detail::distance_vec_impl(X, Y);
}
+//===----------------------------------------------------------------------===//
+// dot2add builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn float dot2add(half2 a, half2 b, float c)
+/// \brief Dot product of 2 vector of type half and add a float scalar value.
+
+_HLSL_AVAILABILITY(shadermodel, 6.4)
+const inline float dot2add(half2 a, half2 b, float c) {
+ return __detail::dot2add_impl(a, b, c);
+}
+
//===----------------------------------------------------------------------===//
// fmod builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 36de110e75e8a..399371c4ae2f6 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1989,7 +1989,7 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
}
// Helper function for CheckHLSLBuiltinFunctionCall
-static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall, unsigned NumArgs) {
assert(TheCall->getNumArgs() > 1);
ExprResult A = TheCall->getArg(0);
@@ -1999,7 +1999,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
bool AllBArgAreVectors = true;
- for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
+ for (unsigned i = 1; i < NumArgs; ++i) {
ExprResult B = TheCall->getArg(i);
QualType ArgTyB = B.get()->getType();
auto *VecTyB = ArgTyB->getAs<VectorType>();
@@ -2049,6 +2049,10 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return false;
}
+static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
+ return CheckVectorElementCallArgs(S, TheCall, TheCall->getNumArgs());
+}
+
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() > 1);
QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2091,10 +2095,10 @@ static bool CheckArgTypeIsCorrect(
return false;
}
-static bool CheckAllArgTypesAreCorrect(
- Sema *S, CallExpr *TheCall, QualType ExpectedType,
+static bool CheckArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall, unsigned NumArgs, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
- for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
+ for (unsigned i = 0; i < NumArgs; ++i) {
Expr *Arg = TheCall->getArg(i);
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
@@ -2103,6 +2107,13 @@ static bool CheckAllArgTypesAreCorrect(
return false;
}
+static bool CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall, QualType ExpectedType,
+ llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+ return CheckArgTypesAreCorrect(S, TheCall, TheCall->getNumArgs(),
+ ExpectedType, Check);
+}
+
static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
@@ -2146,15 +2157,17 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
return true;
}
-static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
+static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall,
+ unsigned NumArgs, QualType ExpectedType) {
auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
if (const auto *VecTy = PassedType->getAs<VectorType>())
return VecTy->getElementType()->isDoubleType();
return false;
};
- return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
- checkDoubleVector);
+ return CheckArgTypesAreCorrect(S, TheCall, NumArgs,
+ ExpectedType, checkDoubleVector);
}
+
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasIntegerRepresentation() &&
@@ -2468,8 +2481,36 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
if (SemaRef.BuiltinVectorToScalarMath(TheCall))
return true;
- if (CheckNoDoubleVectors(&SemaRef, TheCall))
+ if (CheckNoDoubleVectors(&SemaRef, TheCall,
+ TheCall->getNumArgs(), SemaRef.Context.FloatTy))
+ return true;
+ break;
+ }
+ case Builtin::BI__builtin_hlsl_dot2add: {
+ // Check number of arguments should be 3
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ // Check first two arguments are vector of length 2 with half data type
+ auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool {
+ if (const auto *VecTy = PassedType->getAs<VectorType>())
+ return !(VecTy->getNumElements() == 2 &&
+ VecTy->getElementType()->isHalfType());
+ return true;
+ };
+ if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(0),
+ SemaRef.getASTContext().HalfTy,
+ checkHalfVectorOfSize2))
+ return true;
+ if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(1),
+ SemaRef.getASTContext().HalfTy,
+ checkHalfVectorOfSize2))
+ return true;
+
+ // Check third argument is a float
+ if (CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), SemaRef.getASTContext().FloatTy))
return true;
+ TheCall->setType(TheCall->getArg(2)->getType());
break;
}
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
new file mode 100644
index 0000000000000..ce325327a01b5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+float test(half2 p1, half2 p2, float p3) {
+ // CHECK-SPIRV: %[[MUL:.*]] = call {{.*}} float @llvm.spv.fdot.v2f32(<2 x float> %1, <2 x float> %2)
+ // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr, align 4
+ // CHECK-SPIRV: %[[RES:.*]] = fadd {{.*}} float %[[MUL]], %[[C]]
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f32(<2 x float> %0, <2 x float> %1, float %2)
+ // CHECK: ret float %[[RES]]
+ return dot2add(p1, p2, p3);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl
new file mode 100644
index 0000000000000..61282a319dafd
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_dot2add();
+ // expected-error at -1 {{too few arguments to function call, expected 3, have 0}}
+}
+
+bool test_too_many_arg(half2 p1, half2 p2, float p3) {
+ return __builtin_hlsl_dot2add(p1, p2, p3, p1);
+ // expected-error at -1 {{too many arguments to function call, expected 3, have 4}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index ead7286f4311c..775d325feeb14 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -100,6 +100,10 @@ def int_dx_udot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
+def int_dx_dot2add :
+ DefaultAttrsIntrinsic<[llvm_float_ty],
+ [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
+ [IntrNoMem, Commutative]>;
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index ebe1d876d58b1..193b592a525a0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1098,6 +1098,17 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
let stages = [Stages<DXIL1_2, [all_stages]>];
}
+def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
+ let Doc = "dot product of 2 vectors of half having size = 2, returns "
+ "float";
+ let intrinsics = [IntrinSelect<int_dx_dot2add>];
+ let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
+ let result = FloatTy;
+ let overloads = [Overloads<DXIL1_0, []>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
"accumulate to i32";
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index dff9f3e03079e..f7ed0c5071d75 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -54,10 +54,36 @@ static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
return ExtractedElements;
}
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+ IRBuilder<> &Builder,
+ unsigned NumOperands) {
+ assert(NumOperands > 0);
+ Value *Arg0 = Orig->getOperand(0);
+ [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+ assert(VecArg0);
+ SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+ for (unsigned I = 1; I < NumOperands; ++I) {
+ Value *Arg = Orig->getOperand(I);
+ [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ assert(VecArg);
+ assert(VecArg0->getElementType() == VecArg->getElementType());
+ assert(VecArg0->getNumElements() == VecArg->getNumElements());
+ auto NextOperandList = populateOperands(Arg, Builder);
+ NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+ }
+ return NewOperands;
+}
+
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
IRBuilder<> &Builder) {
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
- unsigned NumOperands = Orig->getNumOperands() - 1;
+ return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
+}
+/*
+static SmallVector<Value *> argVectorFlattenExcludeLastElement(CallInst *Orig,
+ IRBuilder<> &Builder) {
+ // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ unsigned NumOperands = Orig->getNumOperands() - 2;
assert(NumOperands > 0);
Value *Arg0 = Orig->getOperand(0);
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
@@ -74,7 +100,7 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
}
return NewOperands;
}
-
+*/
namespace {
class OpLowerer {
Module &M;
@@ -168,6 +194,25 @@ class OpLowerer {
}
} else if (IsVectorArgExpansion) {
Args = argVectorFlatten(CI, OpBuilder.getIRB());
+ } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
+ // arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ // arg[NumOperands-2] also does not need to be flattened because it is a scalar.
+ unsigned NumOperands = CI->getNumOperands() - 2;
+ Args.push_back(CI->getArgOperand(NumOperands));
+ Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
+
+ /*unsigned NumOperands = CI->getNumOperands() - 1;
+ assert(NumOperands > 0);
+ Value *LastArg = CI->getOperand(NumOperands - 1);
+
+ Args.push_back(LastArg);
+
+ //dbgs() << "Value of LastArg" << LastArg->getName() << "\n";
+
+
+ //Args = populateOperands(LastArg, OpBuilder.getIRB());
+ Args.append(argVectorFlattenExcludeLastElement(CI, OpBuilder.getIRB()));
+ */
} else {
Args.append(CI->arg_begin(), CI->arg_end());
}
diff --git a/llvm/test/CodeGen/DirectX/dot2add.ll b/llvm/test/CodeGen/DirectX/dot2add.ll
new file mode 100644
index 0000000000000..b1019c36b56e8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot2add.ll
@@ -0,0 +1,8 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) {
+entry:
+; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3)
+ %ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c)
+ ret float %ret
+}
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/131237
More information about the llvm-commits
mailing list