[clang] 1b2d11d - Add normalize builtins and normalize HLSL function to DirectX and SPIR-V backend (#102683)
via cfe-commits
cfe-commits at lists.llvm.org
Tue Aug 13 15:15:22 PDT 2024
Author: Joshua Batista
Date: 2024-08-13T15:15:19-07:00
New Revision: 1b2d11de938af899c74eacc0218304576fe6052b
URL: https://github.com/llvm/llvm-project/commit/1b2d11de938af899c74eacc0218304576fe6052b
DIFF: https://github.com/llvm/llvm-project/commit/1b2d11de938af899c74eacc0218304576fe6052b.diff
LOG: Add normalize builtins and normalize HLSL function to DirectX and SPIR-V backend (#102683)
This PR adds the normalize intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.
Used https://github.com/llvm/llvm-project/pull/101256 as a reference,
along with https://github.com/llvm/llvm-project/pull/102243
Fixes https://github.com/llvm/llvm-project/issues/99139
Added:
clang/test/CodeGenHLSL/builtins/normalize.hlsl
clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
llvm/test/CodeGen/DirectX/normalize.ll
llvm/test/CodeGen/DirectX/normalize_error.ll
llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll
Modified:
clang/include/clang/Basic/Builtins.td
clang/lib/CodeGen/CGBuiltin.cpp
clang/lib/CodeGen/CGHLSLRuntime.h
clang/lib/Headers/hlsl/hlsl_intrinsics.h
clang/lib/Sema/SemaHLSL.cpp
llvm/include/llvm/IR/IntrinsicsDirectX.td
llvm/include/llvm/IR/IntrinsicsSPIRV.td
llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..0a874d8638df43 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_normalize"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rcp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..bfae7f8abdb50a 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18584,6 +18584,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.length");
}
+ case Builtin::BI__builtin_hlsl_normalize: {
+ Value *X = EmitScalarExpr(E->getArg(0));
+
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ "normalize operand must have a float representation");
+
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/X->getType(),
+ CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
+ nullptr, "hlsl.normalize");
+ }
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 590d388c456c45..cd604bea2e763d 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..678cdc77f8a71b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double4 min(double4, double4);
+//===----------------------------------------------------------------------===//
+// normalize builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T normalize(T x)
+/// \brief Returns the normalized unit vector of the specified floating-point
+/// vector. \param x [in] The vector of floats.
+///
+/// Normalize is based on the following formula: x / length(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half normalize(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half2 normalize(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half3 normalize(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half4 normalize(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float normalize(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float2 normalize(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float3 normalize(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float4 normalize(float4);
+
//===----------------------------------------------------------------------===//
// pow builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a9c0c57e88221d..e3e926465e799e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_normalize: {
+ if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+ return true;
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+ // return type is the same as the input type
+ TheCall->setType(ArgTyA);
+ break;
+ }
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/normalize.hlsl b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
new file mode 100644
index 00000000000000..213959e77e7e1e
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
@@ -0,0 +1,100 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,DXIL_CHECK,DXIL_NATIVE_HALF,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,DXIL_CHECK,NO_HALF,DXIL_NO_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF,SPIR_NATIVE_HALF,SPIR_CHECK
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF,SPIR_NO_HALF,SPIR_CHECK
+
+// DXIL_NATIVE_HALF: define noundef half @
+// SPIR_NATIVE_HALF: define spir_func noundef half @
+// DXIL_NATIVE_HALF: call half @llvm.dx.normalize.f16(half
+// SPIR_NATIVE_HALF: call half @llvm.spv.normalize.f16(half
+// DXIL_NO_HALF: call float @llvm.dx.normalize.f32(float
+// SPIR_NO_HALF: call float @llvm.spv.normalize.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_normalize_half(half p0)
+{
+ return normalize(p0);
+}
+// DXIL_NATIVE_HALF: define noundef <2 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <2 x half> @
+// DXIL_NATIVE_HALF: call <2 x half> @llvm.dx.normalize.v2f16(<2 x half>
+// SPIR_NATIVE_HALF: call <2 x half> @llvm.spv.normalize.v2f16(<2 x half>
+// DXIL_NO_HALF: call <2 x float> @llvm.dx.normalize.v2f32(<2 x float>
+// SPIR_NO_HALF: call <2 x float> @llvm.spv.normalize.v2f32(<2 x float>
+// NATIVE_HALF: ret <2 x half> %hlsl.normalize
+// NO_HALF: ret <2 x float> %hlsl.normalize
+half2 test_normalize_half2(half2 p0)
+{
+ return normalize(p0);
+}
+// DXIL_NATIVE_HALF: define noundef <3 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <3 x half> @
+// DXIL_NATIVE_HALF: call <3 x half> @llvm.dx.normalize.v3f16(<3 x half>
+// SPIR_NATIVE_HALF: call <3 x half> @llvm.spv.normalize.v3f16(<3 x half>
+// DXIL_NO_HALF: call <3 x float> @llvm.dx.normalize.v3f32(<3 x float>
+// SPIR_NO_HALF: call <3 x float> @llvm.spv.normalize.v3f32(<3 x float>
+// NATIVE_HALF: ret <3 x half> %hlsl.normalize
+// NO_HALF: ret <3 x float> %hlsl.normalize
+half3 test_normalize_half3(half3 p0)
+{
+ return normalize(p0);
+}
+// DXIL_NATIVE_HALF: define noundef <4 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <4 x half> @
+// DXIL_NATIVE_HALF: call <4 x half> @llvm.dx.normalize.v4f16(<4 x half>
+// SPIR_NATIVE_HALF: call <4 x half> @llvm.spv.normalize.v4f16(<4 x half>
+// DXIL_NO_HALF: call <4 x float> @llvm.dx.normalize.v4f32(<4 x float>
+// SPIR_NO_HALF: call <4 x float> @llvm.spv.normalize.v4f32(<4 x float>
+// NATIVE_HALF: ret <4 x half> %hlsl.normalize
+// NO_HALF: ret <4 x float> %hlsl.normalize
+half4 test_normalize_half4(half4 p0)
+{
+ return normalize(p0);
+}
+
+// DXIL_CHECK: define noundef float @
+// SPIR_CHECK: define spir_func noundef float @
+// DXIL_CHECK: call float @llvm.dx.normalize.f32(float
+// SPIR_CHECK: call float @llvm.spv.normalize.f32(float
+// CHECK: ret float
+float test_normalize_float(float p0)
+{
+ return normalize(p0);
+}
+// DXIL_CHECK: define noundef <2 x float> @
+// SPIR_CHECK: define spir_func noundef <2 x float> @
+// DXIL_CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// SPIR_CHECK: %hlsl.normalize = call <2 x float> @llvm.spv.normalize.v2f32(<2 x float>
+// CHECK: ret <2 x float> %hlsl.normalize
+float2 test_normalize_float2(float2 p0)
+{
+ return normalize(p0);
+}
+// DXIL_CHECK: define noundef <3 x float> @
+// SPIR_CHECK: define spir_func noundef <3 x float> @
+// DXIL_CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// SPIR_CHECK: %hlsl.normalize = call <3 x float> @llvm.spv.normalize.v3f32(<3 x float>
+// CHECK: ret <3 x float> %hlsl.normalize
+float3 test_normalize_float3(float3 p0)
+{
+ return normalize(p0);
+}
+// DXIL_CHECK: define noundef <4 x float> @
+// SPIR_CHECK: define spir_func noundef <4 x float> @
+// DXIL_CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// SPIR_CHECK: %hlsl.normalize = call <4 x float> @llvm.spv.normalize.v4f32(
+// CHECK: ret <4 x float> %hlsl.normalize
+float4 test_length_float4(float4 p0)
+{
+ return normalize(p0);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
new file mode 100644
index 00000000000000..c72c8b3c222b6b
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+ return __builtin_hlsl_normalize();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+ return __builtin_hlsl_normalize(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error at -1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_normalize_int_to_float_promotion(int p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error at -1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..904801e6e9e95f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..1b5e463822749e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+ def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..626321f44c2bfc 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
+ case Intrinsic::dx_normalize:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
@@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}
+static bool expandNormalizeIntrinsic(CallInst *Orig) {
+ Value *X = Orig->getOperand(0);
+ Type *Ty = Orig->getType();
+ Type *EltTy = Ty->getScalarType();
+ IRBuilder<> Builder(Orig->getParent());
+ Builder.SetInsertPoint(Orig);
+
+ auto *XVec = dyn_cast<FixedVectorType>(Ty);
+ if (!XVec) {
+ if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
+ const APFloat &fpVal = constantFP->getValueAPF();
+ if (fpVal.isZero())
+ report_fatal_error(Twine("Invalid input scalar: length is zero"),
+ /* gen_crash_diag=*/false);
+ }
+ Value *Result = Builder.CreateFDiv(X, X);
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+ }
+
+ Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+ unsigned XVecSize = XVec->getNumElements();
+ Value *DotProduct = nullptr;
+ // use the dot intrinsic corresponding to the vector size
+ switch (XVecSize) {
+ case 1:
+ report_fatal_error(Twine("Invalid input vector: length is zero"),
+ /* gen_crash_diag=*/false);
+ break;
+ case 2:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
+ break;
+ case 3:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
+ break;
+ case 4:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
+ break;
+ default:
+ report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
+ /* gen_crash_diag=*/false);
+ }
+
+ // verify that the length is non-zero
+ // (if the dot product is non-zero, then the length is non-zero)
+ if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
+ const APFloat &fpVal = constantFP->getValueAPF();
+ if (fpVal.isZero())
+ report_fatal_error(Twine("Invalid input vector: length is zero"),
+ /* gen_crash_diag=*/false);
+ }
+
+ Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+ ArrayRef<Value *>{DotProduct},
+ nullptr, "dx.rsqrt");
+
+ Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+ Value *Result = Builder.CreateFMul(X, MultiplicandVec);
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+}
+
static bool expandPowIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
@@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
+ case Intrinsic::dx_normalize:
+ return expandNormalizeIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 1a8341ec48aaf5..48e7b9fd764b3e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -244,6 +244,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1409,6 +1412,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::Normalize)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2142,6 +2162,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectLength(ResVReg, ResType, I);
case Intrinsic::spv_frac:
return selectFrac(ResVReg, ResType, I);
+ case Intrinsic::spv_normalize:
+ return selectNormalize(ResVReg, ResType, I);
case Intrinsic::spv_rsqrt:
return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
new file mode 100644
index 00000000000000..f3533cc56e7c25
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -0,0 +1,118 @@
+; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for normalize are generated for half/float.
+
+declare half @llvm.dx.normalize.f16(half)
+declare <2 x half> @llvm.dx.normalize.v2f16(<2 x half>)
+declare <3 x half> @llvm.dx.normalize.v3f16(<3 x half>)
+declare <4 x half> @llvm.dx.normalize.v4f16(<4 x half>)
+
+declare float @llvm.dx.normalize.f32(float)
+declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
+declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
+declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
+
+define noundef half @test_normalize_half(half noundef %p0) {
+entry:
+ ; CHECK: fdiv half %p0, %p0
+ %hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
+ ret half %hlsl.normalize
+}
+
+define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
+ ; EXPCHECK: [[doth2:%.*]] = call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+ ; DOPCHECK: [[doth2:%.*]] = call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: [[rsqrt:%.*]] = call half @llvm.dx.rsqrt.f16(half [[doth2]])
+ ; DOPCHECK: [[rsqrt:%.*]] = call half @dx.op.unary.f16(i32 25, half [[doth2]])
+ ; CHECK: [[splatinserth2:%.*]] = insertelement <2 x half> poison, half [[rsqrt]], i64 0
+ ; CHECK: [[splat:%.*]] = shufflevector <2 x half> [[splatinserth2]], <2 x half> poison, <2 x i32> zeroinitializer
+ ; CHECK: fmul <2 x half> %p0, [[splat]]
+
+ %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
+ ret <2 x half> %hlsl.normalize
+}
+
+define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
+ ; EXPCHECK: [[doth3:%.*]] = call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+ ; DOPCHECK: [[doth3:%.*]] = call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: [[rsqrt:%.*]] = call half @llvm.dx.rsqrt.f16(half [[doth3]])
+ ; DOPCHECK: [[rsqrt:%.*]] = call half @dx.op.unary.f16(i32 25, half [[doth3]])
+ ; CHECK: [[splatinserth3:%.*]] = insertelement <3 x half> poison, half [[rsqrt]], i64 0
+ ; CHECK: [[splat:%.*]] shufflevector <3 x half> [[splatinserth3]], <3 x half> poison, <3 x i32> zeroinitializer
+ ; CHECK: fmul <3 x half> %p0, %.splat
+
+ %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
+ ret <3 x half> %hlsl.normalize
+}
+
+define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
+ ; EXPCHECK: [[doth4:%.*]] = call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+ ; DOPCHECK: [[doth4:%.*]] = call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: [[rsqrt:%.*]] = call half @llvm.dx.rsqrt.f16(half [[doth4]])
+ ; DOPCHECK: [[rsqrt:%.*]] = call half @dx.op.unary.f16(i32 25, half [[doth4]])
+ ; CHECK: [[splatinserth4:%.*]] = insertelement <4 x half> poison, half [[rsqrt]], i64 0
+ ; CHECK: [[splat:%.*]] shufflevector <4 x half> [[splatinserth4]], <4 x half> poison, <4 x i32> zeroinitializer
+ ; CHECK: fmul <4 x half> %p0, %.splat
+
+ %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16(<4 x half> %p0)
+ ret <4 x half> %hlsl.normalize
+}
+
+define noundef float @test_normalize_float(float noundef %p0) {
+entry:
+ ; CHECK: fdiv float %p0, %p0
+ %hlsl.normalize = call float @llvm.dx.normalize.f32(float %p0)
+ ret float %hlsl.normalize
+}
+
+define noundef <2 x float> @test_normalize_float2(<2 x float> noundef %p0) {
+entry:
+ ; CHECK: extractelement <2 x float> %{{.*}}, i64 0
+ ; EXPCHECK: [[dotf2:%.*]] = call float @llvm.dx.dot2.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
+ ; DOPCHECK: [[dotf2:%.*]] = call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+ ; EXPCHECK: [[rsqrt:%.*]] = call float @llvm.dx.rsqrt.f32(float [[dotf2]])
+ ; DOPCHECK: [[rsqrt:%.*]] = call float @dx.op.unary.f32(i32 25, float [[dotf2]])
+ ; CHECK: [[splatinsertf2:%.*]] = insertelement <2 x float> poison, float [[rsqrt]], i64 0
+ ; CHECK: [[splat:%.*]] shufflevector <2 x float> [[splatinsertf2]], <2 x float> poison, <2 x i32> zeroinitializer
+ ; CHECK: fmul <2 x float> %p0, %.splat
+
+ %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(<2 x float> %p0)
+ ret <2 x float> %hlsl.normalize
+}
+
+define noundef <3 x float> @test_normalize_float3(<3 x float> noundef %p0) {
+entry:
+ ; CHECK: extractelement <3 x float> %{{.*}}, i64 0
+ ; EXPCHECK: [[dotf3:%.*]] = call float @llvm.dx.dot3.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
+ ; DOPCHECK: [[dotf3:%.*]] = call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+ ; EXPCHECK: [[rsqrt:%.*]] = call float @llvm.dx.rsqrt.f32(float [[dotf3]])
+ ; DOPCHECK: [[rsqrt:%.*]] = call float @dx.op.unary.f32(i32 25, float [[dotf3]])
+ ; CHECK: [[splatinsertf3:%.*]] = insertelement <3 x float> poison, float [[rsqrt]], i64 0
+ ; CHECK: [[splat:%.*]] shufflevector <3 x float> [[splatinsertf3]], <3 x float> poison, <3 x i32> zeroinitializer
+ ; CHECK: fmul <3 x float> %p0, %.splat
+
+ %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(<3 x float> %p0)
+ ret <3 x float> %hlsl.normalize
+}
+
+define noundef <4 x float> @test_normalize_float4(<4 x float> noundef %p0) {
+entry:
+ ; CHECK: extractelement <4 x float> %{{.*}}, i64 0
+ ; EXPCHECK: [[dotf4:%.*]] = call float @llvm.dx.dot4.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
+ ; DOPCHECK: [[dotf4:%.*]] = call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+ ; EXPCHECK: [[rsqrt:%.*]] = call float @llvm.dx.rsqrt.f32(float [[dotf4]])
+ ; DOPCHECK: [[rsqrt:%.*]] = call float @dx.op.unary.f32(i32 25, float [[dotf4]])
+ ; CHECK: [[splatinsertf4:%.*]] = insertelement <4 x float> poison, float [[rsqrt]], i64 0
+ ; CHECK: [[splat:%.*]] shufflevector <4 x float> [[splatinsertf4]], <4 x float> poison, <4 x i32> zeroinitializer
+ ; CHECK: fmul <4 x float> %p0, %.splat
+
+ %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(<4 x float> %p0)
+ ret <4 x float> %hlsl.normalize
+}
diff --git a/llvm/test/CodeGen/DirectX/normalize_error.ll b/llvm/test/CodeGen/DirectX/normalize_error.ll
new file mode 100644
index 00000000000000..da0905690236db
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation normalize does not support double overload type
+; CHECK: Cannot create Dot2 operation: Invalid overload type
+
+define noundef <2 x double> @test_normalize_double2(<2 x double> noundef %p0) {
+entry:
+ %hlsl.normalize = call <2 x double> @llvm.dx.normalize.v2f32(<2 x double> %p0)
+ ret <2 x double> %hlsl.normalize
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll
new file mode 100644
index 00000000000000..fa73b9c2a4d3ab
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll
@@ -0,0 +1,31 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure SPIRV operation function calls for normalize are lowered correctly.
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+
+define noundef <4 x half> @normalize_half4(<4 x half> noundef %a) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec4_float_16]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
+ ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] Normalize %[[#arg0]]
+ %hlsl.normalize = call <4 x half> @llvm.spv.normalize.v4f16(<4 x half> %a)
+ ret <4 x half> %hlsl.normalize
+}
+
+define noundef <4 x float> @normalize_float4(<4 x float> noundef %a) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] Normalize %[[#arg0]]
+ %hlsl.normalize = call <4 x float> @llvm.spv.normalize.v4f32(<4 x float> %a)
+ ret <4 x float> %hlsl.normalize
+}
+
+declare <4 x half> @llvm.spv.normalize.v4f16(<4 x half>)
+declare <4 x float> @llvm.spv.normalize.v4f32(<4 x float>)
More information about the cfe-commits
mailing list