[clang] [llvm] Add length HLSL function to DirectX Backend (PR #101256)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 30 15:16:13 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
Author: Joshua Batista (bob80905)
<details>
<summary>Changes</summary>
This PR adds the length intrinsic and an HLSL function that uses it.
The SPIRV implementation is left for a future PR.
Fixes #<!-- -->99134
---
Full diff: https://github.com/llvm/llvm-project/pull/101256.diff
13 Files Affected:
- (modified) clang/docs/LanguageExtensions.rst (+1)
- (modified) clang/docs/ReleaseNotes.rst (+1)
- (modified) clang/include/clang/Basic/Builtins.td (+6)
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16)
- (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+33)
- (modified) clang/lib/Sema/SemaHLSL.cpp (+22)
- (added) clang/test/CodeGenHLSL/builtins/length.hlsl (+73)
- (added) clang/test/SemaHLSL/BuiltIns/length-errors.hlsl (+31)
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
- (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+2)
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+39)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+1-1)
``````````diff
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index a747464582e77..45f081081a371 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -664,6 +664,7 @@ Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±in
T __builtin_elementwise_cosh(T x) return the hyperbolic cosine of angle x in radians floating point types
T __builtin_elementwise_tanh(T x) return the hyperbolic tangent of angle x in radians floating point types
T __builtin_elementwise_floor(T x) return the largest integral value less than or equal to x floating point types
+ T __builtin_elementwise_length(T x) return the length of the specified floating-point vector floating point types
T __builtin_elementwise_log(T x) return the natural logarithm of x floating point types
T __builtin_elementwise_log2(T x) return the base 2 logarithm of x floating point types
T __builtin_elementwise_log10(T x) return the base 10 logarithm of x floating point types
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index dad44f45a847f..46f40889b4b33 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -243,6 +243,7 @@ DWARF Support in Clang
Floating Point Support in Clang
-------------------------------
+- Add ``__builtin_elementwise_length``builtin for floating point types only.
Fixed Point Support in Clang
----------------------------
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 4133f6ff40cf3..0baadf0d196b2 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4707,6 +4707,12 @@ def HLSLIsinf : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLLength : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_length"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_lerp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f0651c280ff95..a28073ca9ccc5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18460,6 +18460,22 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),
ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");
}
+ case Builtin::BI__builtin_hlsl_elementwise_length: {
+ Value *X = EmitScalarExpr(E->getArg(0));
+
+ if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+ llvm_unreachable("length operand must have a float representation");
+ // if the operand is a scalar, we can use the fabs llvm intrinsic directly
+ if (!E->getArg(0)->getType()->isVectorType()) {
+ llvm::Type *ResultType = ConvertType(E->getType());
+ Function *F = CGM.getIntrinsic(Intrinsic::fabs, ResultType);
+ return Builder.CreateCall(F, X);
+ }
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/X->getType()->getScalarType(),
+ CGM.getHLSLRuntime().getLengthIntrinsic(),
+ ArrayRef<Value *>{X}, nullptr, "hlsl.length");
+ }
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 8c067f4963955..3f2dc0ae7b84d 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -75,6 +75,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 09f26a4588c14..21ac25bba1acb 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -908,6 +908,39 @@ float3 lerp(float3, float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
float4 lerp(float4, float4, float4);
+
+//===----------------------------------------------------------------------===//
+// length builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T length(T x)
+/// \brief Returns the length of the specified floating-point vector.
+/// \param x [in] The vector of floats, or a scalar float.
+///
+/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + �).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+half length(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+half length(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+half length(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+half length(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+float length(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+float length(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+float length(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+float length(float4);
+
//===----------------------------------------------------------------------===//
// log builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9940bc5b4a606..624cbd3777bb8 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1076,6 +1076,28 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_elementwise_length: {
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+ if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+ QualType RetTy;
+
+ if (auto *VTy = ArgTyA->getAs<VectorType>())
+ RetTy = VTy->getElementType();
+ else
+ RetTy = TheCall->getArg(0)->getType();
+
+ TheCall->setType(RetTy);
+
+
+ if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+ return true;
+ break;
+ }
case Builtin::BI__builtin_hlsl_mad: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
diff --git a/clang/test/CodeGenHLSL/builtins/length.hlsl b/clang/test/CodeGenHLSL/builtins/length.hlsl
new file mode 100644
index 0000000000000..0af669f36e6ba
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/length.hlsl
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.fabs.f16(half
+// NO_HALF: call float @llvm.fabs.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_length_half(half p0)
+{
+ return length(p0);
+}
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v2f16
+// NO_HALF: %hlsl.length = call float @llvm.dx.length.v2f32(
+// NATIVE_HALF: ret half %hlsl.length
+// NO_HALF: ret float %hlsl.length
+half test_length_half2(half2 p0)
+{
+ return length(p0);
+}
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v3f16
+// NO_HALF: %hlsl.length = call float @llvm.dx.length.v3f32(
+// NATIVE_HALF: ret half %hlsl.length
+// NO_HALF: ret float %hlsl.length
+half test_length_half3(half3 p0)
+{
+ return length(p0);
+}
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v4f16
+// NO_HALF: %hlsl.length = call float @llvm.dx.length.v4f32(
+// NATIVE_HALF: ret half %hlsl.length
+// NO_HALF: ret float %hlsl.length
+half test_length_half4(half4 p0)
+{
+ return length(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.fabs.f32(float
+// CHECK: ret float
+float test_length_float(float p0)
+{
+ return length(p0);
+}
+// CHECK: define noundef float @
+// CHECK: %hlsl.length = call float @llvm.dx.length.v2f32(
+// CHECK: ret float %hlsl.length
+float test_length_float2(float2 p0)
+{
+ return length(p0);
+}
+// CHECK: define noundef float @
+// CHECK: %hlsl.length = call float @llvm.dx.length.v3f32(
+// CHECK: ret float %hlsl.length
+float test_length_float3(float3 p0)
+{
+ return length(p0);
+}
+// CHECK: define noundef float @
+// CHECK: %hlsl.length = call float @llvm.dx.length.v4f32(
+// CHECK: ret float %hlsl.length
+float test_length_float4(float4 p0)
+{
+ return length(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
new file mode 100644
index 0000000000000..781c344f0da17
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg()
+{
+ return __builtin_hlsl_elementwise_length();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+bool2 test_too_many_arg(float2 p0)
+{
+ return __builtin_hlsl_elementwise_length(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+ return __builtin_hlsl_elementwise_length(p1);
+ // expected-error at -1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_length_int_to_float_promotion(int p1)
+{
+ return __builtin_hlsl_elementwise_length(p1);
+ // expected-error at -1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_length_int2_to_float2_promotion(int2 p1)
+{
+ return __builtin_hlsl_elementwise_length(p1);
+ // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index a7f212da2f5b6..47c01f899a926 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -55,6 +55,8 @@ def int_dx_isinf :
def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
+def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index ef6ddf12c32f6..c91fe859d7cc2 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -63,5 +63,7 @@ let TargetPrefix = "spv" in {
def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
+ def int_spv_length : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
+ [IntrNoMem, IntrWillReturn] >;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 4b162a35365c8..7ef5b9eae9310 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -42,6 +42,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_clamp:
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
+ case Intrinsic::dx_length:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
@@ -157,6 +158,42 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
return true;
}
+static bool expandLengthIntrinsic(CallInst *Orig) {
+ Value *X = Orig->getOperand(0);
+ IRBuilder<> Builder(Orig->getParent());
+ Builder.SetInsertPoint(Orig);
+ Type *Ty = X->getType();
+ Type *EltTy = Ty->getScalarType();
+
+ // Though dx.length does work on scalar type, we can optimize it to just emit
+ // fabs, in CGBuiltin.cpp. We shouldn't see a scalar type here because
+ // CGBuiltin.cpp should have emitted a fabs call.
+ assert(Ty->isVectorTy() && "dx.length only works on vector type");
+ Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+ auto *XVec = dyn_cast<FixedVectorType>(Ty);
+ unsigned size = XVec->getNumElements();
+ if (size > 1) {
+ Value *Sum = Builder.CreateFMul(Elt, Elt);
+ for (unsigned i = 1; i < size; i++) {
+ Elt = Builder.CreateExtractElement(X, i);
+ Value *Mul = Builder.CreateFMul(Elt, Elt);
+ Sum = Builder.CreateFAdd(Sum, Mul);
+ }
+ Value *Result = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+ } else {
+ Value *Result = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::fabs, ArrayRef<Value *>{Elt}, nullptr, "elt.abs");
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+ }
+}
+
static bool expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
@@ -280,6 +317,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandClampIntrinsic(Orig, F.getIntrinsicID());
case Intrinsic::dx_lerp:
return expandLerpIntrinsic(Orig);
+ case Intrinsic::dx_length:
+ return expandLengthIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 8391e0dec9a39..0f0b7fee96559 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -280,7 +280,7 @@ void SPIRVInstructionSelector::setupMF(MachineFunction &MF, GISelKnownBits *KB,
CodeGenCoverage *CoverageInfo,
ProfileSummaryInfo *PSI,
BlockFrequencyInfo *BFI) {
- MMI = &MF.getMMI().getObjFileInfo<SPIRVMachineModuleInfo>();
+ // MMI = &MF.getMMI().getObjFileInfo<SPIRVMachineModuleInfo>();
MRI = &MF.getRegInfo();
GR.setCurrentFunc(MF);
InstructionSelector::setupMF(MF, KB, CoverageInfo, PSI, BFI);
``````````
</details>
https://github.com/llvm/llvm-project/pull/101256
More information about the llvm-commits
mailing list