[clang] [llvm] Add normalize builtins and normalize HLSL function to DirectX and SPIR-V backend (PR #102683)

Joshua Batista via cfe-commits cfe-commits at lists.llvm.org
Sat Aug 10 00:46:45 PDT 2024


https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/102683

>From 547b4da91b20811db156a8c73fcb2f381cfed7bd Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 9 Aug 2024 10:48:10 -0700
Subject: [PATCH 1/8] suboptimal expansion of normalize done

---
 clang/include/clang/Basic/Builtins.td         |   6 +
 clang/lib/CodeGen/CGBuiltin.cpp               |  23 ++++
 clang/lib/CodeGen/CGHLSLRuntime.h             |   1 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  32 +++++
 clang/lib/Sema/SemaHLSL.cpp                   |  12 ++
 .../test/CodeGenHLSL/builtins/normalize.hlsl  |  73 +++++++++++
 .../SemaHLSL/BuiltIns/normalize-errors.hlsl   |  31 +++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |   1 +
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |   1 +
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  51 ++++++++
 llvm/test/CodeGen/DirectX/normalize.ll        | 116 ++++++++++++++++++
 llvm/test/CodeGen/DirectX/normalize_error.ll  |  10 ++
 12 files changed, 357 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/builtins/normalize.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/normalize.ll
 create mode 100644 llvm/test/CodeGen/DirectX/normalize_error.ll

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..0a874d8638df43 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_normalize"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_rcp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..2507c858d77209 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18584,6 +18584,29 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
         nullptr, "hlsl.length");
   }
+  case Builtin::BI__builtin_hlsl_normalize: {
+    Value *X = EmitScalarExpr(E->getArg(0));
+
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           "normalize operand must have a float representation");
+
+    // scalar inputs should expect a scalar return type
+    if (!E->getArg(0)->getType()->isVectorType())
+      return Builder.CreateIntrinsic(
+          /*ReturnType=*/X->getType()->getScalarType(),
+          CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
+          nullptr, "hlsl.normalize");
+
+    // construct a vector return type for vector inputs
+    auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+    llvm::Type *retType = X->getType()->getScalarType();
+    retType = llvm::VectorType::get(
+        retType, ElementCount::getFixed(XVecTy->getNumElements()));
+
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/retType, CGM.getHLSLRuntime().getNormalizeIntrinsic(),
+        ArrayRef<Value *>{X}, nullptr, "hlsl.normalize");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..80ca432f4b509c 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..678cdc77f8a71b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double4 min(double4, double4);
 
+//===----------------------------------------------------------------------===//
+// normalize builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T normalize(T x)
+/// \brief Returns the normalized unit vector of the specified floating-point
+/// vector. \param x [in] The vector of floats.
+///
+/// Normalize is based on the following formula: x / length(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half normalize(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half2 normalize(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half3 normalize(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half4 normalize(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float normalize(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float2 normalize(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float3 normalize(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float4 normalize(float4);
+
 //===----------------------------------------------------------------------===//
 // pow builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a9c0c57e88221d..61f68a415a7d6c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_normalize: {
+    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+      return true;
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+
+    TheCall->setType(ArgTyA);
+    break;
+  }
   // Note these are llvm builtins that we want to catch invalid intrinsic
   // generation. Normal handling of these builitns will occur elsewhere.
   case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/normalize.hlsl b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
