[clang] [llvm] [HLSL] implement elementwise firstbithigh hlsl builtin (PR #111082)

Sarah Spall via cfe-commits cfe-commits at lists.llvm.org
Thu Oct 24 13:58:50 PDT 2024


https://github.com/spall updated https://github.com/llvm/llvm-project/pull/111082

>From 6239941c302f616f87ed652151e828a8eae1054c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Mon, 23 Sep 2024 22:10:59 +0000
Subject: [PATCH 1/8] implement firstbithigh hlsl builtin

---
 clang/include/clang/Basic/Builtins.td         |   6 +
 clang/lib/CodeGen/CGBuiltin.cpp               |  17 ++
 clang/lib/CodeGen/CGHLSLRuntime.h             |   2 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  72 +++++++++
 clang/lib/Sema/SemaHLSL.cpp                   |  18 +++
 .../CodeGenHLSL/builtins/firstbithigh.hlsl    | 153 ++++++++++++++++++
 .../BuiltIns/firstbithigh-errors.hlsl         |  28 ++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |   2 +
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |   2 +
 llvm/lib/Target/DirectX/DXIL.td               |  24 +++
 .../DirectX/DirectXTargetTransformInfo.cpp    |   2 +
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  14 ++
 llvm/test/CodeGen/DirectX/firstbithigh.ll     |  91 +++++++++++
 .../CodeGen/DirectX/firstbitshigh_error.ll    |  10 ++
 .../CodeGen/DirectX/firstbituhigh_error.ll    |  10 ++
 .../SPIRV/hlsl-intrinsics/firstbithigh.ll     |  37 +++++
 16 files changed, 488 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/firstbithigh.ll
 create mode 100644 llvm/test/CodeGen/DirectX/firstbitshigh_error.ll
 create mode 100644 llvm/test/CodeGen/DirectX/firstbituhigh_error.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..d294f680bb244d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4792,6 +4792,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLFirstBitHigh : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_firstbithigh"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLFrac : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_frac"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 3f28b7f26c36fe..51fc0245fd5517 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18641,6 +18641,14 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getUDotIntrinsic();
 }
 
+Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
+  if (QT->hasSignedIntegerRepresentation()) {
+    return RT.getFirstBitSHighIntrinsic();
+  }
+
+  return RT.getFirstBitUHighIntrinsic();
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -18730,6 +18738,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
   } break;
