[clang] [llvm] Add length HLSL function to DirectX Backend (PR #101256)
Joshua Batista via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 30 15:16:26 PDT 2024
https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/101256
>From 7027cf254ae1b6acfdfbbf5dbeda3c4d6a4b3c43 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 26 Jul 2024 15:41:01 -0700
Subject: [PATCH 1/2] first attempt
---
clang/docs/LanguageExtensions.rst | 1 +
clang/docs/ReleaseNotes.rst | 1 +
clang/include/clang/Basic/Builtins.td | 6 ++++++
clang/lib/CodeGen/CGBuiltin.cpp | 3 +++
clang/lib/Sema/SemaChecking.cpp | 1 +
clang/test/CodeGenHLSL/builtins/length.hlsl | 1 +
clang/test/SemaHLSL/BuiltIns/length-errors.hlsl | 1 +
7 files changed, 14 insertions(+)
create mode 100644 clang/test/CodeGenHLSL/builtins/length.hlsl
create mode 100644 clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index a747464582e77..45f081081a371 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -664,6 +664,7 @@ Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±in
T __builtin_elementwise_cosh(T x) return the hyperbolic cosine of angle x in radians floating point types
T __builtin_elementwise_tanh(T x) return the hyperbolic tangent of angle x in radians floating point types
T __builtin_elementwise_floor(T x) return the largest integral value less than or equal to x floating point types
+ T __builtin_elementwise_length(T x) return the length of the specified floating-point vector floating point types
T __builtin_elementwise_log(T x) return the natural logarithm of x floating point types
T __builtin_elementwise_log2(T x) return the base 2 logarithm of x floating point types
T __builtin_elementwise_log10(T x) return the base 10 logarithm of x floating point types
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index dad44f45a847f..46f40889b4b33 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -243,6 +243,7 @@ DWARF Support in Clang
Floating Point Support in Clang
-------------------------------
+- Add ``__builtin_elementwise_length``builtin for floating point types only.
Fixed Point Support in Clang
----------------------------
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 4133f6ff40cf3..d6122a484c094 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -1248,6 +1248,12 @@ def ElementwiseBitreverse : Builtin {
let Prototype = "void(...)";
}
+def ElementwiseLength : Builtin, F16F128MathTemplate {
+ let Spellings = ["__builtin_elementwise_length"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
+ let Prototype = "T(float, T)";
+}
+
def ElementwiseMax : Builtin {
let Spellings = ["__builtin_elementwise_max"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f0651c280ff95..38c7cc8ab5a78 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3815,6 +3815,9 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_elementwise_exp2:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::exp2, "elt.exp2"));
+ case Builtin::BI__builtin_elementwise_length:
+ return RValue::get(emitBuiltinWithOneOverloadedType<1>(
+ *this, E, llvm::Intrinsic::length, "elt.length"));
case Builtin::BI__builtin_elementwise_log:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::log, "elt.log"));
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index bb30b1e289a1c..09e3b17571528 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2684,6 +2684,7 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
case Builtin::BI__builtin_elementwise_exp:
case Builtin::BI__builtin_elementwise_exp2:
case Builtin::BI__builtin_elementwise_floor:
+ case Builtin::BI__builtin_elementwise_length:
case Builtin::BI__builtin_elementwise_log:
case Builtin::BI__builtin_elementwise_log2:
case Builtin::BI__builtin_elementwise_log10:
diff --git a/clang/test/CodeGenHLSL/builtins/length.hlsl b/clang/test/CodeGenHLSL/builtins/length.hlsl
new file mode 100644
index 0000000000000..2f259b79aa7e2
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/length.hlsl
@@ -0,0 +1 @@
+s
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
new file mode 100644
index 0000000000000..2f259b79aa7e2
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
@@ -0,0 +1 @@
+s
\ No newline at end of file
>From fc20777ddb0d2e083fa92f3c1673e87874f8f935 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Tue, 30 Jul 2024 15:05:57 -0700
Subject: [PATCH 2/2] sema and codegen hlsl passes
---
clang/include/clang/Basic/Builtins.td | 12 +--
clang/lib/CodeGen/CGBuiltin.cpp | 19 ++++-
clang/lib/CodeGen/CGHLSLRuntime.h | 1 +
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 33 +++++++++
clang/lib/Sema/SemaChecking.cpp | 1 -
clang/lib/Sema/SemaHLSL.cpp | 22 ++++++
clang/test/CodeGenHLSL/builtins/length.hlsl | 74 ++++++++++++++++++-
.../test/SemaHLSL/BuiltIns/length-errors.hlsl | 32 +++++++-
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 2 +
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 39 ++++++++++
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 2 +-
12 files changed, 226 insertions(+), 13 deletions(-)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index d6122a484c094..0baadf0d196b2 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -1248,12 +1248,6 @@ def ElementwiseBitreverse : Builtin {
let Prototype = "void(...)";
}
-def ElementwiseLength : Builtin, F16F128MathTemplate {
- let Spellings = ["__builtin_elementwise_length"];
- let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
- let Prototype = "T(float, T)";
-}
-
def ElementwiseMax : Builtin {
let Spellings = ["__builtin_elementwise_max"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
@@ -4713,6 +4707,12 @@ def HLSLIsinf : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLLength : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_length"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_lerp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 38c7cc8ab5a78..a28073ca9ccc5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3815,9 +3815,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BI__builtin_elementwise_exp2:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::exp2, "elt.exp2"));
- case Builtin::BI__builtin_elementwise_length:
- return RValue::get(emitBuiltinWithOneOverloadedType<1>(
- *this, E, llvm::Intrinsic::length, "elt.length"));
case Builtin::BI__builtin_elementwise_log:
return RValue::get(emitBuiltinWithOneOverloadedType<1>(
*this, E, llvm::Intrinsic::log, "elt.log"));
@@ -18463,6 +18460,22 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),
ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");
}
+ case Builtin::BI__builtin_hlsl_elementwise_length: {
+ Value *X = EmitScalarExpr(E->getArg(0));
+
+ if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+ llvm_unreachable("length operand must have a float representation");
+ // if the operand is a scalar, we can use the fabs llvm intrinsic directly
+ if (!E->getArg(0)->getType()->isVectorType()) {
+ llvm::Type *ResultType = ConvertType(E->getType());
+ Function *F = CGM.getIntrinsic(Intrinsic::fabs, ResultType);
+ return Builder.CreateCall(F, X);
+ }
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/X->getType()->getScalarType(),
+ CGM.getHLSLRuntime().getLengthIntrinsic(),
+ ArrayRef<Value *>{X}, nullptr, "hlsl.length");
+ }
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 8c067f4963955..3f2dc0ae7b84d 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -75,6 +75,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 09f26a4588c14..21ac25bba1acb 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -908,6 +908,39 @@ float3 lerp(float3, float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
float4 lerp(float4, float4, float4);
+
+//===----------------------------------------------------------------------===//
+// length builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T length(T x)
+/// \brief Returns the length of the specified floating-point vector.
+/// \param x [in] The vector of floats, or a scalar float.
+///
+/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + �).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+half length(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+half length(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+half length(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+half length(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+float length(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+float length(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+float length(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_length)
+float length(float4);
+
//===----------------------------------------------------------------------===//
// log builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 09e3b17571528..bb30b1e289a1c 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2684,7 +2684,6 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
case Builtin::BI__builtin_elementwise_exp:
case Builtin::BI__builtin_elementwise_exp2:
case Builtin::BI__builtin_elementwise_floor:
- case Builtin::BI__builtin_elementwise_length:
case Builtin::BI__builtin_elementwise_log:
case Builtin::BI__builtin_elementwise_log2:
case Builtin::BI__builtin_elementwise_log10:
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9940bc5b4a606..624cbd3777bb8 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1076,6 +1076,28 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_elementwise_length: {
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+ if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+ QualType RetTy;
+
+ if (auto *VTy = ArgTyA->getAs<VectorType>())
+ RetTy = VTy->getElementType();
+ else
+ RetTy = TheCall->getArg(0)->getType();
+
+ TheCall->setType(RetTy);
+
+
+ if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+ return true;
+ break;
+ }
case Builtin::BI__builtin_hlsl_mad: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
diff --git a/clang/test/CodeGenHLSL/builtins/length.hlsl b/clang/test/CodeGenHLSL/builtins/length.hlsl
index 2f259b79aa7e2..0af669f36e6ba 100644
--- a/clang/test/CodeGenHLSL/builtins/length.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/length.hlsl
@@ -1 +1,73 @@
-s
\ No newline at end of file
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.fabs.f16(half
+// NO_HALF: call float @llvm.fabs.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_length_half(half p0)
+{
+ return length(p0);
+}
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v2f16
+// NO_HALF: %hlsl.length = call float @llvm.dx.length.v2f32(
+// NATIVE_HALF: ret half %hlsl.length
+// NO_HALF: ret float %hlsl.length
+half test_length_half2(half2 p0)
+{
+ return length(p0);
+}
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v3f16
+// NO_HALF: %hlsl.length = call float @llvm.dx.length.v3f32(
+// NATIVE_HALF: ret half %hlsl.length
+// NO_HALF: ret float %hlsl.length
+half test_length_half3(half3 p0)
+{
+ return length(p0);
+}
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v4f16
+// NO_HALF: %hlsl.length = call float @llvm.dx.length.v4f32(
+// NATIVE_HALF: ret half %hlsl.length
+// NO_HALF: ret float %hlsl.length
+half test_length_half4(half4 p0)
+{
+ return length(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.fabs.f32(float
+// CHECK: ret float
+float test_length_float(float p0)
+{
+ return length(p0);
+}
+// CHECK: define noundef float @
+// CHECK: %hlsl.length = call float @llvm.dx.length.v2f32(
+// CHECK: ret float %hlsl.length
+float test_length_float2(float2 p0)
+{
+ return length(p0);
+}
+// CHECK: define noundef float @
+// CHECK: %hlsl.length = call float @llvm.dx.length.v3f32(
+// CHECK: ret float %hlsl.length
+float test_length_float3(float3 p0)
+{
+ return length(p0);
+}
+// CHECK: define noundef float @
+// CHECK: %hlsl.length = call float @llvm.dx.length.v4f32(
+// CHECK: ret float %hlsl.length
+float test_length_float4(float4 p0)
+{
+ return length(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
index 2f259b79aa7e2..781c344f0da17 100644
--- a/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
@@ -1 +1,31 @@
-s
\ No newline at end of file
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg()
+{
+ return __builtin_hlsl_elementwise_length();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+bool2 test_too_many_arg(float2 p0)
+{
+ return __builtin_hlsl_elementwise_length(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+ return __builtin_hlsl_elementwise_length(p1);
+ // expected-error at -1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_length_int_to_float_promotion(int p1)
+{
+ return __builtin_hlsl_elementwise_length(p1);
+ // expected-error at -1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_length_int2_to_float2_promotion(int2 p1)
+{
+ return __builtin_hlsl_elementwise_length(p1);
+ // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index a7f212da2f5b6..47c01f899a926 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -55,6 +55,8 @@ def int_dx_isinf :
def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
+def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index ef6ddf12c32f6..c91fe859d7cc2 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -63,5 +63,7 @@ let TargetPrefix = "spv" in {
def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
+ def int_spv_length : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
+ [IntrNoMem, IntrWillReturn] >;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 4b162a35365c8..7ef5b9eae9310 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -42,6 +42,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_clamp:
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
+ case Intrinsic::dx_length:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
@@ -157,6 +158,42 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
return true;
}
+static bool expandLengthIntrinsic(CallInst *Orig) {
+ Value *X = Orig->getOperand(0);
+ IRBuilder<> Builder(Orig->getParent());
+ Builder.SetInsertPoint(Orig);
+ Type *Ty = X->getType();
+ Type *EltTy = Ty->getScalarType();
+
+ // Though dx.length does work on scalar type, we can optimize it to just emit
+ // fabs, in CGBuiltin.cpp. We shouldn't see a scalar type here because
+ // CGBuiltin.cpp should have emitted a fabs call.
+ assert(Ty->isVectorTy() && "dx.length only works on vector type");
+ Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+ auto *XVec = dyn_cast<FixedVectorType>(Ty);
+ unsigned size = XVec->getNumElements();
+ if (size > 1) {
+ Value *Sum = Builder.CreateFMul(Elt, Elt);
+ for (unsigned i = 1; i < size; i++) {
+ Elt = Builder.CreateExtractElement(X, i);
+ Value *Mul = Builder.CreateFMul(Elt, Elt);
+ Sum = Builder.CreateFAdd(Sum, Mul);
+ }
+ Value *Result = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+ } else {
+ Value *Result = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::fabs, ArrayRef<Value *>{Elt}, nullptr, "elt.abs");
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+ }
+}
+
static bool expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
@@ -280,6 +317,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandClampIntrinsic(Orig, F.getIntrinsicID());
case Intrinsic::dx_lerp:
return expandLerpIntrinsic(Orig);
+ case Intrinsic::dx_length:
+ return expandLengthIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 8391e0dec9a39..0f0b7fee96559 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -280,7 +280,7 @@ void SPIRVInstructionSelector::setupMF(MachineFunction &MF, GISelKnownBits *KB,
CodeGenCoverage *CoverageInfo,
ProfileSummaryInfo *PSI,
BlockFrequencyInfo *BFI) {
- MMI = &MF.getMMI().getObjFileInfo<SPIRVMachineModuleInfo>();
+ // MMI = &MF.getMMI().getObjFileInfo<SPIRVMachineModuleInfo>();
MRI = &MF.getRegInfo();
GR.setCurrentFunc(MF);
InstructionSelector::setupMF(MF, KB, CoverageInfo, PSI, BFI);
More information about the llvm-commits
mailing list