new file mode 100644
index 00000000000000..f46a35866f45d9
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.normalize.f16(half
+// NO_HALF: call float @llvm.dx.normalize.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_normalize_half(half p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16
+// NO_HALF: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// NATIVE_HALF: ret <2 x half> %hlsl.normalize
+// NO_HALF: ret <2 x float> %hlsl.normalize
+half2 test_normalize_half2(half2 p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16
+// NO_HALF: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// NATIVE_HALF: ret <3 x half> %hlsl.normalize
+// NO_HALF: ret <3 x float> %hlsl.normalize
+half3 test_normalize_half3(half3 p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16
+// NO_HALF: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// NATIVE_HALF: ret <4 x half> %hlsl.normalize
+// NO_HALF: ret <4 x float> %hlsl.normalize
+half4 test_normalize_half4(half4 p0)
+{
+	return normalize(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.dx.normalize.f32(float
+// CHECK: ret float
+float test_normalize_float(float p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <2 x float> @
+// CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// CHECK: ret <2 x float> %hlsl.normalize
+float2 test_normalize_float2(float2 p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <3 x float> @
+// CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// CHECK: ret <3 x float> %hlsl.normalize
+float3 test_normalize_float3(float3 p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <4 x float> @
+// CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// CHECK: ret <4 x float> %hlsl.normalize
+float4 test_length_float4(float4 p0)
+{
+	return normalize(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
new file mode 100644
index 00000000000000..b348297d37eb1c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+  return __builtin_hlsl_normalize();
+  // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+  return __builtin_hlsl_normalize(p0, p0);
+  // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error at -1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_normalize_int_to_float_promotion(int p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error at -1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..904801e6e9e95f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
 def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..1b5e463822749e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
   def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], 
     [IntrNoMem, IntrWillReturn] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+  def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..649758e47db1dd 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_lerp:
   case Intrinsic::dx_length:
+  case Intrinsic::dx_normalize:
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return true;
@@ -229,6 +230,54 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
 }
 
+static bool expandNormalizeIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  Type *Ty = Orig->getType();
+  Type *EltTy = Ty->getScalarType();
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+
+  Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+  auto *XVec = dyn_cast<FixedVectorType>(Ty);
+  if (!XVec) {
+    if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
+      const APFloat &fpVal = constantFP->getValueAPF();
+      if (fpVal.isZero())
+        report_fatal_error(Twine("Invalid input scalar: length is zero"),
+                           /* gen_crash_diag=*/false);
+    }
+    Value *Result = Builder.CreateFDiv(X, X);
+
+    Orig->replaceAllUsesWith(Result);
+    Orig->eraseFromParent();
+    return true;
+  }
+
+  unsigned XVecSize = XVec->getNumElements();
+  Value *Sum = Builder.CreateFMul(Elt, Elt);
+  for (unsigned I = 1; I < XVecSize; I++) {
+    Elt = Builder.CreateExtractElement(X, I);
+    Value *Mul = Builder.CreateFMul(Elt, Elt);
+    Sum = Builder.CreateFAdd(Sum, Mul);
+  }
+  Value *Length = Builder.CreateIntrinsic(
+      EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
+
+  // verify that the length is non-zero
+  if (auto *constantFP = dyn_cast<ConstantFP>(Length)) {
+    const APFloat &fpVal = constantFP->getValueAPF();
+    if (fpVal.isZero())
+      report_fatal_error(Twine("Invalid input vector: length is zero"),
+                         /* gen_crash_diag=*/false);
+  }
+  Value *LengthVec = Builder.CreateVectorSplat(XVecSize, Length);
+  Value *Result = Builder.CreateFDiv(X, LengthVec);
+
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
 static bool expandPowIntrinsic(CallInst *Orig) {
 
   Value *X = Orig->getOperand(0);
@@ -314,6 +363,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     return expandLerpIntrinsic(Orig);
   case Intrinsic::dx_length:
     return expandLengthIntrinsic(Orig);
+  case Intrinsic::dx_normalize:
+    return expandNormalizeIntrinsic(Orig);
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
new file mode 100644
index 00000000000000..3f66e8cac98d3a
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -0,0 +1,116 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for normalize are generated for half/float.
+
+declare half @llvm.dx.normalize.f16(half)
+declare <2 x half> @llvm.dx.normalize.v2f16(<2 x half>)
+declare <3 x half> @llvm.dx.normalize.v3f16(<3 x half>)
+declare <4 x half> @llvm.dx.normalize.v4f16(<4 x half>)
+
+declare float @llvm.dx.normalize.f32(float)
+declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
+declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
+declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
+
+define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <2 x half> %{{.*}}, i64 1
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: fadd half %{{.*}}, %{{.*}}
+  ; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})
+
+  %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
+  ret <2 x half> %hlsl.normalize
+}
+
+define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <3 x half> %{{.*}}, i64 1
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: fadd half %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <3 x half> %{{.*}}, i64 2
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: fadd half %{{.*}}, %{{.*}}
+  ; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})
+
+  %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
+  ret <3 x half> %hlsl.normalize
+}
+
+define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <4 x half> %{{.*}}, i64 1
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: fadd half %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <4 x half> %{{.*}}, i64 2
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: fadd half %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <4 x half> %{{.*}}, i64 3
+  ; CHECK: fmul half %{{.*}}, %{{.*}}
+  ; CHECK: fadd half %{{.*}}, %{{.*}}
+  ; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
+  ; DOPCHECK:  call half @dx.op.unary.f16(i32 24, half %{{.*}})
+
+  %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16(<4 x half> %p0)
+  ret <4 x half> %hlsl.normalize
+}
+
+define noundef <2 x float> @test_normalize_float2(<2 x float> noundef %p0) {
+entry:
+  ; CHECK: extractelement <2 x float> %{{.*}}, i64 0
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <2 x float> %{{.*}}, i64 1
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: fadd float %{{.*}}, %{{.*}}
+  ; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
+  ; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})
+
+  %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(<2 x float> %p0)
+  ret <2 x float> %hlsl.normalize
+}
+
+define noundef <3 x float> @test_normalize_float3(<3 x float> noundef %p0) {
+entry:
+  ; CHECK: extractelement <3 x float> %{{.*}}, i64 0
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <3 x float> %{{.*}}, i64 1
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: fadd float %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <3 x float> %{{.*}}, i64 2
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: fadd float %{{.*}}, %{{.*}}
+  ; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
+  ; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})
+
+  %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(<3 x float> %p0)
+  ret <3 x float> %hlsl.normalize
+}
+
+define noundef <4 x float> @test_normalize_float4(<4 x float> noundef %p0) {
+entry:
+  ; CHECK: extractelement <4 x float> %{{.*}}, i64 0
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <4 x float> %{{.*}}, i64 1
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: fadd float %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <4 x float> %{{.*}}, i64 2
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: fadd float %{{.*}}, %{{.*}}
+  ; CHECK: extractelement <4 x float> %{{.*}}, i64 3
+  ; CHECK: fmul float %{{.*}}, %{{.*}}
+  ; CHECK: fadd float %{{.*}}, %{{.*}}
+  ; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
+  ; DOPCHECK:  call float @dx.op.unary.f32(i32 24, float %{{.*}})
+
+  %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(<4 x float> %p0)
+  ret <4 x float> %hlsl.normalize
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/DirectX/normalize_error.ll b/llvm/test/CodeGen/DirectX/normalize_error.ll
new file mode 100644
index 00000000000000..cd117797e5714f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation normalize does not support double overload type
+; CHECK: Cannot create Sqrt operation: Invalid overload type
+
+define noundef <2 x double> @test_normalize_double2(<2 x double> noundef %p0) {
+entry:
+  %hlsl.normalize = call <2 x double> @llvm.dx.normalize.v2f32(<2 x double> %p0)
+  ret <2 x double> %hlsl.normalize
+}
\ No newline at end of file

>From e3ca0f0fb4f323b10f7f43ff4a26c053822f74c3 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 9 Aug 2024 12:36:15 -0700
Subject: [PATCH 2/8] optimize expansion, update tests and add scalar test
 variants

---
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  43 +++++--
 llvm/test/CodeGen/DirectX/normalize.ll        | 110 +++++++++---------
 llvm/test/CodeGen/DirectX/normalize_error.ll  |   2 +-
 3 files changed, 89 insertions(+), 66 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 649758e47db1dd..e80166e0ff0569 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -237,7 +237,6 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
 
-  Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
   auto *XVec = dyn_cast<FixedVectorType>(Ty);
   if (!XVec) {
     if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
@@ -253,25 +252,47 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
     return true;
   }
 
+  Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
   unsigned XVecSize = XVec->getNumElements();
-  Value *Sum = Builder.CreateFMul(Elt, Elt);
-  for (unsigned I = 1; I < XVecSize; I++) {
-    Elt = Builder.CreateExtractElement(X, I);
-    Value *Mul = Builder.CreateFMul(Elt, Elt);
-    Sum = Builder.CreateFAdd(Sum, Mul);
+  Value *DotProduct = nullptr;
+  switch (XVecSize) {
+  case 1:
+    report_fatal_error(Twine("Invalid input vector: length is zero"),
+                       /* gen_crash_diag=*/false);
+    break;
+  case 2:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
+    break;
+  case 3:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
+    break;
+  case 4:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
+    break;
+  default:
+    report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
+                       /* gen_crash_diag=*/false);
   }
-  Value *Length = Builder.CreateIntrinsic(
-      EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
+
+  Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+                                                ArrayRef<Value *>{DotProduct},
+                                                nullptr, "dx.rsqrt");
 
   // verify that the length is non-zero
-  if (auto *constantFP = dyn_cast<ConstantFP>(Length)) {
+  // (if the reciprocal sqrt of the length is non-zero, then the length is
+  // non-zero)
+  if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
     const APFloat &fpVal = constantFP->getValueAPF();
     if (fpVal.isZero())
       report_fatal_error(Twine("Invalid input vector: length is zero"),
                          /* gen_crash_diag=*/false);
   }
-  Value *LengthVec = Builder.CreateVectorSplat(XVecSize, Length);
-  Value *Result = Builder.CreateFDiv(X, LengthVec);
+
+  Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+  Value *Result = Builder.CreateFMul(X, MultiplicandVec);
 
   Orig->replaceAllUsesWith(Result);
   Orig->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
index 3f66e8cac98d3a..8b4a6692e8725f 100644
--- a/llvm/test/CodeGen/DirectX/normalize.ll
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -13,15 +13,23 @@ declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
 declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
 declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
 
+define noundef half @test_normalize_half(half noundef %p0) {
+entry:
+  ; CHECK: fdiv half %p0, %p0
+  %hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
+  ret half %hlsl.normalize
+}
+
 define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
 entry:
   ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <2 x half> %{{.*}}, i64 1
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: fadd half %{{.*}}, %{{.*}}
-  ; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
-  ; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <2 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <2 x half> %{{.*}}, <2 x half> poison, <2 x i32> zeroinitializer
+  ; CHECK: fmul <2 x half> %{{.*}}, %{{.*}}  
 
   %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
   ret <2 x half> %hlsl.normalize
@@ -30,15 +38,13 @@ entry:
 define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
 entry:
   ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <3 x half> %{{.*}}, i64 1
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: fadd half %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <3 x half> %{{.*}}, i64 2
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: fadd half %{{.*}}, %{{.*}}
-  ; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
-  ; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <3 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <3 x half> %{{.*}}, <3 x half> poison, <3 x i32> zeroinitializer
+  ; CHECK: fmul <3 x half> %{{.*}}, %{{.*}}
 
   %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
   ret <3 x half> %hlsl.normalize
@@ -47,32 +53,35 @@ entry:
 define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
 entry:
   ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <4 x half> %{{.*}}, i64 1
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: fadd half %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <4 x half> %{{.*}}, i64 2
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: fadd half %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <4 x half> %{{.*}}, i64 3
-  ; CHECK: fmul half %{{.*}}, %{{.*}}
-  ; CHECK: fadd half %{{.*}}, %{{.*}}
-  ; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
-  ; DOPCHECK:  call half @dx.op.unary.f16(i32 24, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <4 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <4 x half> %{{.*}}, <4 x half> poison, <4 x i32> zeroinitializer
+  ; CHECK: fmul <4 x half> %{{.*}}, %{{.*}}
 
   %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16(<4 x half> %p0)
   ret <4 x half> %hlsl.normalize
 }
 
+define noundef float @test_normalize_float(float noundef %p0) {
+entry:
+  ; CHECK: fdiv float %p0, %p0
+  %hlsl.normalize = call float @llvm.dx.normalize.f32(float %p0)
+  ret float %hlsl.normalize
+}
+
 define noundef <2 x float> @test_normalize_float2(<2 x float> noundef %p0) {
 entry:
   ; CHECK: extractelement <2 x float> %{{.*}}, i64 0
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <2 x float> %{{.*}}, i64 1
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: fadd float %{{.*}}, %{{.*}}
-  ; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
-  ; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})
+  ; EXPCHECK: call float @llvm.dx.dot2.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
+  ; DOPCHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+  ; EXPCHECK: call float @llvm.dx.rsqrt.f32(float %{{.*}})
+  ; DOPCHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
+  ; CHECK: insertelement <2 x float> poison, float %{{.*}}, i64 0
+  ; CHECK: shufflevector <2 x float> %{{.*}}, <2 x float> poison, <2 x i32> zeroinitializer
+  ; CHECK: fmul <2 x float> %{{.*}}, %{{.*}}
 
   %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(<2 x float> %p0)
   ret <2 x float> %hlsl.normalize