+  case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
+    
+    Value *X = EmitScalarExpr(E->getArg(0));
+    
+    return Builder.CreateIntrinsic(
+	/*ReturnType=*/X->getType(),
+	getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
+	ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
+  }
   case Builtin::BI__builtin_hlsl_lerp: {
     Value *X = EmitScalarExpr(E->getArg(0));
     Value *Y = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..4e37123e3f110a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -91,6 +91,8 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..4b3a4f50ceb981 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -958,6 +958,78 @@ float3 exp2(float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_exp2)
 float4 exp2(float4);
 
+//===----------------------------------------------------------------------===//
+// firstbithigh builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T firstbithigh(T Val)
+/// \brief Returns the location of the first set bit starting from the highest
+/// order bit and working downward, per component.
+/// \param Val the input value.
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t firstbithigh(int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t2 firstbithigh(int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t3 firstbithigh(int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t4 firstbithigh(int16_t4);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t firstbithigh(uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t2 firstbithigh(uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t3 firstbithigh(uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t4 firstbithigh(uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int firstbithigh(int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int2 firstbithigh(int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int3 firstbithigh(int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int4 firstbithigh(int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint firstbithigh(uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint2 firstbithigh(uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint3 firstbithigh(uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint4 firstbithigh(uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t firstbithigh(int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t2 firstbithigh(int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t3 firstbithigh(int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t4 firstbithigh(int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t firstbithigh(uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t2 firstbithigh(uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t3 firstbithigh(uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t4 firstbithigh(uint64_t4);
+  
 //===----------------------------------------------------------------------===//
 // floor builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index c6627b0e993226..864f6a197a97fa 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1927,6 +1927,24 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
+    if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+      return true;
+
+    const Expr *Arg = TheCall->getArg(0);
+    QualType ArgTy = Arg->getType();
+    QualType EltTy = ArgTy;
+
+    if (auto *VecTy = EltTy->getAs<VectorType>())
+      EltTy = VecTy->getElementType();
+
+    if (!EltTy->isIntegerType()) {
+      Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+          << 1 << /* integer ty */ 6 << ArgTy;
+      return true;
+    }
+    break;
+  }
   case Builtin::BI__builtin_hlsl_select: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
new file mode 100644
index 00000000000000..9821b308e63521
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
@@ -0,0 +1,153 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s -DTARGET=dx
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s -DTARGET=spv
+
+#ifdef __HLSL_ENABLE_16_BIT
+// CHECK-LABEL: test_firstbithigh_ushort
+// CHECK: call i16 @llvm.[[TARGET]].firstbituhigh.i16
+int test_firstbithigh_ushort(uint16_t p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort2
+// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbituhigh.v2i16
+uint16_t2 test_firstbithigh_ushort2(uint16_t2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort3
+// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbituhigh.v3i16
+uint16_t3 test_firstbithigh_ushort3(uint16_t3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort4
+// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbituhigh.v4i16
+uint16_t4 test_firstbithigh_ushort4(uint16_t4 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short
+// CHECK: call i16 @llvm.[[TARGET]].firstbitshigh.i16
+int16_t test_firstbithigh_short(int16_t p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short2
+// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbitshigh.v2i16
+int16_t2 test_firstbithigh_short2(int16_t2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short3
+// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbitshigh.v3i16
+int16_t3 test_firstbithigh_short3(int16_t3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short4
+// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbitshigh.v4i16
+int16_t4 test_firstbithigh_short4(int16_t4 p0) {
+  return firstbithigh(p0);
+}
+#endif // __HLSL_ENABLE_16_BIT
+
+// CHECK-LABEL: test_firstbithigh_uint
+// CHECK: call i32 @llvm.[[TARGET]].firstbituhigh.i32
+uint test_firstbithigh_uint(uint p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint2
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbituhigh.v2i32
+uint2 test_firstbithigh_uint2(uint2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint3
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbituhigh.v3i32
+uint3 test_firstbithigh_uint3(uint3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint4
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbituhigh.v4i32
+uint4 test_firstbithigh_uint4(uint4 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong
+// CHECK: call i64 @llvm.[[TARGET]].firstbituhigh.i64
+uint64_t test_firstbithigh_ulong(uint64_t p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong2
+// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbituhigh.v2i64
+uint64_t2 test_firstbithigh_ulong2(uint64_t2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong3
+// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbituhigh.v3i64
+uint64_t3 test_firstbithigh_ulong3(uint64_t3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong4
+// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbituhigh.v4i64
+uint64_t4 test_firstbithigh_ulong4(uint64_t4 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int
+// CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i32
+int test_firstbithigh_int(int p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int2
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i32
+int2 test_firstbithigh_int2(int2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int3
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i32
+int3 test_firstbithigh_int3(int3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int4
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i32
+int4 test_firstbithigh_int4(int4 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long
+// CHECK: call i64 @llvm.[[TARGET]].firstbitshigh.i64
+int64_t test_firstbithigh_long(int64_t p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long2
+// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbitshigh.v2i64
+int64_t2 test_firstbithigh_long2(int64_t2 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long3
+// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbitshigh.v3i64
+int64_t3 test_firstbithigh_long3(int64_t3 p0) {
+  return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long4
+// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbitshigh.v4i64
+int64_t4 test_firstbithigh_long4(int64_t4 p0) {
+  return firstbithigh(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl
new file mode 100644
index 00000000000000..1912ab3ae806b3
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+int test_too_few_arg() {
+  return firstbithigh();
+  // expected-error at -1 {{no matching function for call to 'firstbithigh'}}
+}
+
+int test_too_many_arg(int p0) {
+  return firstbithigh(p0, p0);
+  // expected-error at -1 {{no matching function for call to 'firstbithigh'}}
+}
+
+double test_int_builtin(double p0) {
+  return firstbithigh(p0);
+  // expected-error at -1 {{call to 'firstbithigh' is ambiguous}}
+}
+
+double2 test_int_builtin_2(double2 p0) {
+  return __builtin_hlsl_elementwise_firstbithigh(p0);
+  // expected-error at -1 {{1st argument must be a vector of integers
+  // (was 'double2' (aka 'vector<double, 2>'))}}
+}
+
+float test_int_builtin_3(float p0) {
+  return __builtin_hlsl_elementwise_firstbithigh(p0);
+  // expected-error at -1 {{1st argument must be a vector of integers
+  // (was 'float')}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..cfffd91afde880 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -92,4 +92,6 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
 def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], 
     [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 6df2eb156a0774..33d8199ae2e734 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -97,4 +97,6 @@ let TargetPrefix = "spv" in {
             [llvm_any_ty],
             [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
             [IntrNoMem]>;
+  def int_spv_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+  def int_spv_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..a6d1fe55227145 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,6 +564,30 @@ def CBits :  DXILOp<31, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+def FBH :  DXILOp<33, unary> {
+  let Doc = "Returns the location of the first set bit starting from "
+            "the highest order bit and working downward.";
+  let LLVMIntrinsic = int_dx_firstbituhigh;
+  let arguments = [OverloadTy];
+  let result = OverloadTy;
+  let overloads =
+      [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
+def FBSH :  DXILOp<34, unary> {
+  let Doc = "Returns the location of the first set bit from "
+            "the highest order bit based on the sign.";
+  let LLVMIntrinsic = int_dx_firstbitshigh;
+  let arguments = [OverloadTy];
+  let result = OverloadTy;
+  let overloads =
+      [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def FMax :  DXILOp<35, binary> {
   let Doc = "Float maximum. FMax(a,b) = a > b ? a : b";
   let LLVMIntrinsic = int_maxnum;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..b0436a39423405 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -32,6 +32,8 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
   case Intrinsic::dx_rsqrt:
   case Intrinsic::dx_wave_readlane:
   case Intrinsic::dx_splitdouble:
+  case Intrinsic::dx_firstbituhigh:
+  case Intrinsic::dx_firstbitshigh:
     return true;
   default:
     return false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d9377fe4b91a1a..2167aef36f6c06 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -213,6 +213,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectPhi(Register ResVReg, const SPIRVType *ResType,
                  MachineInstr &I) const;
 
+  bool selectExtInst(Register ResVReg, const SPIRVType *RestType,
+		     MachineInstr &I, GL::GLSLExtInst GLInst) const;
   bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
                      MachineInstr &I, CL::OpenCLExtInst CLInst) const;
   bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
@@ -761,6 +763,14 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   }
 }
 
+bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
+					     const SPIRVType *ResType,
+					     MachineInstr &I,
+					     GL::GLSLExtInst GLInst) const {
+  return selectExtInst(ResVReg, ResType, I,
+		       {{SPIRV::InstructionSet::GLSL_std_450, GLInst}});
+}
+
 bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
                                              const SPIRVType *ResType,
                                              MachineInstr &I,
@@ -2547,6 +2557,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectExtInst(ResVReg, ResType, I, CL::rsqrt, GL::InverseSqrt);
   case Intrinsic::spv_sign:
     return selectSign(ResVReg, ResType, I);
+  case Intrinsic::spv_firstbituhigh:
+    return selectExtInst(ResVReg, ResType, I, GL::FindUMsb);
+  case Intrinsic::spv_firstbitshigh:
+    return selectExtInst(ResVReg, ResType, I, GL::FindSMsb);
   case Intrinsic::spv_lifetime_start:
   case Intrinsic::spv_lifetime_end: {
     unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/test/CodeGen/DirectX/firstbithigh.ll b/llvm/test/CodeGen/DirectX/firstbithigh.ll
new file mode 100644
index 00000000000000..4a97ad6226149f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/firstbithigh.ll
@@ -0,0 +1,91 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; Make sure dxil operation function calls for firstbithigh are generated for all integer types.
+
+define noundef i16 @test_firstbithigh_ushort(i16 noundef %a) {
+entry:
+; CHECK: call i16 @dx.op.unary.i16(i32 33, i16 %{{.*}})
+  %elt.firstbithigh = call i16 @llvm.dx.firstbituhigh.i16(i16 %a)
+  ret i16 %elt.firstbithigh
+}
+
+define noundef i16 @test_firstbithigh_short(i16 noundef %a) {
+entry:
+; CHECK: call i16 @dx.op.unary.i16(i32 34, i16 %{{.*}})
+  %elt.firstbithigh = call i16 @llvm.dx.firstbitshigh.i16(i16 %a)
+  ret i16 %elt.firstbithigh
+}
+
+define noundef i32 @test_firstbithigh_uint(i32 noundef %a) {
+entry:
+; CHECK: call i32 @dx.op.unary.i32(i32 33, i32 %{{.*}})
+  %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i32(i32 %a)
+  ret i32 %elt.firstbithigh
+}
+
+define noundef i32 @test_firstbithigh_int(i32 noundef %a) {
+entry:
+; CHECK: call i32 @dx.op.unary.i32(i32 34, i32 %{{.*}})
+  %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i32(i32 %a)
+  ret i32 %elt.firstbithigh
+}
+
+define noundef i64 @test_firstbithigh_ulong(i64 noundef %a) {
+entry:
+; CHECK: call i64 @dx.op.unary.i64(i32 33, i64 %{{.*}})
+  %elt.firstbithigh = call i64 @llvm.dx.firstbituhigh.i64(i64 %a)
+  ret i64 %elt.firstbithigh
+}
+
+define noundef i64 @test_firstbithigh_long(i64 noundef %a) {
+entry:
+; CHECK: call i64 @dx.op.unary.i64(i32 34, i64 %{{.*}})
+  %elt.firstbithigh = call i64 @llvm.dx.firstbitshigh.i64(i64 %a)
+  ret i64 %elt.firstbithigh
+}
+
+define noundef <4 x i32> @test_firstbituhigh_vec4_i32(<4 x i32> noundef %a)  {
+entry:
+  ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
+  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee0]])
+  ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
+  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee1]])
+  ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
+  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee2]])
+  ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
+  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee3]])
+  ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
+  ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
+  ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
+  ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie3]], i64 3
+  %2 = call <4 x i32> @llvm.dx.firstbituhigh.v4i32(<4 x i32> %a)
+  ret <4 x i32> %2
+}
+
+define noundef <4 x i32> @test_firstbitshigh_vec4_i32(<4 x i32> noundef %a)  {
+entry:
+  ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
+  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee0]])
+  ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
+  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee1]])
+  ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
+  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee2]])
+  ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
+  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee3]])
+  ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
+  ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
+  ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
+  ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie3]], i64 3
+  %2 = call <4 x i32> @llvm.dx.firstbitshigh.v4i32(<4 x i32> %a)
+  ret <4 x i32> %2
+}
+
+declare i16 @llvm.dx.firstbituhigh.i16(i16)
+declare i32 @llvm.dx.firstbituhigh.i32(i32)
+declare i64 @llvm.dx.firstbituhigh.i64(i64)
+declare <4 x i32> @llvm.dx.firstbituhigh.v4i32(<4 x i32>)
+
+declare i16 @llvm.dx.firstbitshigh.i16(i16)
+declare i32 @llvm.dx.firstbitshigh.i32(i32)
+declare i64 @llvm.dx.firstbitshigh.i64(i64)
+declare <4 x i32> @llvm.dx.firstbitshigh.v4i32(<4 x i32>)
diff --git a/llvm/test/CodeGen/DirectX/firstbitshigh_error.ll b/llvm/test/CodeGen/DirectX/firstbitshigh_error.ll
new file mode 100644
index 00000000000000..22982a03e47921
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/firstbitshigh_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation firstbitshigh does not support double overload type
+; CHECK: invalid intrinsic signature
+
+define noundef double @firstbitshigh_double(double noundef %a) {
+entry:
+  %1 = call double @llvm.dx.firstbitshigh.f64(double %a)
+  ret double %1
+}
diff --git a/llvm/test/CodeGen/DirectX/firstbituhigh_error.ll b/llvm/test/CodeGen/DirectX/firstbituhigh_error.ll
new file mode 100644
index 00000000000000..b611a96ffc2f9c
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/firstbituhigh_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation firstbituhigh does not support double overload type
+; CHECK: invalid intrinsic signature
+
+define noundef double @firstbituhigh_double(double noundef %a) {
+entry:
+  %1 = call double @llvm.dx.firstbituhigh.f64(double %a)
+  ret double %1
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
new file mode 100644
index 00000000000000..e20a7a4cefe8ee
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
@@ -0,0 +1,37 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpMemoryModel Logical GLSL450
+
+define noundef i32 @firstbituhigh_i32(i32 noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb %[[#]]
+  %elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i32(i32 %a)
+  ret i32 %elt.firstbituhigh
+}
+
+define noundef i16 @firstbituhigh_i16(i16 noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb %[[#]]
+  %elt.firstbituhigh = call i16 @llvm.spv.firstbituhigh.i16(i16 %a)
+  ret i16 %elt.firstbituhigh
+}
+
+define noundef i32 @firstbitshigh_i32(i32 noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindSMsb %[[#]]
+  %elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i32(i32 %a)
+  ret i32 %elt.firstbitshigh
+}
+
+define noundef i16 @firstbitshigh_i16(i16 noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindSMsb %[[#]]
+  %elt.firstbitshigh = call i16 @llvm.spv.firstbitshigh.i16(i16 %a)
+  ret i16 %elt.firstbitshigh
+}
+
+declare i16 @llvm.spv.firstbituhigh.i16(i16)
+declare i32 @llvm.spv.firstbituhigh.i32(i32)
+declare i16 @llvm.spv.firstbitshigh.i16(i16)
+declare i32 @llvm.spv.firstbitshigh.i32(i32)

>From b9894f05e5ccb821fe55413ce3864a7c861c91b3 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 4 Oct 2024 01:22:58 +0000
Subject: [PATCH 2/8] make clang format happy

---
 clang/lib/CodeGen/CGBuiltin.cpp                    | 10 +++++-----
 clang/lib/Headers/hlsl/hlsl_intrinsics.h           |  2 +-
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 10 +++++-----
 3 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 51fc0245fd5517..0cd9e319dbd496 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18739,13 +18739,13 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
   } break;
   case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
-    
+
     Value *X = EmitScalarExpr(E->getArg(0));
-    
+
     return Builder.CreateIntrinsic(
-	/*ReturnType=*/X->getType(),
-	getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
-	ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
+        /*ReturnType=*/X->getType(),
+        getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
+        ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
   }
   case Builtin::BI__builtin_hlsl_lerp: {
     Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 4b3a4f50ceb981..6738a19d4b65ca 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1029,7 +1029,7 @@ _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
 uint64_t3 firstbithigh(uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
 uint64_t4 firstbithigh(uint64_t4);
-  
+
 //===----------------------------------------------------------------------===//
 // floor builtins
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 2167aef36f6c06..f75ca5d6de2857 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -214,7 +214,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
                  MachineInstr &I) const;
 
   bool selectExtInst(Register ResVReg, const SPIRVType *RestType,
-		     MachineInstr &I, GL::GLSLExtInst GLInst) const;
+                     MachineInstr &I, GL::GLSLExtInst GLInst) const;
   bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
                      MachineInstr &I, CL::OpenCLExtInst CLInst) const;
   bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
@@ -764,11 +764,11 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
 }
 
 bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
-					     const SPIRVType *ResType,
-					     MachineInstr &I,
-					     GL::GLSLExtInst GLInst) const {
+                                             const SPIRVType *ResType,
+                                             MachineInstr &I,
+                                             GL::GLSLExtInst GLInst) const {
   return selectExtInst(ResVReg, ResType, I,
-		       {{SPIRV::InstructionSet::GLSL_std_450, GLInst}});
+                       {{SPIRV::InstructionSet::GLSL_std_450, GLInst}});
 }
 
 bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,

>From 4f34efd75739b6e0006da93bba4d9bff44fcb0ba Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Sat, 5 Oct 2024 00:06:05 +0000
Subject: [PATCH 3/8] address PR comments

---
 clang/lib/CodeGen/CGBuiltin.cpp                    | 1 +
 clang/lib/Sema/SemaHLSL.cpp                        | 2 ++
 llvm/lib/Target/DirectX/DXIL.td                    | 4 ++--
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 4 ++--
 4 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 0cd9e319dbd496..30cfcace4d841c 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18646,6 +18646,7 @@ Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
     return RT.getFirstBitSHighIntrinsic();
   }
 
+  assert(QT->hasUnsignedIntegerRepresentation());
   return RT.getFirstBitUHighIntrinsic();
 }
 
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 864f6a197a97fa..66f88627503fcf 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1943,6 +1943,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
           << 1 << /* integer ty */ 6 << ArgTy;
       return true;
     }
+
+    TheCall->setType(ArgTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_select: {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index a6d1fe55227145..9542bdd575c8d8 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,7 +564,7 @@ def CBits :  DXILOp<31, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def FBH :  DXILOp<33, unary> {
+def FirstbitHi :  DXILOp<33, unaryBits> {
   let Doc = "Returns the location of the first set bit starting from "
             "the highest order bit and working downward.";
   let LLVMIntrinsic = int_dx_firstbituhigh;
@@ -576,7 +576,7 @@ def FBH :  DXILOp<33, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def FBSH :  DXILOp<34, unary> {
+def FirstbitSHi :  DXILOp<34, unaryBits> {
   let Doc = "Returns the location of the first set bit from "
             "the highest order bit based on the sign.";
   let LLVMIntrinsic = int_dx_firstbitshigh;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f75ca5d6de2857..71aa38869c3dfc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2557,9 +2557,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectExtInst(ResVReg, ResType, I, CL::rsqrt, GL::InverseSqrt);
   case Intrinsic::spv_sign:
     return selectSign(ResVReg, ResType, I);
-  case Intrinsic::spv_firstbituhigh:
+  case Intrinsic::spv_firstbituhigh: // There is no CL equivalent of FindUMsb
     return selectExtInst(ResVReg, ResType, I, GL::FindUMsb);
-  case Intrinsic::spv_firstbitshigh:
+  case Intrinsic::spv_firstbitshigh: // There is no CL equivalent of FindSMsb
     return selectExtInst(ResVReg, ResType, I, GL::FindSMsb);
   case Intrinsic::spv_lifetime_start:
   case Intrinsic::spv_lifetime_end: {

>From cb1f55c7257da4ad722ccef59e3a9d59d1f065dc Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 11 Oct 2024 15:21:00 +0000
Subject: [PATCH 4/8] update firstbithigh to the correct behavior. both dxil
 and spirv support 16, 32, and 64 bit now.

---
 clang/lib/CodeGen/CGBuiltin.cpp               |   2 +-
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  40 ++--
 clang/lib/Sema/SemaHLSL.cpp                   |   8 +-
 .../CodeGenHLSL/builtins/firstbithigh.hlsl    |  72 +++----
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |   4 +-
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |   4 +-
 llvm/lib/Target/DirectX/DXIL.td               |   8 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |   2 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 176 +++++++++++++++++-
 llvm/test/CodeGen/DirectX/firstbithigh.ll     |  40 ++--
 .../SPIRV/hlsl-intrinsics/firstbithigh.ll     |  90 ++++++++-
 11 files changed, 346 insertions(+), 100 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 30cfcace4d841c..4dbf2ecb8052e3 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18744,7 +18744,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     Value *X = EmitScalarExpr(E->getArg(0));
 
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/X->getType(),
+	 /*ReturnType=*/ConvertType(E->getType()),
         getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
         ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
   }
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6738a19d4b65ca..9655f1a803c8f4 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -970,38 +970,38 @@ float4 exp2(float4);
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int16_t firstbithigh(int16_t);
+uint firstbithigh(int16_t);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int16_t2 firstbithigh(int16_t2);
+uint2 firstbithigh(int16_t2);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int16_t3 firstbithigh(int16_t3);
+uint3 firstbithigh(int16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int16_t4 firstbithigh(int16_t4);
+uint4 firstbithigh(int16_t4);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint16_t firstbithigh(uint16_t);
+uint firstbithigh(uint16_t);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint16_t2 firstbithigh(uint16_t2);
+uint2 firstbithigh(uint16_t2);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint16_t3 firstbithigh(uint16_t3);
+uint3 firstbithigh(uint16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint16_t4 firstbithigh(uint16_t4);
+uint4 firstbithigh(uint16_t4);
 #endif
 
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int firstbithigh(int);
+uint firstbithigh(int);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int2 firstbithigh(int2);
+uint2 firstbithigh(int2);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int3 firstbithigh(int3);
+uint3 firstbithigh(int3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int4 firstbithigh(int4);
+uint4 firstbithigh(int4);
 
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
 uint firstbithigh(uint);
@@ -1013,22 +1013,22 @@ _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
 uint4 firstbithigh(uint4);
 
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int64_t firstbithigh(int64_t);
+uint firstbithigh(int64_t);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int64_t2 firstbithigh(int64_t2);
+uint2 firstbithigh(int64_t2);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int64_t3 firstbithigh(int64_t3);
+uint3 firstbithigh(int64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int64_t4 firstbithigh(int64_t4);
+uint4 firstbithigh(int64_t4);
 
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint64_t firstbithigh(uint64_t);
+uint firstbithigh(uint64_t);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint64_t2 firstbithigh(uint64_t2);
+uint2 firstbithigh(uint64_t2);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint64_t3 firstbithigh(uint64_t3);
+uint3 firstbithigh(uint64_t3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint64_t4 firstbithigh(uint64_t4);
+uint4 firstbithigh(uint64_t4);
 
 //===----------------------------------------------------------------------===//
 // floor builtins
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 66f88627503fcf..c8606de8a600af 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1935,8 +1935,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     QualType ArgTy = Arg->getType();
     QualType EltTy = ArgTy;
 
-    if (auto *VecTy = EltTy->getAs<VectorType>())
+    QualType ResTy = SemaRef.Context.UnsignedIntTy;
+
+    if (auto *VecTy = EltTy->getAs<VectorType>()) {
       EltTy = VecTy->getElementType();
+      ResTy = SemaRef.Context.getVectorType(ResTy, VecTy->getNumElements(), VecTy->getVectorKind());
+    }
 
     if (!EltTy->isIntegerType()) {
       Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
@@ -1944,7 +1948,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     }
 
-    TheCall->setType(ArgTy);
+    TheCall->setType(ResTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_select: {
diff --git a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
index 9821b308e63521..ce94a1c15c5332 100644
--- a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
@@ -8,50 +8,50 @@
 
 #ifdef __HLSL_ENABLE_16_BIT
 // CHECK-LABEL: test_firstbithigh_ushort
-// CHECK: call i16 @llvm.[[TARGET]].firstbituhigh.i16
-int test_firstbithigh_ushort(uint16_t p0) {
+// CHECK: call i32 @llvm.[[TARGET]].firstbituhigh.i16
+uint test_firstbithigh_ushort(uint16_t p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_ushort2
-// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbituhigh.v2i16
-uint16_t2 test_firstbithigh_ushort2(uint16_t2 p0) {
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbituhigh.v2i16
+uint2 test_firstbithigh_ushort2(uint16_t2 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_ushort3
-// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbituhigh.v3i16
-uint16_t3 test_firstbithigh_ushort3(uint16_t3 p0) {
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbituhigh.v3i16
+uint3 test_firstbithigh_ushort3(uint16_t3 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_ushort4
-// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbituhigh.v4i16
-uint16_t4 test_firstbithigh_ushort4(uint16_t4 p0) {
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbituhigh.v4i16
+uint4 test_firstbithigh_ushort4(uint16_t4 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_short
-// CHECK: call i16 @llvm.[[TARGET]].firstbitshigh.i16
-int16_t test_firstbithigh_short(int16_t p0) {
+// CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i16
+uint test_firstbithigh_short(int16_t p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_short2
-// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbitshigh.v2i16
-int16_t2 test_firstbithigh_short2(int16_t2 p0) {
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i16
+uint2 test_firstbithigh_short2(int16_t2 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_short3
-// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbitshigh.v3i16
-int16_t3 test_firstbithigh_short3(int16_t3 p0) {
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i16
+uint3 test_firstbithigh_short3(int16_t3 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_short4
-// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbitshigh.v4i16
-int16_t4 test_firstbithigh_short4(int16_t4 p0) {
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i16
+uint4 test_firstbithigh_short4(int16_t4 p0) {
   return firstbithigh(p0);
 }
 #endif // __HLSL_ENABLE_16_BIT
@@ -81,73 +81,73 @@ uint4 test_firstbithigh_uint4(uint4 p0) {
 }
 
 // CHECK-LABEL: test_firstbithigh_ulong
-// CHECK: call i64 @llvm.[[TARGET]].firstbituhigh.i64
-uint64_t test_firstbithigh_ulong(uint64_t p0) {
+// CHECK: call i32 @llvm.[[TARGET]].firstbituhigh.i64
+uint test_firstbithigh_ulong(uint64_t p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_ulong2
-// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbituhigh.v2i64
-uint64_t2 test_firstbithigh_ulong2(uint64_t2 p0) {
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbituhigh.v2i64
+uint2 test_firstbithigh_ulong2(uint64_t2 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_ulong3
-// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbituhigh.v3i64
-uint64_t3 test_firstbithigh_ulong3(uint64_t3 p0) {
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbituhigh.v3i64
+uint3 test_firstbithigh_ulong3(uint64_t3 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_ulong4
-// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbituhigh.v4i64
-uint64_t4 test_firstbithigh_ulong4(uint64_t4 p0) {
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbituhigh.v4i64
+uint4 test_firstbithigh_ulong4(uint64_t4 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_int
 // CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i32
-int test_firstbithigh_int(int p0) {
+uint test_firstbithigh_int(int p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_int2
 // CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i32
-int2 test_firstbithigh_int2(int2 p0) {
+uint2 test_firstbithigh_int2(int2 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_int3
 // CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i32
-int3 test_firstbithigh_int3(int3 p0) {
+uint3 test_firstbithigh_int3(int3 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_int4
 // CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i32
-int4 test_firstbithigh_int4(int4 p0) {
+uint4 test_firstbithigh_int4(int4 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_long
-// CHECK: call i64 @llvm.[[TARGET]].firstbitshigh.i64
-int64_t test_firstbithigh_long(int64_t p0) {
+// CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i64
+uint test_firstbithigh_long(int64_t p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_long2
-// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbitshigh.v2i64
-int64_t2 test_firstbithigh_long2(int64_t2 p0) {
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i64
+uint2 test_firstbithigh_long2(int64_t2 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_long3
-// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbitshigh.v3i64
-int64_t3 test_firstbithigh_long3(int64_t3 p0) {
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i64
+uint3 test_firstbithigh_long3(int64_t3 p0) {
   return firstbithigh(p0);
 }
 
 // CHECK-LABEL: test_firstbithigh_long4
-// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbitshigh.v4i64
-int64_t4 test_firstbithigh_long4(int64_t4 p0) {
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i64
+uint4 test_firstbithigh_long4(int64_t4 p0) {
   return firstbithigh(p0);
 }
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index cfffd91afde880..612dbb398720eb 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -92,6 +92,6 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
 def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], 
     [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
-def int_dx_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
-def int_dx_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_firstbituhigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
+def int_dx_firstbitshigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 33d8199ae2e734..906af27445d527 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -97,6 +97,6 @@ let TargetPrefix = "spv" in {
             [llvm_any_ty],
             [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
             [IntrNoMem]>;
-  def int_spv_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
-  def int_spv_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+  def int_spv_firstbituhigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
+  def int_spv_firstbitshigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9542bdd575c8d8..9fdd79a3a7d6fc 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,24 +564,24 @@ def CBits :  DXILOp<31, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def FirstbitHi :  DXILOp<33, unaryBits> {
+def FirstbitHi :  DXILOp<33, unary> {
   let Doc = "Returns the location of the first set bit starting from "
             "the highest order bit and working downward.";
   let LLVMIntrinsic = int_dx_firstbituhigh;
   let arguments = [OverloadTy];
-  let result = OverloadTy;
+  let result = Int32Ty;
   let overloads =
       [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def FirstbitSHi :  DXILOp<34, unaryBits> {
+def FirstbitSHi :  DXILOp<34, unary> {
   let Doc = "Returns the location of the first set bit from "
             "the highest order bit based on the sign.";
   let LLVMIntrinsic = int_dx_firstbitshigh;
   let arguments = [OverloadTy];
-  let result = OverloadTy;
+  let result = Int32Ty;
   let overloads =
       [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 64fde8bf67ab91..fe8f946e46d82e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -424,7 +424,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
     LLT LLTy = LLT::scalar(64);
     Register SpvVecConst =
         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
-    CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
+    CurMF->getRegInfo().setRegClass(SpvVecConst, getRegClass(SpvType));
     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
     DT.add(CA, CurMF, SpvVecConst);
     MachineInstrBuilder MIB;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 71aa38869c3dfc..f4b9ec6834c92c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -92,9 +92,26 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool spvSelect(Register ResVReg, const SPIRVType *ResType,
                  MachineInstr &I) const;
 
+  bool selectFirstBitHigh(Register ResVReg, const SPIRVType *ResType,
+			  MachineInstr &I, bool IsSigned) const;
+
+  bool selectFirstBitHigh16(Register ResVReg, const SPIRVType *ResType,
+			    MachineInstr &I, bool IsSigned) const;
+
+  bool selectFirstBitHigh32(Register ResVReg, const SPIRVType *ResType,
+			    MachineInstr &I, Register SrcReg,
+			    bool IsSigned) const;
+
+  bool selectFirstBitHigh64(Register ResVReg, const SPIRVType *ResType,
+			    MachineInstr &I, bool IsSigned) const;
+
   bool selectGlobalValue(Register ResVReg, MachineInstr &I,
                          const MachineInstr *Init = nullptr) const;
 
+  bool selectNAryOpWithSrcs(Register ResVReg, const SPIRVType *ResType,
+			    MachineInstr &I, std::vector<Register> SrcRegs,
+			    unsigned Opcode) const;
+
   bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I, Register SrcReg,
                          unsigned Opcode) const;
@@ -818,6 +835,20 @@ bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
   return false;
 }
 
+bool SPIRVInstructionSelector::selectNAryOpWithSrcs(Register ResVReg,
+						    const SPIRVType *ResType,
+			                            MachineInstr &I,
+						    std::vector<Register> Srcs,
+						    unsigned Opcode) const {
+  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType));
+  for(Register SReg : Srcs) {
+    MIB.addUse(SReg);
+  }
+  return MIB.constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectUnOpWithSrc(Register ResVReg,
                                                  const SPIRVType *ResType,
                                                  MachineInstr &I,
@@ -2558,9 +2589,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_sign:
     return selectSign(ResVReg, ResType, I);
   case Intrinsic::spv_firstbituhigh: // There is no CL equivalent of FindUMsb
-    return selectExtInst(ResVReg, ResType, I, GL::FindUMsb);
+    return selectFirstBitHigh(ResVReg, ResType, I, false);
   case Intrinsic::spv_firstbitshigh: // There is no CL equivalent of FindSMsb
-    return selectExtInst(ResVReg, ResType, I, GL::FindSMsb);
+    return selectFirstBitHigh(ResVReg, ResType, I, true);
   case Intrinsic::spv_lifetime_start:
   case Intrinsic::spv_lifetime_end: {
     unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
@@ -2640,6 +2671,147 @@ Register SPIRVInstructionSelector::buildPointerToResource(
                                                  MIRBuilder);
 }
 
+bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
+						     const SPIRVType *ResType,
+						     MachineInstr &I,
+						     bool IsSigned) const {
+  unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert;
+  // zero or sign extend
+  Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  bool Result = selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(),
+				  Opcode);
+  return Result & selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh32(Register ResVReg,
+						    const SPIRVType *ResType,
+						    MachineInstr &I,
+						    Register SrcReg,
+						    bool IsSigned) const {
+  unsigned Opcode = IsSigned ? GL::FindSMsb : GL::FindUMsb;
+  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+    .addDef(ResVReg)
+    .addUse(GR.getSPIRVTypeID(ResType))
+    .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+    .addImm(Opcode)
+    .addUse(SrcReg)
+    .constrainAllUses(TII, TRI, RBI);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
+						     const SPIRVType *ResType,
+						     MachineInstr &I,
+						     bool IsSigned) const {
+  Register OpReg = I.getOperand(2).getReg();
+  // 1. split our int64 into 2 pieces using a bitcast
+  unsigned count = GR.getScalarOrVectorComponentCount(ResType);
+  SPIRVType *baseType = GR.retrieveScalarOrVectorIntType(ResType);
+  MachineIRBuilder MIRBuilder(I);
+  SPIRVType *postCastT = GR.getOrCreateSPIRVVectorType(baseType, 2 * count,
+						       MIRBuilder);
+  Register bitcastReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+  bool Result = selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg,
+				  SPIRV::OpBitcast);
+
+  // 2. call firstbithigh
+  Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+  Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
+
+  // 3. check if result of each top 32 bits is == -1
+  // split result vector into vector of high bits and vector of low bits
+  // get high bits
+  // if ResType is a scalar we need a vector anyways because our code
+  // operates on vectors, even vectors of length one.
+  SPIRVType *VResType = ResType;
+  bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
+  if (isScalarRes)
+    VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
+  // count should be one.
+
+  Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+		     TII.get(SPIRV::OpVectorShuffle))
+    .addDef(HighReg)
+    .addUse(GR.getSPIRVTypeID(VResType))
+    .addUse(FBHReg)
+    .addUse(FBHReg); // this vector will not be selected from; could be empty
+  unsigned i;
+  for(i = 0; i < count*2; i += 2) {
+    MIB.addImm(i);
+  }
+  Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+  // get low bits
+  Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+		TII.get(SPIRV::OpVectorShuffle))
+    .addDef(LowReg)
+    .addUse(GR.getSPIRVTypeID(VResType))
+    .addUse(FBHReg)
+    .addUse(FBHReg); // this vector will not be selected from; could be empty
+  for(i = 1; i < count*2; i += 2) {
+    MIB.addImm(i);
+  }
+  Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+  SPIRVType *BoolType =
+    GR.getOrCreateSPIRVVectorType(GR.getOrCreateSPIRVBoolType(I, TII),
+				  count,
+				  MIRBuilder);
+  // check if the high bits are == -1;
+  Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
+  // true if -1
+  Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
+  Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
+				 SPIRV::OpIEqual);
+
+  // Select low bits if true in BReg, otherwise high bits
+  Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
+				 SPIRV::OpSelectVIVCond);
+
+  // Add 32 for high bits, 0 for low bits
+  Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  bool ZeroAsNull = STI.isOpenCLEnv();
+  Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
+  Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
+  Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
+				 SPIRV::OpSelectVIVCond);
+
+  Register AddReg = ResVReg;
+  if(isScalarRes)
+    AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
+			         SPIRV::OpIAddV);
+
+  // convert result back to scalar if necessary
+  if (!isScalarRes)
+    return Result;
+  else
+    return Result & selectNAryOpWithSrcs(ResVReg, ResType, I,
+					 {AddReg,
+					  GR.getOrCreateConstInt(0, I, ResType,
+								 TII)},
+					 SPIRV::OpVectorExtractDynamic);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
+						   const SPIRVType *ResType,
+						   MachineInstr &I,
+						   bool IsSigned) const {
+  // FindUMsb intrinsic only supports 32 bit integers
+  Register OpReg = I.getOperand(2).getReg();
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
+  unsigned bitWidth = GR.getScalarOrVectorBitWidth(OpType);
+
+  if (bitWidth == 16)
+    return selectFirstBitHigh16(ResVReg, ResType, I, IsSigned);
+  else if (bitWidth == 32)
+    return selectFirstBitHigh32(ResVReg, ResType, I, OpReg, IsSigned);
+  else // 64 bit
+    return selectFirstBitHigh64(ResVReg, ResType, I, IsSigned);
+}
+
 bool SPIRVInstructionSelector::selectAllocaArray(Register ResVReg,
                                                  const SPIRVType *ResType,
                                                  MachineInstr &I) const {
diff --git a/llvm/test/CodeGen/DirectX/firstbithigh.ll b/llvm/test/CodeGen/DirectX/firstbithigh.ll
index 4a97ad6226149f..de0b11c97a9b98 100644
--- a/llvm/test/CodeGen/DirectX/firstbithigh.ll
+++ b/llvm/test/CodeGen/DirectX/firstbithigh.ll
@@ -2,18 +2,18 @@
 
 ; Make sure dxil operation function calls for firstbithigh are generated for all integer types.
 
-define noundef i16 @test_firstbithigh_ushort(i16 noundef %a) {
+define noundef i32 @test_firstbithigh_ushort(i16 noundef %a) {
 entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 33, i16 %{{.*}})
-  %elt.firstbithigh = call i16 @llvm.dx.firstbituhigh.i16(i16 %a)
-  ret i16 %elt.firstbithigh
+; CHECK: call i32 @dx.op.unary.i16(i32 33, i16 %{{.*}})
+  %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i16(i16 %a)
+  ret i32 %elt.firstbithigh
 }
 
-define noundef i16 @test_firstbithigh_short(i16 noundef %a) {
+define noundef i32 @test_firstbithigh_short(i16 noundef %a) {
 entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 34, i16 %{{.*}})
-  %elt.firstbithigh = call i16 @llvm.dx.firstbitshigh.i16(i16 %a)
-  ret i16 %elt.firstbithigh
+; CHECK: call i32 @dx.op.unary.i16(i32 34, i16 %{{.*}})
+  %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i16(i16 %a)
+  ret i32 %elt.firstbithigh
 }
 
 define noundef i32 @test_firstbithigh_uint(i32 noundef %a) {
@@ -30,18 +30,18 @@ entry:
   ret i32 %elt.firstbithigh
 }
 
-define noundef i64 @test_firstbithigh_ulong(i64 noundef %a) {
+define noundef i32 @test_firstbithigh_ulong(i64 noundef %a) {
 entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 33, i64 %{{.*}})
-  %elt.firstbithigh = call i64 @llvm.dx.firstbituhigh.i64(i64 %a)
-  ret i64 %elt.firstbithigh
+; CHECK: call i32 @dx.op.unary.i64(i32 33, i64 %{{.*}})
+  %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i64(i64 %a)
+  ret i32 %elt.firstbithigh
 }
 
-define noundef i64 @test_firstbithigh_long(i64 noundef %a) {
+define noundef i32 @test_firstbithigh_long(i64 noundef %a) {
 entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 34, i64 %{{.*}})
-  %elt.firstbithigh = call i64 @llvm.dx.firstbitshigh.i64(i64 %a)
-  ret i64 %elt.firstbithigh
+; CHECK: call i32 @dx.op.unary.i64(i32 34, i64 %{{.*}})
+  %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i64(i64 %a)
+  ret i32 %elt.firstbithigh
 }
 
 define noundef <4 x i32> @test_firstbituhigh_vec4_i32(<4 x i32> noundef %a)  {
@@ -80,12 +80,12 @@ entry:
   ret <4 x i32> %2
 }
 
-declare i16 @llvm.dx.firstbituhigh.i16(i16)
+declare i32 @llvm.dx.firstbituhigh.i16(i16)
 declare i32 @llvm.dx.firstbituhigh.i32(i32)
-declare i64 @llvm.dx.firstbituhigh.i64(i64)
+declare i32 @llvm.dx.firstbituhigh.i64(i64)
 declare <4 x i32> @llvm.dx.firstbituhigh.v4i32(<4 x i32>)
 
-declare i16 @llvm.dx.firstbitshigh.i16(i16)
+declare i32 @llvm.dx.firstbitshigh.i16(i16)
 declare i32 @llvm.dx.firstbitshigh.i32(i32)
-declare i64 @llvm.dx.firstbitshigh.i64(i64)
+declare i32 @llvm.dx.firstbitshigh.i64(i64)
 declare <4 x i32> @llvm.dx.firstbitshigh.v4i32(<4 x i32>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
index e20a7a4cefe8ee..057b1a9c78722a 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
@@ -10,11 +10,57 @@ entry:
   ret i32 %elt.firstbituhigh
 }
 
-define noundef i16 @firstbituhigh_i16(i16 noundef %a) {
+define noundef <2 x i32> @firstbituhigh_2xi32(<2 x i32> noundef %a) {
 entry:
 ; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb %[[#]]
-  %elt.firstbituhigh = call i16 @llvm.spv.firstbituhigh.i16(i16 %a)
-  ret i16 %elt.firstbituhigh
+  %elt.firstbituhigh = call <2 x i32> @llvm.spv.firstbituhigh.v2i32(<2 x i32> %a)
+  ret <2 x i32> %elt.firstbituhigh
+}
+
+define noundef i32 @firstbituhigh_i16(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = OpUConvert %[[#]]
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb [[A]]
+  %elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i16(i16 %a)
+  ret i32 %elt.firstbituhigh
+}
+
+define noundef <2 x i32> @firstbituhigh_v2i16(<2 x i16> noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = OpUConvert %[[#]]
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb [[A]]
+  %elt.firstbituhigh = call <2 x i32> @llvm.spv.firstbituhigh.v2i16(<2 x i16> %a)
+  ret <2 x i32> %elt.firstbituhigh
+}
+
+define noundef i32 @firstbituhigh_i64(i64 noundef %a) {
+entry:
+; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
+; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindUMsb [[O]]
+; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
+; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
+; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
+; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
+; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
+; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
+  %elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i64(i64 %a)
+  ret i32 %elt.firstbituhigh
+}
+
+define noundef <2 x i32> @firstbituhigh_v2i64(<2 x i64> noundef %a) {
+entry:
+; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
+; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindUMsb [[O]]
+; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
+; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
+; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
+; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
+; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
+; CHECK: OpReturnValue [[B]]
+  %elt.firstbituhigh = call <2 x i32> @llvm.spv.firstbituhigh.v2i64(<2 x i64> %a)
+  ret <2 x i32> %elt.firstbituhigh
 }
 
 define noundef i32 @firstbitshigh_i32(i32 noundef %a) {
@@ -24,14 +70,38 @@ entry:
   ret i32 %elt.firstbitshigh
 }
 
-define noundef i16 @firstbitshigh_i16(i16 noundef %a) {
+define noundef i32 @firstbitshigh_i16(i16 noundef %a) {
 entry:
+; CHECK: [[A:%.*]] = OpSConvert %[[#]]
 ; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindSMsb %[[#]]
-  %elt.firstbitshigh = call i16 @llvm.spv.firstbitshigh.i16(i16 %a)
-  ret i16 %elt.firstbitshigh
+  %elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i16(i16 %a)
+  ret i32 %elt.firstbitshigh
+}
+
+define noundef i32 @firstbitshigh_i64(i64 noundef %a) {
+entry:
+; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
+; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindSMsb [[O]]
+; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
+; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
+; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
+; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
+; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
+; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
+  %elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i64(i64 %a)
+  ret i32 %elt.firstbitshigh
 }
 
-declare i16 @llvm.spv.firstbituhigh.i16(i16)
-declare i32 @llvm.spv.firstbituhigh.i32(i32)
-declare i16 @llvm.spv.firstbitshigh.i16(i16)
-declare i32 @llvm.spv.firstbitshigh.i32(i32)
+;declare i16 @llvm.spv.firstbituhigh.i16(i16)
+;declare i32 @llvm.spv.firstbituhigh.i32(i32)
+;declare i64 @llvm.spv.firstbituhigh.i64(i64)
+;declare i16 @llvm.spv.firstbituhigh.v2i16(<2 x i16>)
+;declare i32 @llvm.spv.firstbituhigh.v2i32(<2 x i32>)
+;declare i64 @llvm.spv.firstbituhigh.v2i64(<2 x i64>)
+;declare i16 @llvm.spv.firstbitshigh.i16(i16)
+;declare i32 @llvm.spv.firstbitshigh.i32(i32)
+;declare i64 @llvm.spv.firstbitshigh.i64(i64)
+;declare i16 @llvm.spv.firstbitshigh.v2i16(<2 x i16>)
+;declare i32 @llvm.spv.firstbitshigh.v2i32(<2 x i32>)
+;declare i64 @llvm.spv.firstbitshigh.v2i64(<2 x i64>)

>From 70734cea8f5965e85c87d3d7282ea5e850f87ec6 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 16 Oct 2024 16:41:04 +0000
Subject: [PATCH 5/8] make clang format happy

---
 clang/lib/CodeGen/CGBuiltin.cpp               |   2 +-
 clang/lib/Sema/SemaHLSL.cpp                   |   3 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 135 +++++++++---------
 3 files changed, 71 insertions(+), 69 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 4dbf2ecb8052e3..50a893d0338bc8 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18744,7 +18744,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     Value *X = EmitScalarExpr(E->getArg(0));
 
     return Builder.CreateIntrinsic(
-	 /*ReturnType=*/ConvertType(E->getType()),
+        /*ReturnType=*/ConvertType(E->getType()),
         getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
         ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
   }
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index c8606de8a600af..edf1b07b75fa7d 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1939,7 +1939,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
 
     if (auto *VecTy = EltTy->getAs<VectorType>()) {
       EltTy = VecTy->getElementType();
-      ResTy = SemaRef.Context.getVectorType(ResTy, VecTy->getNumElements(), VecTy->getVectorKind());
+      ResTy = SemaRef.Context.getVectorType(ResTy, VecTy->getNumElements(),
+                                            VecTy->getVectorKind());
     }
 
     if (!EltTy->isIntegerType()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f4b9ec6834c92c..c4a44bcc765b73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -93,24 +93,24 @@ class SPIRVInstructionSelector : public InstructionSelector {
                  MachineInstr &I) const;
 
   bool selectFirstBitHigh(Register ResVReg, const SPIRVType *ResType,
-			  MachineInstr &I, bool IsSigned) const;
+                          MachineInstr &I, bool IsSigned) const;
 
   bool selectFirstBitHigh16(Register ResVReg, const SPIRVType *ResType,
-			    MachineInstr &I, bool IsSigned) const;
+                            MachineInstr &I, bool IsSigned) const;
 
   bool selectFirstBitHigh32(Register ResVReg, const SPIRVType *ResType,
-			    MachineInstr &I, Register SrcReg,
-			    bool IsSigned) const;
+                            MachineInstr &I, Register SrcReg,
+                            bool IsSigned) const;
 
   bool selectFirstBitHigh64(Register ResVReg, const SPIRVType *ResType,
-			    MachineInstr &I, bool IsSigned) const;
+                            MachineInstr &I, bool IsSigned) const;
 
   bool selectGlobalValue(Register ResVReg, MachineInstr &I,
                          const MachineInstr *Init = nullptr) const;
 
   bool selectNAryOpWithSrcs(Register ResVReg, const SPIRVType *ResType,
-			    MachineInstr &I, std::vector<Register> SrcRegs,
-			    unsigned Opcode) const;
+                            MachineInstr &I, std::vector<Register> SrcRegs,
+                            unsigned Opcode) const;
 
   bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I, Register SrcReg,
@@ -836,14 +836,14 @@ bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
 }
 
 bool SPIRVInstructionSelector::selectNAryOpWithSrcs(Register ResVReg,
-						    const SPIRVType *ResType,
-			                            MachineInstr &I,
-						    std::vector<Register> Srcs,
-						    unsigned Opcode) const {
+                                                    const SPIRVType *ResType,
+                                                    MachineInstr &I,
+                                                    std::vector<Register> Srcs,
+                                                    unsigned Opcode) const {
   auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
-      .addDef(ResVReg)
-      .addUse(GR.getSPIRVTypeID(ResType));
-  for(Register SReg : Srcs) {
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType));
+  for (Register SReg : Srcs) {
     MIB.addUse(SReg);
   }
   return MIB.constrainAllUses(TII, TRI, RBI);
@@ -2672,46 +2672,46 @@ Register SPIRVInstructionSelector::buildPointerToResource(
 }
 
 bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
-						     const SPIRVType *ResType,
-						     MachineInstr &I,
-						     bool IsSigned) const {
+                                                    const SPIRVType *ResType,
+                                                    MachineInstr &I,
+                                                    bool IsSigned) const {
   unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert;
   // zero or sign extend
   Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
-  bool Result = selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(),
-				  Opcode);
+  bool Result =
+      selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(), Opcode);
   return Result & selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
 }
 
 bool SPIRVInstructionSelector::selectFirstBitHigh32(Register ResVReg,
-						    const SPIRVType *ResType,
-						    MachineInstr &I,
-						    Register SrcReg,
-						    bool IsSigned) const {
+                                                    const SPIRVType *ResType,
+                                                    MachineInstr &I,
+                                                    Register SrcReg,
+                                                    bool IsSigned) const {
   unsigned Opcode = IsSigned ? GL::FindSMsb : GL::FindUMsb;
   return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
-    .addDef(ResVReg)
-    .addUse(GR.getSPIRVTypeID(ResType))
-    .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
-    .addImm(Opcode)
-    .addUse(SrcReg)
-    .constrainAllUses(TII, TRI, RBI);
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(Opcode)
+      .addUse(SrcReg)
+      .constrainAllUses(TII, TRI, RBI);
 }
 
 bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
-						     const SPIRVType *ResType,
-						     MachineInstr &I,
-						     bool IsSigned) const {
+                                                    const SPIRVType *ResType,
+                                                    MachineInstr &I,
+                                                    bool IsSigned) const {
   Register OpReg = I.getOperand(2).getReg();
   // 1. split our int64 into 2 pieces using a bitcast
   unsigned count = GR.getScalarOrVectorComponentCount(ResType);
   SPIRVType *baseType = GR.retrieveScalarOrVectorIntType(ResType);
   MachineIRBuilder MIRBuilder(I);
-  SPIRVType *postCastT = GR.getOrCreateSPIRVVectorType(baseType, 2 * count,
-						       MIRBuilder);
+  SPIRVType *postCastT =
+      GR.getOrCreateSPIRVVectorType(baseType, 2 * count, MIRBuilder);
   Register bitcastReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
-  bool Result = selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg,
-				  SPIRV::OpBitcast);
+  bool Result =
+      selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg, SPIRV::OpBitcast);
 
   // 2. call firstbithigh
   Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
@@ -2729,46 +2729,48 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
   // count should be one.
 
   Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
-		     TII.get(SPIRV::OpVectorShuffle))
-    .addDef(HighReg)
-    .addUse(GR.getSPIRVTypeID(VResType))
-    .addUse(FBHReg)
-    .addUse(FBHReg); // this vector will not be selected from; could be empty
+  auto MIB =
+      BuildMI(*I.getParent(), I, I.getDebugLoc(),
+              TII.get(SPIRV::OpVectorShuffle))
+          .addDef(HighReg)
+          .addUse(GR.getSPIRVTypeID(VResType))
+          .addUse(FBHReg)
+          .addUse(
+              FBHReg); // this vector will not be selected from; could be empty
   unsigned i;
-  for(i = 0; i < count*2; i += 2) {
+  for (i = 0; i < count * 2; i += 2) {
     MIB.addImm(i);
   }
   Result &= MIB.constrainAllUses(TII, TRI, RBI);
 
   // get low bits
   Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
-		TII.get(SPIRV::OpVectorShuffle))
-    .addDef(LowReg)
-    .addUse(GR.getSPIRVTypeID(VResType))
-    .addUse(FBHReg)
-    .addUse(FBHReg); // this vector will not be selected from; could be empty
-  for(i = 1; i < count*2; i += 2) {
+  MIB =
+      BuildMI(*I.getParent(), I, I.getDebugLoc(),
+              TII.get(SPIRV::OpVectorShuffle))
+          .addDef(LowReg)
+          .addUse(GR.getSPIRVTypeID(VResType))
+          .addUse(FBHReg)
+          .addUse(
+              FBHReg); // this vector will not be selected from; could be empty
+  for (i = 1; i < count * 2; i += 2) {
     MIB.addImm(i);
   }
   Result &= MIB.constrainAllUses(TII, TRI, RBI);
 
-  SPIRVType *BoolType =
-    GR.getOrCreateSPIRVVectorType(GR.getOrCreateSPIRVBoolType(I, TII),
-				  count,
-				  MIRBuilder);
+  SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
+      GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
   // check if the high bits are == -1;
   Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
   // true if -1
   Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
   Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
-				 SPIRV::OpIEqual);
+                                 SPIRV::OpIEqual);
 
   // Select low bits if true in BReg, otherwise high bits
   Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
   Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
-				 SPIRV::OpSelectVIVCond);
+                                 SPIRV::OpSelectVIVCond);
 
   // Add 32 for high bits, 0 for low bits
   Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
@@ -2776,29 +2778,28 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
   Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
   Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
   Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
-				 SPIRV::OpSelectVIVCond);
+                                 SPIRV::OpSelectVIVCond);
 
   Register AddReg = ResVReg;
-  if(isScalarRes)
+  if (isScalarRes)
     AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
   Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
-			         SPIRV::OpIAddV);
+                                 SPIRV::OpIAddV);
 
   // convert result back to scalar if necessary
   if (!isScalarRes)
     return Result;
   else
-    return Result & selectNAryOpWithSrcs(ResVReg, ResType, I,
-					 {AddReg,
-					  GR.getOrCreateConstInt(0, I, ResType,
-								 TII)},
-					 SPIRV::OpVectorExtractDynamic);
+    return Result & selectNAryOpWithSrcs(
+                        ResVReg, ResType, I,
+                        {AddReg, GR.getOrCreateConstInt(0, I, ResType, TII)},
+                        SPIRV::OpVectorExtractDynamic);
 }
 
 bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
-						   const SPIRVType *ResType,
-						   MachineInstr &I,
-						   bool IsSigned) const {
+                                                  const SPIRVType *ResType,
+                                                  MachineInstr &I,
+                                                  bool IsSigned) const {
   // FindUMsb intrinsic only supports 32 bit integers
   Register OpReg = I.getOperand(2).getReg();
   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);

>From a5cc2af07ae7d0b088b014f69c8add2b986dac18 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 16 Oct 2024 19:59:06 +0000
Subject: [PATCH 6/8] switch to unaryBits

---
 llvm/lib/Target/DirectX/DXIL.td           |  4 ++--
 llvm/test/CodeGen/DirectX/firstbithigh.ll | 28 +++++++++++------------
 2 files changed, 16 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9fdd79a3a7d6fc..dfa6c67fc4a083 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,7 +564,7 @@ def CBits :  DXILOp<31, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def FirstbitHi :  DXILOp<33, unary> {
+def FirstbitHi :  DXILOp<33, unaryBits> {
   let Doc = "Returns the location of the first set bit starting from "
             "the highest order bit and working downward.";
   let LLVMIntrinsic = int_dx_firstbituhigh;
@@ -576,7 +576,7 @@ def FirstbitHi :  DXILOp<33, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def FirstbitSHi :  DXILOp<34, unary> {
+def FirstbitSHi :  DXILOp<34, unaryBits> {
   let Doc = "Returns the location of the first set bit from "
             "the highest order bit based on the sign.";
   let LLVMIntrinsic = int_dx_firstbitshigh;
diff --git a/llvm/test/CodeGen/DirectX/firstbithigh.ll b/llvm/test/CodeGen/DirectX/firstbithigh.ll
index de0b11c97a9b98..5584c433fb6f0e 100644
--- a/llvm/test/CodeGen/DirectX/firstbithigh.ll
+++ b/llvm/test/CodeGen/DirectX/firstbithigh.ll
@@ -4,42 +4,42 @@
 
 define noundef i32 @test_firstbithigh_ushort(i16 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i16(i32 33, i16 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i16(i32 33, i16 %{{.*}})
   %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i16(i16 %a)
   ret i32 %elt.firstbithigh
 }
 
 define noundef i32 @test_firstbithigh_short(i16 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i16(i32 34, i16 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i16(i32 34, i16 %{{.*}})
   %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i16(i16 %a)
   ret i32 %elt.firstbithigh
 }
 
 define noundef i32 @test_firstbithigh_uint(i32 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 33, i32 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i32(i32 33, i32 %{{.*}})
   %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i32(i32 %a)
   ret i32 %elt.firstbithigh
 }
 
 define noundef i32 @test_firstbithigh_int(i32 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 34, i32 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i32(i32 34, i32 %{{.*}})
   %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i32(i32 %a)
   ret i32 %elt.firstbithigh
 }
 
 define noundef i32 @test_firstbithigh_ulong(i64 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i64(i32 33, i64 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i64(i32 33, i64 %{{.*}})
   %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i64(i64 %a)
   ret i32 %elt.firstbithigh
 }
 
 define noundef i32 @test_firstbithigh_long(i64 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i64(i32 34, i64 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i64(i32 34, i64 %{{.*}})
   %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i64(i64 %a)
   ret i32 %elt.firstbithigh
 }
@@ -47,13 +47,13 @@ entry:
 define noundef <4 x i32> @test_firstbituhigh_vec4_i32(<4 x i32> noundef %a)  {
 entry:
   ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
-  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee0]])
+  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 33, i32 [[ee0]])
   ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
-  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee1]])
+  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 33, i32 [[ee1]])
   ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
-  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee2]])
+  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 33, i32 [[ee2]])
   ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
-  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee3]])
+  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 33, i32 [[ee3]])
   ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
@@ -65,13 +65,13 @@ entry:
 define noundef <4 x i32> @test_firstbitshigh_vec4_i32(<4 x i32> noundef %a)  {
 entry:
   ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
-  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee0]])
+  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 34, i32 [[ee0]])
   ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
-  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee1]])
+  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 34, i32 [[ee1]])
   ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
-  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee2]])
+  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 34, i32 [[ee2]])
   ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
-  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee3]])
+  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 34, i32 [[ee3]])
   ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2

>From 990ffe11f4aa5c8a8c24b4638d5c70148732a565 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 23 Oct 2024 20:09:01 +0000
Subject: [PATCH 7/8] address pr comments

---
 .../CodeGenHLSL/builtins/firstbithigh.hlsl    |  2 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 59 ++++++++++---------
 2 files changed, 32 insertions(+), 29 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
index ce94a1c15c5332..debf6b6d3e3f5a 100644
--- a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
@@ -150,4 +150,4 @@ uint3 test_firstbithigh_long3(int64_t3 p0) {
 // CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i64
 uint4 test_firstbithigh_long4(int64_t4 p0) {
   return firstbithigh(p0);
-}
\ No newline at end of file
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c4a44bcc765b73..1dfcbfaadc49a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2589,9 +2589,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_sign:
     return selectSign(ResVReg, ResType, I);
   case Intrinsic::spv_firstbituhigh: // There is no CL equivalent of FindUMsb
-    return selectFirstBitHigh(ResVReg, ResType, I, false);
+    return selectFirstBitHigh(ResVReg, ResType, I, /*IsSigned=*/false);
   case Intrinsic::spv_firstbitshigh: // There is no CL equivalent of FindSMsb
-    return selectFirstBitHigh(ResVReg, ResType, I, true);
+    return selectFirstBitHigh(ResVReg, ResType, I, /*IsSigned=*/true);
   case Intrinsic::spv_lifetime_start:
   case Intrinsic::spv_lifetime_end: {
     unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
@@ -2729,32 +2729,30 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
   // count should be one.
 
   Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  auto MIB =
-      BuildMI(*I.getParent(), I, I.getDebugLoc(),
-              TII.get(SPIRV::OpVectorShuffle))
-          .addDef(HighReg)
-          .addUse(GR.getSPIRVTypeID(VResType))
-          .addUse(FBHReg)
-          .addUse(
-              FBHReg); // this vector will not be selected from; could be empty
-  unsigned i;
-  for (i = 0; i < count * 2; i += 2) {
-    MIB.addImm(i);
+  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                     TII.get(SPIRV::OpVectorShuffle))
+                 .addDef(HighReg)
+                 .addUse(GR.getSPIRVTypeID(VResType))
+                 .addUse(FBHReg)
+                 .addUse(FBHReg);
+  // ^^ this vector will not be selected from; could be empty
+  unsigned j;
+  for (j = 0; j < count * 2; j += 2) {
+    MIB.addImm(j);
   }
   Result &= MIB.constrainAllUses(TII, TRI, RBI);
 
   // get low bits
   Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  MIB =
-      BuildMI(*I.getParent(), I, I.getDebugLoc(),
-              TII.get(SPIRV::OpVectorShuffle))
-          .addDef(LowReg)
-          .addUse(GR.getSPIRVTypeID(VResType))
-          .addUse(FBHReg)
-          .addUse(
-              FBHReg); // this vector will not be selected from; could be empty
-  for (i = 1; i < count * 2; i += 2) {
-    MIB.addImm(i);
+  MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                TII.get(SPIRV::OpVectorShuffle))
+            .addDef(LowReg)
+            .addUse(GR.getSPIRVTypeID(VResType))
+            .addUse(FBHReg)
+            .addUse(FBHReg);
+  // ^^ this vector will not be selected from; could be empty
+  for (j = 1; j < count * 2; j += 2) {
+    MIB.addImm(j);
   }
   Result &= MIB.constrainAllUses(TII, TRI, RBI);
 
@@ -2783,6 +2781,7 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
   Register AddReg = ResVReg;
   if (isScalarRes)
     AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+
   Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
                                  SPIRV::OpIAddV);
 
@@ -2800,17 +2799,21 @@ bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
                                                   const SPIRVType *ResType,
                                                   MachineInstr &I,
                                                   bool IsSigned) const {
-  // FindUMsb intrinsic only supports 32 bit integers
+  // FindUMsb and FindSMsb intrinsics only support 32 bit integers
   Register OpReg = I.getOperand(2).getReg();
   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
-  unsigned bitWidth = GR.getScalarOrVectorBitWidth(OpType);
 
-  if (bitWidth == 16)
+  switch (GR.getScalarOrVectorBitWidth(OpType)) {
+  case 16:
     return selectFirstBitHigh16(ResVReg, ResType, I, IsSigned);
-  else if (bitWidth == 32)
+  case 32:
     return selectFirstBitHigh32(ResVReg, ResType, I, OpReg, IsSigned);
-  else // 64 bit
+  case 64:
     return selectFirstBitHigh64(ResVReg, ResType, I, IsSigned);
+  default:
+    report_fatal_error(
+        "spv_firstbituhigh and spv_firstbitshigh only support 16,32,64 bits.");
+  }
 }
 
 bool SPIRVInstructionSelector::selectAllocaArray(Register ResVReg,

>From 285b8ab8d83252e97e39baa7fb46085e018646e1 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 24 Oct 2024 18:21:55 +0000
Subject: [PATCH 8/8] remove use of vectors of length 1 in firstbithigh64 code

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |   9 ++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |   4 +
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 130 +++++++++---------
 .../SPIRV/hlsl-intrinsics/firstbithigh.ll     |  12 +-
 4 files changed, 84 insertions(+), 71 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index fe8f946e46d82e..5b257006f814fd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -449,6 +449,15 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
   return Res;
 }
 
+Register SPIRVGlobalRegistry::getOrCreateConstScalarOrVector(
+    uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
+    const SPIRVInstrInfo &TII, bool ZeroAsNull) {
+  if (SpvType->getOpcode() == SPIRV::OpTypeVector)
+    return getOrCreateConstVector(Val, I, SpvType, TII, ZeroAsNull);
+  else
+    return getOrCreateConstInt(Val, I, SpvType, TII, ZeroAsNull);
+}
+
 Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
                                                      MachineInstr &I,
                                                      SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a95b488960c4c3..7a6174523e6877 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -492,6 +492,10 @@ class SPIRVGlobalRegistry {
   Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
                            SPIRVType *SpvType = nullptr);
 
+  Register getOrCreateConstScalarOrVector(uint64_t Val, MachineInstr &I,
+                                          SPIRVType *SpvType,
+                                          const SPIRVInstrInfo &TII,
+                                          bool ZeroAsNull = true);
   Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
                                   SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                   bool ZeroAsNull = true);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 1dfcbfaadc49a0..326536d93d3570 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2717,82 +2717,82 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
   Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
   Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
 
-  // 3. check if result of each top 32 bits is == -1
-  // split result vector into vector of high bits and vector of low bits
-  // get high bits
-  // if ResType is a scalar we need a vector anyways because our code
-  // operates on vectors, even vectors of length one.
-  SPIRVType *VResType = ResType;
+  // 3. split result vector into high bits and low bits
+  Register HighReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  Register LowReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+
+  bool ZeroAsNull = STI.isOpenCLEnv();
   bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
-  if (isScalarRes)
-    VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
-  // count should be one.
+  if (isScalarRes) {
+    // if scalar do a vector extract
+    Result &= selectNAryOpWithSrcs(
+        HighReg, ResType, I,
+        {FBHReg, GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull)},
+        SPIRV::OpVectorExtractDynamic);
+    Result &= selectNAryOpWithSrcs(
+        LowReg, ResType, I,
+        {FBHReg, GR.getOrCreateConstInt(1, I, ResType, TII, ZeroAsNull)},
+        SPIRV::OpVectorExtractDynamic);
+  } else { // vector case do a shufflevector
+    auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                       TII.get(SPIRV::OpVectorShuffle))
+                   .addDef(HighReg)
+                   .addUse(GR.getSPIRVTypeID(ResType))
+                   .addUse(FBHReg)
+                   .addUse(FBHReg);
+    // ^^ this vector will not be selected from; could be empty
+    unsigned j;
+    for (j = 0; j < count * 2; j += 2) {
+      MIB.addImm(j);
+    }
+    Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+    // get low bits
+    MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                  TII.get(SPIRV::OpVectorShuffle))
+              .addDef(LowReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(FBHReg)
+              .addUse(FBHReg);
+    // ^^ this vector will not be selected from; could be empty
+    for (j = 1; j < count * 2; j += 2) {
+      MIB.addImm(j);
+    }
+    Result &= MIB.constrainAllUses(TII, TRI, RBI);
+  }
+
+  // 4. check if result of each top 32 bits is == -1
+  SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII);
+  if (!isScalarRes)
+    BoolType = GR.getOrCreateSPIRVVectorType(BoolType, count, MIRBuilder);
 
-  Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
-                     TII.get(SPIRV::OpVectorShuffle))
-                 .addDef(HighReg)
-                 .addUse(GR.getSPIRVTypeID(VResType))
-                 .addUse(FBHReg)
-                 .addUse(FBHReg);
-  // ^^ this vector will not be selected from; could be empty
-  unsigned j;
-  for (j = 0; j < count * 2; j += 2) {
-    MIB.addImm(j);
-  }
-  Result &= MIB.constrainAllUses(TII, TRI, RBI);
-
-  // get low bits
-  Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
-                TII.get(SPIRV::OpVectorShuffle))
-            .addDef(LowReg)
-            .addUse(GR.getSPIRVTypeID(VResType))
-            .addUse(FBHReg)
-            .addUse(FBHReg);
-  // ^^ this vector will not be selected from; could be empty
-  for (j = 1; j < count * 2; j += 2) {
-    MIB.addImm(j);
-  }
-  Result &= MIB.constrainAllUses(TII, TRI, RBI);
-
-  SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
-      GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
   // check if the high bits are == -1;
-  Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
+  Register NegOneReg =
+      GR.getOrCreateConstScalarOrVector(-1, I, ResType, TII, ZeroAsNull);
   // true if -1
   Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
   Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
                                  SPIRV::OpIEqual);
 
   // Select low bits if true in BReg, otherwise high bits
-  Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
-                                 SPIRV::OpSelectVIVCond);
+  unsigned selectOp =
+      isScalarRes ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
+  Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  Result &= selectNAryOpWithSrcs(TmpReg, ResType, I, {BReg, LowReg, HighReg},
+                                 selectOp);
 
   // Add 32 for high bits, 0 for low bits
-  Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  bool ZeroAsNull = STI.isOpenCLEnv();
-  Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
-  Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
-  Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
-                                 SPIRV::OpSelectVIVCond);
-
-  Register AddReg = ResVReg;
-  if (isScalarRes)
-    AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-
-  Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
-                                 SPIRV::OpIAddV);
-
-  // convert result back to scalar if necessary
-  if (!isScalarRes)
-    return Result;
-  else
-    return Result & selectNAryOpWithSrcs(
-                        ResVReg, ResType, I,
-                        {AddReg, GR.getOrCreateConstInt(0, I, ResType, TII)},
-                        SPIRV::OpVectorExtractDynamic);
+  Register ValReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  Register Reg0 =
+      GR.getOrCreateConstScalarOrVector(0, I, ResType, TII, ZeroAsNull);
+  Register Reg32 =
+      GR.getOrCreateConstScalarOrVector(32, I, ResType, TII, ZeroAsNull);
+  Result &=
+      selectNAryOpWithSrcs(ValReg, ResType, I, {BReg, Reg0, Reg32}, selectOp);
+
+  return Result &=
+         selectNAryOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg},
+                              isScalarRes ? SPIRV::OpIAddS : SPIRV::OpIAddV);
 }
 
 bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
index 057b1a9c78722a..2b2012a5d8ba1e 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
@@ -2,6 +2,8 @@
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK: OpMemoryModel Logical GLSL450
+; CHECK: [[Z:%.*]] = OpConstant %[[#]] 0
+; CHECK: [[X:%.*]] = OpConstant %[[#]] 1
 
 define noundef i32 @firstbituhigh_i32(i32 noundef %a) {
 entry:
@@ -37,13 +39,12 @@ define noundef i32 @firstbituhigh_i64(i64 noundef %a) {
 entry:
 ; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
 ; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindUMsb [[O]]
-; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
-; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[M:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[Z]]
+; CHECK: [[L:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[X]]
 ; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
 ; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
 ; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
 ; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
-; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
   %elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i64(i64 %a)
   ret i32 %elt.firstbituhigh
 }
@@ -82,13 +83,12 @@ define noundef i32 @firstbitshigh_i64(i64 noundef %a) {
 entry:
 ; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
 ; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindSMsb [[O]]
-; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
-; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[M:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[Z]]
+; CHECK: [[L:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[X]]
 ; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
 ; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
 ; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
 ; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
-; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
   %elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i64(i64 %a)
   ret i32 %elt.firstbitshigh
 }



More information about the cfe-commits mailing list