@@ -81,15 +90,13 @@ entry:
 define noundef <3 x float> @test_normalize_float3(<3 x float> noundef %p0) {
 entry:
   ; CHECK: extractelement <3 x float> %{{.*}}, i64 0
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <3 x float> %{{.*}}, i64 1
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: fadd float %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <3 x float> %{{.*}}, i64 2
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: fadd float %{{.*}}, %{{.*}}
-  ; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
-  ; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})
+  ; EXPCHECK: call float @llvm.dx.dot3.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
+  ; DOPCHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+  ; EXPCHECK: call float @llvm.dx.rsqrt.f32(float %{{.*}})
+  ; DOPCHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
+  ; CHECK: insertelement <3 x float> poison, float %{{.*}}, i64 0
+  ; CHECK: shufflevector <3 x float> %{{.*}}, <3 x float> poison, <3 x i32> zeroinitializer
+  ; CHECK: fmul <3 x float> %{{.*}}, %{{.*}}
 
   %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(<3 x float> %p0)
   ret <3 x float> %hlsl.normalize
@@ -98,18 +105,13 @@ entry:
 define noundef <4 x float> @test_normalize_float4(<4 x float> noundef %p0) {
 entry:
   ; CHECK: extractelement <4 x float> %{{.*}}, i64 0
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <4 x float> %{{.*}}, i64 1
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: fadd float %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <4 x float> %{{.*}}, i64 2
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: fadd float %{{.*}}, %{{.*}}
-  ; CHECK: extractelement <4 x float> %{{.*}}, i64 3
-  ; CHECK: fmul float %{{.*}}, %{{.*}}
-  ; CHECK: fadd float %{{.*}}, %{{.*}}
-  ; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
-  ; DOPCHECK:  call float @dx.op.unary.f32(i32 24, float %{{.*}})
+  ; EXPCHECK: call float @llvm.dx.dot4.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
+  ; DOPCHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+  ; EXPCHECK: call float @llvm.dx.rsqrt.f32(float %{{.*}})
+  ; DOPCHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
+  ; CHECK: insertelement <4 x float> poison, float %{{.*}}, i64 0
+  ; CHECK: shufflevector <4 x float> %{{.*}}, <4 x float> poison, <4 x i32> zeroinitializer
+  ; CHECK: fmul <4 x float> %{{.*}}, %{{.*}}
 
   %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(<4 x float> %p0)
   ret <4 x float> %hlsl.normalize
diff --git a/llvm/test/CodeGen/DirectX/normalize_error.ll b/llvm/test/CodeGen/DirectX/normalize_error.ll
index cd117797e5714f..6278e8d367d35a 100644
--- a/llvm/test/CodeGen/DirectX/normalize_error.ll
+++ b/llvm/test/CodeGen/DirectX/normalize_error.ll
@@ -1,7 +1,7 @@
 ; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation normalize does not support double overload type
-; CHECK: Cannot create Sqrt operation: Invalid overload type
+; CHECK: Cannot create Dot2 operation: Invalid overload type
 
 define noundef <2 x double> @test_normalize_double2(<2 x double> noundef %p0) {
 entry:

>From bd403524edaab7218dd6e03f75e177279deb86f8 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 9 Aug 2024 13:48:57 -0700
Subject: [PATCH 3/8] add spirv backend

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 22 ++++++++++++++
 .../SPIRV/hlsl-intrinsics/normalize.ll        | 29 +++++++++++++++++++
 2 files changed, 51 insertions(+)
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ed786bd33aa05b..ba8c0de5d91dc7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -238,6 +238,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectLog10(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
+                   MachineInstr &I) const;
+
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
@@ -1349,6 +1352,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
+                                          const SPIRVType *ResType,
+                                          MachineInstr &I) const {
+
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::Normalize)
+      .addUse(I.getOperand(2).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
                                            const SPIRVType *ResType,
                                            MachineInstr &I) const {
@@ -2080,6 +2100,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectFmix(ResVReg, ResType, I);
   case Intrinsic::spv_frac:
     return selectFrac(ResVReg, ResType, I);
+  case Intrinsic::spv_normalize:
+    return selectNormalize(ResVReg, ResType, I);
   case Intrinsic::spv_rsqrt:
     return selectRsqrt(ResVReg, ResType, I);
   case Intrinsic::spv_lifetime_start:
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll
new file mode 100644
index 00000000000000..1ee0b8b5041f93
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll
@@ -0,0 +1,29 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure SPIRV operation function calls for normalize are lowered correctly.
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+
+define noundef <4 x half> @normalize_half4(<4 x half> noundef %a) {
+entry:
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] Normalize %[[#arg0]]
+  %hlsl.normalize = call <4 x half> @llvm.spv.normalize.v4f16(<4 x half> %a)
+  ret <4 x half> %hlsl.normalize
+}
+
+define noundef <4 x float> @normalize_float4(<4 x float> noundef %a) {
+entry:
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] Normalize %[[#arg0]]
+  %hlsl.normalize = call <4 x float> @llvm.spv.normalize.v4f32(<4 x float> %a)
+  ret <4 x float> %hlsl.normalize
+}
+
+declare <4 x half> @llvm.spv.normalize.v4f16(<4 x half>)
+declare <4 x float> @llvm.spv.normalize.v4f32(<4 x float>)
\ No newline at end of file

>From 45a7ff733634794f92d3bf587fe4d28a42747720 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 9 Aug 2024 13:49:39 -0700
Subject: [PATCH 4/8] clang format

---
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ba8c0de5d91dc7..6e27b6c12f8335 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1353,8 +1353,8 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
 }
 
 bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
-                                          const SPIRVType *ResType,
-                                          MachineInstr &I) const {
+                                               const SPIRVType *ResType,
+                                               MachineInstr &I) const {
 
   assert(I.getNumOperands() == 3);
   assert(I.getOperand(2).isReg());

>From 6480d2decf907fd7f48185094dfc08db67542cff Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 9 Aug 2024 14:17:26 -0700
Subject: [PATCH 5/8] clang-format

---
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 6e27b6c12f8335..c1e8ccae33db8e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -239,7 +239,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
                    MachineInstr &I) const;
 
   bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
-                   MachineInstr &I) const;
+                       MachineInstr &I) const;
 
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;

>From 76a2d07719a34a03f42dba2a627b65c0fbc19167 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 9 Aug 2024 15:08:28 -0700
Subject: [PATCH 6/8] add comment to reinitiate build tests

---
 clang/lib/Sema/SemaHLSL.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 61f68a415a7d6c..e3e926465e799e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1116,7 +1116,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
 
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
-
+    // return type is the same as the input type
     TheCall->setType(ArgTyA);
     break;
   }

>From a9188add047b39f80be23ad19029361d5a8228b8 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 9 Aug 2024 16:56:29 -0700
Subject: [PATCH 7/8] another comment to kick off builds again

---
 llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e80166e0ff0569..0ef31ca9af2253 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -255,6 +255,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
   Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
   unsigned XVecSize = XVec->getNumElements();
   Value *DotProduct = nullptr;
+  // use the dot intrinsic corresponding to the vector size
   switch (XVecSize) {
   case 1:
     report_fatal_error(Twine("Invalid input vector: length is zero"),

>From 13102f6df80525181cc959d84b56c37a0f2a0d10 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 9 Aug 2024 16:59:32 -0700
Subject: [PATCH 8/8] prevent div by 0

---
 llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp | 13 ++++++-------
 1 file changed, 6 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 0ef31ca9af2253..626321f44c2bfc 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -278,20 +278,19 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
                        /* gen_crash_diag=*/false);
   }
 
-  Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
-                                                ArrayRef<Value *>{DotProduct},
-                                                nullptr, "dx.rsqrt");
-
   // verify that the length is non-zero
-  // (if the reciprocal sqrt of the length is non-zero, then the length is
-  // non-zero)
-  if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
+  // (if the dot product is non-zero, then the length is non-zero)
+  if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
     const APFloat &fpVal = constantFP->getValueAPF();
     if (fpVal.isZero())
       report_fatal_error(Twine("Invalid input vector: length is zero"),
                          /* gen_crash_diag=*/false);
   }
 
+  Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+                                                ArrayRef<Value *>{DotProduct},
+                                                nullptr, "dx.rsqrt");
+
   Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
   Value *Result = Builder.CreateFMul(X, MultiplicandVec);
 



More information about the cfe-commits mailing list