[clang] [llvm] [HLSL] implement elementwise firstbithigh hlsl builtin (PR #111082)
Sarah Spall via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 5 10:41:36 PST 2024
https://github.com/spall updated https://github.com/llvm/llvm-project/pull/111082
>From 93ff4fcbe072d7e81f8fe1e4e9e60025c3cc6d30 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Mon, 23 Sep 2024 22:10:59 +0000
Subject: [PATCH 01/10] implement firstbithigh hlsl builtin
---
clang/include/clang/Basic/Builtins.td | 6 +
clang/lib/CodeGen/CGBuiltin.cpp | 17 ++
clang/lib/CodeGen/CGHLSLRuntime.h | 2 +
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 72 +++++++++
clang/lib/Sema/SemaHLSL.cpp | 18 +++
.../CodeGenHLSL/builtins/firstbithigh.hlsl | 153 ++++++++++++++++++
.../BuiltIns/firstbithigh-errors.hlsl | 28 ++++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 2 +
llvm/lib/Target/DirectX/DXIL.td | 24 +++
.../DirectX/DirectXTargetTransformInfo.cpp | 2 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 14 ++
llvm/test/CodeGen/DirectX/firstbithigh.ll | 91 +++++++++++
.../CodeGen/DirectX/firstbitshigh_error.ll | 10 ++
.../CodeGen/DirectX/firstbituhigh_error.ll | 10 ++
.../SPIRV/hlsl-intrinsics/firstbithigh.ll | 37 +++++
16 files changed, 488 insertions(+)
create mode 100644 clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
create mode 100644 clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl
create mode 100644 llvm/test/CodeGen/DirectX/firstbithigh.ll
create mode 100644 llvm/test/CodeGen/DirectX/firstbitshigh_error.ll
create mode 100644 llvm/test/CodeGen/DirectX/firstbituhigh_error.ll
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 9bd67e0cefebc3..33507992ac3174 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4792,6 +4792,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLFirstBitHigh : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_firstbithigh"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLFrac : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_frac"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 34fedd67114751..394d6aaef78532 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18767,6 +18767,14 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
return RT.getUDotIntrinsic();
}
+Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
+ if (QT->hasSignedIntegerRepresentation()) {
+ return RT.getFirstBitSHighIntrinsic();
+ }
+
+ return RT.getFirstBitUHighIntrinsic();
+}
+
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E,
ReturnValueSlot ReturnValue) {
@@ -18856,6 +18864,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
} break;
+ case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
+
+ Value *X = EmitScalarExpr(E->getArg(0));
+
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/X->getType(),
+ getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
+ ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
+ }
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index cd533cad84e9fb..e56aa6a1edaab2 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -91,6 +91,8 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index d9f3a17ea23d8e..9047021489ad5c 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -998,6 +998,78 @@ float3 exp2(float3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_exp2)
float4 exp2(float4);
+//===----------------------------------------------------------------------===//
+// firstbithigh builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T firstbithigh(T Val)
+/// \brief Returns the location of the first set bit starting from the highest
+/// order bit and working downward, per component.
+/// \param Val the input value.
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t firstbithigh(int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t2 firstbithigh(int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t3 firstbithigh(int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int16_t4 firstbithigh(int16_t4);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t firstbithigh(uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t2 firstbithigh(uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t3 firstbithigh(uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint16_t4 firstbithigh(uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int firstbithigh(int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int2 firstbithigh(int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int3 firstbithigh(int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int4 firstbithigh(int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint firstbithigh(uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint2 firstbithigh(uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint3 firstbithigh(uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint4 firstbithigh(uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t firstbithigh(int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t2 firstbithigh(int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t3 firstbithigh(int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+int64_t4 firstbithigh(int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t firstbithigh(uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t2 firstbithigh(uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t3 firstbithigh(uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
+uint64_t4 firstbithigh(uint64_t4);
+
//===----------------------------------------------------------------------===//
// floor builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a472538236e2d9..8a7c24fab05e6f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1947,6 +1947,24 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
+ if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+ return true;
+
+ const Expr *Arg = TheCall->getArg(0);
+ QualType ArgTy = Arg->getType();
+ QualType EltTy = ArgTy;
+
+ if (auto *VecTy = EltTy->getAs<VectorType>())
+ EltTy = VecTy->getElementType();
+
+ if (!EltTy->isIntegerType()) {
+ Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+ << 1 << /* integer ty */ 6 << ArgTy;
+ return true;
+ }
+ break;
+ }
case Builtin::BI__builtin_hlsl_select: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
diff --git a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
new file mode 100644
index 00000000000000..9821b308e63521
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
@@ -0,0 +1,153 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s -DTARGET=dx
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s -DTARGET=spv
+
+#ifdef __HLSL_ENABLE_16_BIT
+// CHECK-LABEL: test_firstbithigh_ushort
+// CHECK: call i16 @llvm.[[TARGET]].firstbituhigh.i16
+int test_firstbithigh_ushort(uint16_t p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort2
+// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbituhigh.v2i16
+uint16_t2 test_firstbithigh_ushort2(uint16_t2 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort3
+// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbituhigh.v3i16
+uint16_t3 test_firstbithigh_ushort3(uint16_t3 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ushort4
+// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbituhigh.v4i16
+uint16_t4 test_firstbithigh_ushort4(uint16_t4 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short
+// CHECK: call i16 @llvm.[[TARGET]].firstbitshigh.i16
+int16_t test_firstbithigh_short(int16_t p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short2
+// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbitshigh.v2i16
+int16_t2 test_firstbithigh_short2(int16_t2 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short3
+// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbitshigh.v3i16
+int16_t3 test_firstbithigh_short3(int16_t3 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_short4
+// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbitshigh.v4i16
+int16_t4 test_firstbithigh_short4(int16_t4 p0) {
+ return firstbithigh(p0);
+}
+#endif // __HLSL_ENABLE_16_BIT
+
+// CHECK-LABEL: test_firstbithigh_uint
+// CHECK: call i32 @llvm.[[TARGET]].firstbituhigh.i32
+uint test_firstbithigh_uint(uint p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint2
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbituhigh.v2i32
+uint2 test_firstbithigh_uint2(uint2 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint3
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbituhigh.v3i32
+uint3 test_firstbithigh_uint3(uint3 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_uint4
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbituhigh.v4i32
+uint4 test_firstbithigh_uint4(uint4 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong
+// CHECK: call i64 @llvm.[[TARGET]].firstbituhigh.i64
+uint64_t test_firstbithigh_ulong(uint64_t p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong2
+// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbituhigh.v2i64
+uint64_t2 test_firstbithigh_ulong2(uint64_t2 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong3
+// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbituhigh.v3i64
+uint64_t3 test_firstbithigh_ulong3(uint64_t3 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_ulong4
+// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbituhigh.v4i64
+uint64_t4 test_firstbithigh_ulong4(uint64_t4 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int
+// CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i32
+int test_firstbithigh_int(int p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int2
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i32
+int2 test_firstbithigh_int2(int2 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int3
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i32
+int3 test_firstbithigh_int3(int3 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_int4
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i32
+int4 test_firstbithigh_int4(int4 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long
+// CHECK: call i64 @llvm.[[TARGET]].firstbitshigh.i64
+int64_t test_firstbithigh_long(int64_t p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long2
+// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbitshigh.v2i64
+int64_t2 test_firstbithigh_long2(int64_t2 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long3
+// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbitshigh.v3i64
+int64_t3 test_firstbithigh_long3(int64_t3 p0) {
+ return firstbithigh(p0);
+}
+
+// CHECK-LABEL: test_firstbithigh_long4
+// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbitshigh.v4i64
+int64_t4 test_firstbithigh_long4(int64_t4 p0) {
+ return firstbithigh(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl
new file mode 100644
index 00000000000000..1912ab3ae806b3
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/firstbithigh-errors.hlsl
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+int test_too_few_arg() {
+ return firstbithigh();
+ // expected-error at -1 {{no matching function for call to 'firstbithigh'}}
+}
+
+int test_too_many_arg(int p0) {
+ return firstbithigh(p0, p0);
+ // expected-error at -1 {{no matching function for call to 'firstbithigh'}}
+}
+
+double test_int_builtin(double p0) {
+ return firstbithigh(p0);
+ // expected-error at -1 {{call to 'firstbithigh' is ambiguous}}
+}
+
+double2 test_int_builtin_2(double2 p0) {
+ return __builtin_hlsl_elementwise_firstbithigh(p0);
+ // expected-error at -1 {{1st argument must be a vector of integers
+ // (was 'double2' (aka 'vector<double, 2>'))}}
+}
+
+float test_int_builtin_3(float p0) {
+ return __builtin_hlsl_elementwise_firstbithigh(p0);
+ // expected-error at -1 {{1st argument must be a vector of integers
+ // (was 'float')}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..cfffd91afde880 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -92,4 +92,6 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index ddb47390537412..52ad400337a259 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -98,4 +98,6 @@ let TargetPrefix = "spv" in {
[llvm_any_ty],
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
[IntrNoMem]>;
+ def int_spv_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+ def int_spv_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index af12b74351058e..573c607765745f 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,6 +564,30 @@ def CountBits : DXILOp<31, unaryBits> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
+def FBH : DXILOp<33, unary> {
+ let Doc = "Returns the location of the first set bit starting from "
+ "the highest order bit and working downward.";
+ let LLVMIntrinsic = int_dx_firstbituhigh;
+ let arguments = [OverloadTy];
+ let result = OverloadTy;
+ let overloads =
+ [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
+def FBSH : DXILOp<34, unary> {
+ let Doc = "Returns the location of the first set bit from "
+ "the highest order bit based on the sign.";
+ let LLVMIntrinsic = int_dx_firstbitshigh;
+ let arguments = [OverloadTy];
+ let result = OverloadTy;
+ let overloads =
+ [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
def FMax : DXILOp<35, binary> {
let Doc = "Float maximum. FMax(a,b) = a > b ? a : b";
let LLVMIntrinsic = int_maxnum;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..b0436a39423405 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -32,6 +32,8 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_splitdouble:
+ case Intrinsic::dx_firstbituhigh:
+ case Intrinsic::dx_firstbitshigh:
return true;
default:
return false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 892912a5680113..ba533484afef86 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -213,6 +213,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectPhi(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectExtInst(Register ResVReg, const SPIRVType *RestType,
+ MachineInstr &I, GL::GLSLExtInst GLInst) const;
bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, CL::OpenCLExtInst CLInst) const;
bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
@@ -762,6 +764,14 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
}
}
+bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ GL::GLSLExtInst GLInst) const {
+ return selectExtInst(ResVReg, ResType, I,
+ {{SPIRV::InstructionSet::GLSL_std_450, GLInst}});
+}
+
bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
@@ -2548,6 +2558,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectExtInst(ResVReg, ResType, I, CL::rsqrt, GL::InverseSqrt);
case Intrinsic::spv_sign:
return selectSign(ResVReg, ResType, I);
+ case Intrinsic::spv_firstbituhigh:
+ return selectExtInst(ResVReg, ResType, I, GL::FindUMsb);
+ case Intrinsic::spv_firstbitshigh:
+ return selectExtInst(ResVReg, ResType, I, GL::FindSMsb);
case Intrinsic::spv_group_memory_barrier_with_group_sync: {
Register MemSemReg =
buildI32Constant(SPIRV::MemorySemantics::SequentiallyConsistent, I);
diff --git a/llvm/test/CodeGen/DirectX/firstbithigh.ll b/llvm/test/CodeGen/DirectX/firstbithigh.ll
new file mode 100644
index 00000000000000..4a97ad6226149f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/firstbithigh.ll
@@ -0,0 +1,91 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; Make sure dxil operation function calls for firstbithigh are generated for all integer types.
+
+define noundef i16 @test_firstbithigh_ushort(i16 noundef %a) {
+entry:
+; CHECK: call i16 @dx.op.unary.i16(i32 33, i16 %{{.*}})
+ %elt.firstbithigh = call i16 @llvm.dx.firstbituhigh.i16(i16 %a)
+ ret i16 %elt.firstbithigh
+}
+
+define noundef i16 @test_firstbithigh_short(i16 noundef %a) {
+entry:
+; CHECK: call i16 @dx.op.unary.i16(i32 34, i16 %{{.*}})
+ %elt.firstbithigh = call i16 @llvm.dx.firstbitshigh.i16(i16 %a)
+ ret i16 %elt.firstbithigh
+}
+
+define noundef i32 @test_firstbithigh_uint(i32 noundef %a) {
+entry:
+; CHECK: call i32 @dx.op.unary.i32(i32 33, i32 %{{.*}})
+ %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i32(i32 %a)
+ ret i32 %elt.firstbithigh
+}
+
+define noundef i32 @test_firstbithigh_int(i32 noundef %a) {
+entry:
+; CHECK: call i32 @dx.op.unary.i32(i32 34, i32 %{{.*}})
+ %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i32(i32 %a)
+ ret i32 %elt.firstbithigh
+}
+
+define noundef i64 @test_firstbithigh_ulong(i64 noundef %a) {
+entry:
+; CHECK: call i64 @dx.op.unary.i64(i32 33, i64 %{{.*}})
+ %elt.firstbithigh = call i64 @llvm.dx.firstbituhigh.i64(i64 %a)
+ ret i64 %elt.firstbithigh
+}
+
+define noundef i64 @test_firstbithigh_long(i64 noundef %a) {
+entry:
+; CHECK: call i64 @dx.op.unary.i64(i32 34, i64 %{{.*}})
+ %elt.firstbithigh = call i64 @llvm.dx.firstbitshigh.i64(i64 %a)
+ ret i64 %elt.firstbithigh
+}
+
+define noundef <4 x i32> @test_firstbituhigh_vec4_i32(<4 x i32> noundef %a) {
+entry:
+ ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
+ ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee0]])
+ ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
+ ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee1]])
+ ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
+ ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee2]])
+ ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
+ ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee3]])
+ ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
+ ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
+ ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
+ ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie3]], i64 3
+ %2 = call <4 x i32> @llvm.dx.firstbituhigh.v4i32(<4 x i32> %a)
+ ret <4 x i32> %2
+}
+
+define noundef <4 x i32> @test_firstbitshigh_vec4_i32(<4 x i32> noundef %a) {
+entry:
+ ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
+ ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee0]])
+ ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
+ ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee1]])
+ ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
+ ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee2]])
+ ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
+ ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee3]])
+ ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
+ ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
+ ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
+ ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie3]], i64 3
+ %2 = call <4 x i32> @llvm.dx.firstbitshigh.v4i32(<4 x i32> %a)
+ ret <4 x i32> %2
+}
+
+declare i16 @llvm.dx.firstbituhigh.i16(i16)
+declare i32 @llvm.dx.firstbituhigh.i32(i32)
+declare i64 @llvm.dx.firstbituhigh.i64(i64)
+declare <4 x i32> @llvm.dx.firstbituhigh.v4i32(<4 x i32>)
+
+declare i16 @llvm.dx.firstbitshigh.i16(i16)
+declare i32 @llvm.dx.firstbitshigh.i32(i32)
+declare i64 @llvm.dx.firstbitshigh.i64(i64)
+declare <4 x i32> @llvm.dx.firstbitshigh.v4i32(<4 x i32>)
diff --git a/llvm/test/CodeGen/DirectX/firstbitshigh_error.ll b/llvm/test/CodeGen/DirectX/firstbitshigh_error.ll
new file mode 100644
index 00000000000000..22982a03e47921
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/firstbitshigh_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation firstbitshigh does not support double overload type
+; CHECK: invalid intrinsic signature
+
+define noundef double @firstbitshigh_double(double noundef %a) {
+entry:
+ %1 = call double @llvm.dx.firstbitshigh.f64(double %a)
+ ret double %1
+}
diff --git a/llvm/test/CodeGen/DirectX/firstbituhigh_error.ll b/llvm/test/CodeGen/DirectX/firstbituhigh_error.ll
new file mode 100644
index 00000000000000..b611a96ffc2f9c
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/firstbituhigh_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation firstbituhigh does not support double overload type
+; CHECK: invalid intrinsic signature
+
+define noundef double @firstbituhigh_double(double noundef %a) {
+entry:
+ %1 = call double @llvm.dx.firstbituhigh.f64(double %a)
+ ret double %1
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
new file mode 100644
index 00000000000000..e20a7a4cefe8ee
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
@@ -0,0 +1,37 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpMemoryModel Logical GLSL450
+
+define noundef i32 @firstbituhigh_i32(i32 noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb %[[#]]
+ %elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i32(i32 %a)
+ ret i32 %elt.firstbituhigh
+}
+
+define noundef i16 @firstbituhigh_i16(i16 noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb %[[#]]
+ %elt.firstbituhigh = call i16 @llvm.spv.firstbituhigh.i16(i16 %a)
+ ret i16 %elt.firstbituhigh
+}
+
+define noundef i32 @firstbitshigh_i32(i32 noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindSMsb %[[#]]
+ %elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i32(i32 %a)
+ ret i32 %elt.firstbitshigh
+}
+
+define noundef i16 @firstbitshigh_i16(i16 noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindSMsb %[[#]]
+ %elt.firstbitshigh = call i16 @llvm.spv.firstbitshigh.i16(i16 %a)
+ ret i16 %elt.firstbitshigh
+}
+
+declare i16 @llvm.spv.firstbituhigh.i16(i16)
+declare i32 @llvm.spv.firstbituhigh.i32(i32)
+declare i16 @llvm.spv.firstbitshigh.i16(i16)
+declare i32 @llvm.spv.firstbitshigh.i32(i32)
>From c28d084f1bc9d89f7ec83293aff558f96c79cc4d Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 4 Oct 2024 01:22:58 +0000
Subject: [PATCH 02/10] make clang format happy
---
clang/lib/CodeGen/CGBuiltin.cpp | 10 +++++-----
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 2 +-
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 10 +++++-----
3 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 394d6aaef78532..3d15bfc484f134 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18865,13 +18865,13 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
} break;
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
-
+
Value *X = EmitScalarExpr(E->getArg(0));
-
+
return Builder.CreateIntrinsic(
- /*ReturnType=*/X->getType(),
- getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
- ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
+ /*ReturnType=*/X->getType(),
+ getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
+ ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
}
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 9047021489ad5c..f1c3a7b46f984b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1069,7 +1069,7 @@ _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
uint64_t3 firstbithigh(uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
uint64_t4 firstbithigh(uint64_t4);
-
+
//===----------------------------------------------------------------------===//
// floor builtins
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ba533484afef86..b2214344596765 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -214,7 +214,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
MachineInstr &I) const;
bool selectExtInst(Register ResVReg, const SPIRVType *RestType,
- MachineInstr &I, GL::GLSLExtInst GLInst) const;
+ MachineInstr &I, GL::GLSLExtInst GLInst) const;
bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, CL::OpenCLExtInst CLInst) const;
bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
@@ -765,11 +765,11 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
}
bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I,
- GL::GLSLExtInst GLInst) const {
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ GL::GLSLExtInst GLInst) const {
return selectExtInst(ResVReg, ResType, I,
- {{SPIRV::InstructionSet::GLSL_std_450, GLInst}});
+ {{SPIRV::InstructionSet::GLSL_std_450, GLInst}});
}
bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
>From 14ae28245b3519863d5e1d43c26a67dd34bed0a5 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Sat, 5 Oct 2024 00:06:05 +0000
Subject: [PATCH 03/10] address PR comments
---
clang/lib/CodeGen/CGBuiltin.cpp | 1 +
clang/lib/Sema/SemaHLSL.cpp | 2 ++
llvm/lib/Target/DirectX/DXIL.td | 4 ++--
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 4 ++--
4 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 3d15bfc484f134..e17c559f2b13c8 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18772,6 +18772,7 @@ Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
return RT.getFirstBitSHighIntrinsic();
}
+ assert(QT->hasUnsignedIntegerRepresentation());
return RT.getFirstBitUHighIntrinsic();
}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 8a7c24fab05e6f..04755003f99358 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1963,6 +1963,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
<< 1 << /* integer ty */ 6 << ArgTy;
return true;
}
+
+ TheCall->setType(ArgTy);
break;
}
case Builtin::BI__builtin_hlsl_select: {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 573c607765745f..0b696f4e39bc10 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,7 +564,7 @@ def CountBits : DXILOp<31, unaryBits> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def FBH : DXILOp<33, unary> {
+def FirstbitHi : DXILOp<33, unaryBits> {
let Doc = "Returns the location of the first set bit starting from "
"the highest order bit and working downward.";
let LLVMIntrinsic = int_dx_firstbituhigh;
@@ -576,7 +576,7 @@ def FBH : DXILOp<33, unary> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def FBSH : DXILOp<34, unary> {
+def FirstbitSHi : DXILOp<34, unaryBits> {
let Doc = "Returns the location of the first set bit from "
"the highest order bit based on the sign.";
let LLVMIntrinsic = int_dx_firstbitshigh;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b2214344596765..80416c662e7af6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2558,9 +2558,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectExtInst(ResVReg, ResType, I, CL::rsqrt, GL::InverseSqrt);
case Intrinsic::spv_sign:
return selectSign(ResVReg, ResType, I);
- case Intrinsic::spv_firstbituhigh:
+ case Intrinsic::spv_firstbituhigh: // There is no CL equivalent of FindUMsb
return selectExtInst(ResVReg, ResType, I, GL::FindUMsb);
- case Intrinsic::spv_firstbitshigh:
+ case Intrinsic::spv_firstbitshigh: // There is no CL equivalent of FindSMsb
return selectExtInst(ResVReg, ResType, I, GL::FindSMsb);
case Intrinsic::spv_group_memory_barrier_with_group_sync: {
Register MemSemReg =
>From 72b3aaa735e7dcb3ad87f2ef6db720adf3a9497c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 11 Oct 2024 15:21:00 +0000
Subject: [PATCH 04/10] update firstbithigh to the correct behavior. both dxil
and spirv support 16, 32, and 64 bit now.
---
clang/lib/CodeGen/CGBuiltin.cpp | 2 +-
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 40 ++--
clang/lib/Sema/SemaHLSL.cpp | 8 +-
.../CodeGenHLSL/builtins/firstbithigh.hlsl | 72 +++----
llvm/include/llvm/IR/IntrinsicsDirectX.td | 4 +-
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 4 +-
llvm/lib/Target/DirectX/DXIL.td | 8 +-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 2 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 176 +++++++++++++++++-
llvm/test/CodeGen/DirectX/firstbithigh.ll | 40 ++--
.../SPIRV/hlsl-intrinsics/firstbithigh.ll | 90 ++++++++-
11 files changed, 346 insertions(+), 100 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e17c559f2b13c8..3c1ab642b1166d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18870,7 +18870,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *X = EmitScalarExpr(E->getArg(0));
return Builder.CreateIntrinsic(
- /*ReturnType=*/X->getType(),
+ /*ReturnType=*/ConvertType(E->getType()),
getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index f1c3a7b46f984b..96bc6c571f671c 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1010,38 +1010,38 @@ float4 exp2(float4);
#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int16_t firstbithigh(int16_t);
+uint firstbithigh(int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int16_t2 firstbithigh(int16_t2);
+uint2 firstbithigh(int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int16_t3 firstbithigh(int16_t3);
+uint3 firstbithigh(int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int16_t4 firstbithigh(int16_t4);
+uint4 firstbithigh(int16_t4);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint16_t firstbithigh(uint16_t);
+uint firstbithigh(uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint16_t2 firstbithigh(uint16_t2);
+uint2 firstbithigh(uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint16_t3 firstbithigh(uint16_t3);
+uint3 firstbithigh(uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint16_t4 firstbithigh(uint16_t4);
+uint4 firstbithigh(uint16_t4);
#endif
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int firstbithigh(int);
+uint firstbithigh(int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int2 firstbithigh(int2);
+uint2 firstbithigh(int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int3 firstbithigh(int3);
+uint3 firstbithigh(int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int4 firstbithigh(int4);
+uint4 firstbithigh(int4);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
uint firstbithigh(uint);
@@ -1053,22 +1053,22 @@ _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
uint4 firstbithigh(uint4);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int64_t firstbithigh(int64_t);
+uint firstbithigh(int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int64_t2 firstbithigh(int64_t2);
+uint2 firstbithigh(int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int64_t3 firstbithigh(int64_t3);
+uint3 firstbithigh(int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-int64_t4 firstbithigh(int64_t4);
+uint4 firstbithigh(int64_t4);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint64_t firstbithigh(uint64_t);
+uint firstbithigh(uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint64_t2 firstbithigh(uint64_t2);
+uint2 firstbithigh(uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint64_t3 firstbithigh(uint64_t3);
+uint3 firstbithigh(uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_firstbithigh)
-uint64_t4 firstbithigh(uint64_t4);
+uint4 firstbithigh(uint64_t4);
//===----------------------------------------------------------------------===//
// floor builtins
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 04755003f99358..f2ae14148311d8 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1955,8 +1955,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
QualType ArgTy = Arg->getType();
QualType EltTy = ArgTy;
- if (auto *VecTy = EltTy->getAs<VectorType>())
+ QualType ResTy = SemaRef.Context.UnsignedIntTy;
+
+ if (auto *VecTy = EltTy->getAs<VectorType>()) {
EltTy = VecTy->getElementType();
+ ResTy = SemaRef.Context.getVectorType(ResTy, VecTy->getNumElements(), VecTy->getVectorKind());
+ }
if (!EltTy->isIntegerType()) {
Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
@@ -1964,7 +1968,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
}
- TheCall->setType(ArgTy);
+ TheCall->setType(ResTy);
break;
}
case Builtin::BI__builtin_hlsl_select: {
diff --git a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
index 9821b308e63521..ce94a1c15c5332 100644
--- a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
@@ -8,50 +8,50 @@
#ifdef __HLSL_ENABLE_16_BIT
// CHECK-LABEL: test_firstbithigh_ushort
-// CHECK: call i16 @llvm.[[TARGET]].firstbituhigh.i16
-int test_firstbithigh_ushort(uint16_t p0) {
+// CHECK: call i32 @llvm.[[TARGET]].firstbituhigh.i16
+uint test_firstbithigh_ushort(uint16_t p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_ushort2
-// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbituhigh.v2i16
-uint16_t2 test_firstbithigh_ushort2(uint16_t2 p0) {
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbituhigh.v2i16
+uint2 test_firstbithigh_ushort2(uint16_t2 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_ushort3
-// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbituhigh.v3i16
-uint16_t3 test_firstbithigh_ushort3(uint16_t3 p0) {
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbituhigh.v3i16
+uint3 test_firstbithigh_ushort3(uint16_t3 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_ushort4
-// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbituhigh.v4i16
-uint16_t4 test_firstbithigh_ushort4(uint16_t4 p0) {
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbituhigh.v4i16
+uint4 test_firstbithigh_ushort4(uint16_t4 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_short
-// CHECK: call i16 @llvm.[[TARGET]].firstbitshigh.i16
-int16_t test_firstbithigh_short(int16_t p0) {
+// CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i16
+uint test_firstbithigh_short(int16_t p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_short2
-// CHECK: call <2 x i16> @llvm.[[TARGET]].firstbitshigh.v2i16
-int16_t2 test_firstbithigh_short2(int16_t2 p0) {
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i16
+uint2 test_firstbithigh_short2(int16_t2 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_short3
-// CHECK: call <3 x i16> @llvm.[[TARGET]].firstbitshigh.v3i16
-int16_t3 test_firstbithigh_short3(int16_t3 p0) {
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i16
+uint3 test_firstbithigh_short3(int16_t3 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_short4
-// CHECK: call <4 x i16> @llvm.[[TARGET]].firstbitshigh.v4i16
-int16_t4 test_firstbithigh_short4(int16_t4 p0) {
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i16
+uint4 test_firstbithigh_short4(int16_t4 p0) {
return firstbithigh(p0);
}
#endif // __HLSL_ENABLE_16_BIT
@@ -81,73 +81,73 @@ uint4 test_firstbithigh_uint4(uint4 p0) {
}
// CHECK-LABEL: test_firstbithigh_ulong
-// CHECK: call i64 @llvm.[[TARGET]].firstbituhigh.i64
-uint64_t test_firstbithigh_ulong(uint64_t p0) {
+// CHECK: call i32 @llvm.[[TARGET]].firstbituhigh.i64
+uint test_firstbithigh_ulong(uint64_t p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_ulong2
-// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbituhigh.v2i64
-uint64_t2 test_firstbithigh_ulong2(uint64_t2 p0) {
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbituhigh.v2i64
+uint2 test_firstbithigh_ulong2(uint64_t2 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_ulong3
-// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbituhigh.v3i64
-uint64_t3 test_firstbithigh_ulong3(uint64_t3 p0) {
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbituhigh.v3i64
+uint3 test_firstbithigh_ulong3(uint64_t3 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_ulong4
-// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbituhigh.v4i64
-uint64_t4 test_firstbithigh_ulong4(uint64_t4 p0) {
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbituhigh.v4i64
+uint4 test_firstbithigh_ulong4(uint64_t4 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_int
// CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i32
-int test_firstbithigh_int(int p0) {
+uint test_firstbithigh_int(int p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_int2
// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i32
-int2 test_firstbithigh_int2(int2 p0) {
+uint2 test_firstbithigh_int2(int2 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_int3
// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i32
-int3 test_firstbithigh_int3(int3 p0) {
+uint3 test_firstbithigh_int3(int3 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_int4
// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i32
-int4 test_firstbithigh_int4(int4 p0) {
+uint4 test_firstbithigh_int4(int4 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_long
-// CHECK: call i64 @llvm.[[TARGET]].firstbitshigh.i64
-int64_t test_firstbithigh_long(int64_t p0) {
+// CHECK: call i32 @llvm.[[TARGET]].firstbitshigh.i64
+uint test_firstbithigh_long(int64_t p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_long2
-// CHECK: call <2 x i64> @llvm.[[TARGET]].firstbitshigh.v2i64
-int64_t2 test_firstbithigh_long2(int64_t2 p0) {
+// CHECK: call <2 x i32> @llvm.[[TARGET]].firstbitshigh.v2i64
+uint2 test_firstbithigh_long2(int64_t2 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_long3
-// CHECK: call <3 x i64> @llvm.[[TARGET]].firstbitshigh.v3i64
-int64_t3 test_firstbithigh_long3(int64_t3 p0) {
+// CHECK: call <3 x i32> @llvm.[[TARGET]].firstbitshigh.v3i64
+uint3 test_firstbithigh_long3(int64_t3 p0) {
return firstbithigh(p0);
}
// CHECK-LABEL: test_firstbithigh_long4
-// CHECK: call <4 x i64> @llvm.[[TARGET]].firstbitshigh.v4i64
-int64_t4 test_firstbithigh_long4(int64_t4 p0) {
+// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i64
+uint4 test_firstbithigh_long4(int64_t4 p0) {
return firstbithigh(p0);
}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index cfffd91afde880..612dbb398720eb 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -92,6 +92,6 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
-def int_dx_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
-def int_dx_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_firstbituhigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
+def int_dx_firstbitshigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 52ad400337a259..874409c21157a8 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -98,6 +98,6 @@ let TargetPrefix = "spv" in {
[llvm_any_ty],
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
[IntrNoMem]>;
- def int_spv_firstbituhigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
- def int_spv_firstbitshigh : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+ def int_spv_firstbituhigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
+ def int_spv_firstbitshigh : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_anyint_ty], [IntrNoMem]>;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 0b696f4e39bc10..e676f90515c710 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,24 +564,24 @@ def CountBits : DXILOp<31, unaryBits> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def FirstbitHi : DXILOp<33, unaryBits> {
+def FirstbitHi : DXILOp<33, unary> {
let Doc = "Returns the location of the first set bit starting from "
"the highest order bit and working downward.";
let LLVMIntrinsic = int_dx_firstbituhigh;
let arguments = [OverloadTy];
- let result = OverloadTy;
+ let result = Int32Ty;
let overloads =
[Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def FirstbitSHi : DXILOp<34, unaryBits> {
+def FirstbitSHi : DXILOp<34, unary> {
let Doc = "Returns the location of the first set bit from "
"the highest order bit based on the sign.";
let LLVMIntrinsic = int_dx_firstbitshigh;
let arguments = [OverloadTy];
- let result = OverloadTy;
+ let result = Int32Ty;
let overloads =
[Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index af0df4d6e5d563..f66506beaa6ed6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -424,7 +424,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
LLT LLTy = LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
- CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
+ CurMF->getRegInfo().setRegClass(SpvVecConst, getRegClass(SpvType));
assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
DT.add(CA, CurMF, SpvVecConst);
MachineInstrBuilder MIB;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 80416c662e7af6..ce9af5c2968480 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -92,9 +92,26 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool spvSelect(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectFirstBitHigh(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I, bool IsSigned) const;
+
+ bool selectFirstBitHigh16(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I, bool IsSigned) const;
+
+ bool selectFirstBitHigh32(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I, Register SrcReg,
+ bool IsSigned) const;
+
+ bool selectFirstBitHigh64(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I, bool IsSigned) const;
+
bool selectGlobalValue(Register ResVReg, MachineInstr &I,
const MachineInstr *Init = nullptr) const;
+ bool selectNAryOpWithSrcs(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I, std::vector<Register> SrcRegs,
+ unsigned Opcode) const;
+
bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, Register SrcReg,
unsigned Opcode) const;
@@ -819,6 +836,20 @@ bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
return false;
}
+bool SPIRVInstructionSelector::selectNAryOpWithSrcs(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ std::vector<Register> Srcs,
+ unsigned Opcode) const {
+ auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType));
+ for(Register SReg : Srcs) {
+ MIB.addUse(SReg);
+ }
+ return MIB.constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectUnOpWithSrc(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
@@ -2559,9 +2590,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_sign:
return selectSign(ResVReg, ResType, I);
case Intrinsic::spv_firstbituhigh: // There is no CL equivalent of FindUMsb
- return selectExtInst(ResVReg, ResType, I, GL::FindUMsb);
+ return selectFirstBitHigh(ResVReg, ResType, I, false);
case Intrinsic::spv_firstbitshigh: // There is no CL equivalent of FindSMsb
- return selectExtInst(ResVReg, ResType, I, GL::FindSMsb);
+ return selectFirstBitHigh(ResVReg, ResType, I, true);
case Intrinsic::spv_group_memory_barrier_with_group_sync: {
Register MemSemReg =
buildI32Constant(SPIRV::MemorySemantics::SequentiallyConsistent, I);
@@ -2682,6 +2713,147 @@ Register SPIRVInstructionSelector::buildPointerToResource(
return AcReg;
}
+bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
+ unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert;
+ // zero or sign extend
+ Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+ bool Result = selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(),
+ Opcode);
+ return Result & selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh32(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ Register SrcReg,
+ bool IsSigned) const {
+ unsigned Opcode = IsSigned ? GL::FindSMsb : GL::FindUMsb;
+ return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(Opcode)
+ .addUse(SrcReg)
+ .constrainAllUses(TII, TRI, RBI);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
+ Register OpReg = I.getOperand(2).getReg();
+ // 1. split our int64 into 2 pieces using a bitcast
+ unsigned count = GR.getScalarOrVectorComponentCount(ResType);
+ SPIRVType *baseType = GR.retrieveScalarOrVectorIntType(ResType);
+ MachineIRBuilder MIRBuilder(I);
+ SPIRVType *postCastT = GR.getOrCreateSPIRVVectorType(baseType, 2 * count,
+ MIRBuilder);
+ Register bitcastReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+ bool Result = selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg,
+ SPIRV::OpBitcast);
+
+ // 2. call firstbithigh
+ Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+ Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
+
+ // 3. check if result of each top 32 bits is == -1
+ // split result vector into vector of high bits and vector of low bits
+ // get high bits
+ // if ResType is a scalar we need a vector anyways because our code
+ // operates on vectors, even vectors of length one.
+ SPIRVType *VResType = ResType;
+ bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
+ if (isScalarRes)
+ VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
+ // count should be one.
+
+ Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(HighReg)
+ .addUse(GR.getSPIRVTypeID(VResType))
+ .addUse(FBHReg)
+ .addUse(FBHReg); // this vector will not be selected from; could be empty
+ unsigned i;
+ for(i = 0; i < count*2; i += 2) {
+ MIB.addImm(i);
+ }
+ Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+ // get low bits
+ Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(LowReg)
+ .addUse(GR.getSPIRVTypeID(VResType))
+ .addUse(FBHReg)
+ .addUse(FBHReg); // this vector will not be selected from; could be empty
+ for(i = 1; i < count*2; i += 2) {
+ MIB.addImm(i);
+ }
+ Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+ SPIRVType *BoolType =
+ GR.getOrCreateSPIRVVectorType(GR.getOrCreateSPIRVBoolType(I, TII),
+ count,
+ MIRBuilder);
+ // check if the high bits are == -1;
+ Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
+ // true if -1
+ Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
+ Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
+ SPIRV::OpIEqual);
+
+ // Select low bits if true in BReg, otherwise high bits
+ Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
+ SPIRV::OpSelectVIVCond);
+
+ // Add 32 for high bits, 0 for low bits
+ Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ bool ZeroAsNull = STI.isOpenCLEnv();
+ Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
+ Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
+ Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
+ SPIRV::OpSelectVIVCond);
+
+ Register AddReg = ResVReg;
+ if(isScalarRes)
+ AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
+ SPIRV::OpIAddV);
+
+ // convert result back to scalar if necessary
+ if (!isScalarRes)
+ return Result;
+ else
+ return Result & selectNAryOpWithSrcs(ResVReg, ResType, I,
+ {AddReg,
+ GR.getOrCreateConstInt(0, I, ResType,
+ TII)},
+ SPIRV::OpVectorExtractDynamic);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
+ // FindUMsb intrinsic only supports 32 bit integers
+ Register OpReg = I.getOperand(2).getReg();
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
+ unsigned bitWidth = GR.getScalarOrVectorBitWidth(OpType);
+
+ if (bitWidth == 16)
+ return selectFirstBitHigh16(ResVReg, ResType, I, IsSigned);
+ else if (bitWidth == 32)
+ return selectFirstBitHigh32(ResVReg, ResType, I, OpReg, IsSigned);
+ else // 64 bit
+ return selectFirstBitHigh64(ResVReg, ResType, I, IsSigned);
+}
+
bool SPIRVInstructionSelector::selectAllocaArray(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
diff --git a/llvm/test/CodeGen/DirectX/firstbithigh.ll b/llvm/test/CodeGen/DirectX/firstbithigh.ll
index 4a97ad6226149f..de0b11c97a9b98 100644
--- a/llvm/test/CodeGen/DirectX/firstbithigh.ll
+++ b/llvm/test/CodeGen/DirectX/firstbithigh.ll
@@ -2,18 +2,18 @@
; Make sure dxil operation function calls for firstbithigh are generated for all integer types.
-define noundef i16 @test_firstbithigh_ushort(i16 noundef %a) {
+define noundef i32 @test_firstbithigh_ushort(i16 noundef %a) {
entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 33, i16 %{{.*}})
- %elt.firstbithigh = call i16 @llvm.dx.firstbituhigh.i16(i16 %a)
- ret i16 %elt.firstbithigh
+; CHECK: call i32 @dx.op.unary.i16(i32 33, i16 %{{.*}})
+ %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i16(i16 %a)
+ ret i32 %elt.firstbithigh
}
-define noundef i16 @test_firstbithigh_short(i16 noundef %a) {
+define noundef i32 @test_firstbithigh_short(i16 noundef %a) {
entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 34, i16 %{{.*}})
- %elt.firstbithigh = call i16 @llvm.dx.firstbitshigh.i16(i16 %a)
- ret i16 %elt.firstbithigh
+; CHECK: call i32 @dx.op.unary.i16(i32 34, i16 %{{.*}})
+ %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i16(i16 %a)
+ ret i32 %elt.firstbithigh
}
define noundef i32 @test_firstbithigh_uint(i32 noundef %a) {
@@ -30,18 +30,18 @@ entry:
ret i32 %elt.firstbithigh
}
-define noundef i64 @test_firstbithigh_ulong(i64 noundef %a) {
+define noundef i32 @test_firstbithigh_ulong(i64 noundef %a) {
entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 33, i64 %{{.*}})
- %elt.firstbithigh = call i64 @llvm.dx.firstbituhigh.i64(i64 %a)
- ret i64 %elt.firstbithigh
+; CHECK: call i32 @dx.op.unary.i64(i32 33, i64 %{{.*}})
+ %elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i64(i64 %a)
+ ret i32 %elt.firstbithigh
}
-define noundef i64 @test_firstbithigh_long(i64 noundef %a) {
+define noundef i32 @test_firstbithigh_long(i64 noundef %a) {
entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 34, i64 %{{.*}})
- %elt.firstbithigh = call i64 @llvm.dx.firstbitshigh.i64(i64 %a)
- ret i64 %elt.firstbithigh
+; CHECK: call i32 @dx.op.unary.i64(i32 34, i64 %{{.*}})
+ %elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i64(i64 %a)
+ ret i32 %elt.firstbithigh
}
define noundef <4 x i32> @test_firstbituhigh_vec4_i32(<4 x i32> noundef %a) {
@@ -80,12 +80,12 @@ entry:
ret <4 x i32> %2
}
-declare i16 @llvm.dx.firstbituhigh.i16(i16)
+declare i32 @llvm.dx.firstbituhigh.i16(i16)
declare i32 @llvm.dx.firstbituhigh.i32(i32)
-declare i64 @llvm.dx.firstbituhigh.i64(i64)
+declare i32 @llvm.dx.firstbituhigh.i64(i64)
declare <4 x i32> @llvm.dx.firstbituhigh.v4i32(<4 x i32>)
-declare i16 @llvm.dx.firstbitshigh.i16(i16)
+declare i32 @llvm.dx.firstbitshigh.i16(i16)
declare i32 @llvm.dx.firstbitshigh.i32(i32)
-declare i64 @llvm.dx.firstbitshigh.i64(i64)
+declare i32 @llvm.dx.firstbitshigh.i64(i64)
declare <4 x i32> @llvm.dx.firstbitshigh.v4i32(<4 x i32>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
index e20a7a4cefe8ee..057b1a9c78722a 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
@@ -10,11 +10,57 @@ entry:
ret i32 %elt.firstbituhigh
}
-define noundef i16 @firstbituhigh_i16(i16 noundef %a) {
+define noundef <2 x i32> @firstbituhigh_2xi32(<2 x i32> noundef %a) {
entry:
; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb %[[#]]
- %elt.firstbituhigh = call i16 @llvm.spv.firstbituhigh.i16(i16 %a)
- ret i16 %elt.firstbituhigh
+ %elt.firstbituhigh = call <2 x i32> @llvm.spv.firstbituhigh.v2i32(<2 x i32> %a)
+ ret <2 x i32> %elt.firstbituhigh
+}
+
+define noundef i32 @firstbituhigh_i16(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = OpUConvert %[[#]]
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb [[A]]
+ %elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i16(i16 %a)
+ ret i32 %elt.firstbituhigh
+}
+
+define noundef <2 x i32> @firstbituhigh_v2i16(<2 x i16> noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = OpUConvert %[[#]]
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindUMsb [[A]]
+ %elt.firstbituhigh = call <2 x i32> @llvm.spv.firstbituhigh.v2i16(<2 x i16> %a)
+ ret <2 x i32> %elt.firstbituhigh
+}
+
+define noundef i32 @firstbituhigh_i64(i64 noundef %a) {
+entry:
+; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
+; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindUMsb [[O]]
+; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
+; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
+; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
+; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
+; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
+; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
+ %elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i64(i64 %a)
+ ret i32 %elt.firstbituhigh
+}
+
+define noundef <2 x i32> @firstbituhigh_v2i64(<2 x i64> noundef %a) {
+entry:
+; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
+; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindUMsb [[O]]
+; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
+; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
+; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
+; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
+; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
+; CHECK: OpReturnValue [[B]]
+ %elt.firstbituhigh = call <2 x i32> @llvm.spv.firstbituhigh.v2i64(<2 x i64> %a)
+ ret <2 x i32> %elt.firstbituhigh
}
define noundef i32 @firstbitshigh_i32(i32 noundef %a) {
@@ -24,14 +70,38 @@ entry:
ret i32 %elt.firstbitshigh
}
-define noundef i16 @firstbitshigh_i16(i16 noundef %a) {
+define noundef i32 @firstbitshigh_i16(i16 noundef %a) {
entry:
+; CHECK: [[A:%.*]] = OpSConvert %[[#]]
; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] FindSMsb %[[#]]
- %elt.firstbitshigh = call i16 @llvm.spv.firstbitshigh.i16(i16 %a)
- ret i16 %elt.firstbitshigh
+ %elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i16(i16 %a)
+ ret i32 %elt.firstbitshigh
+}
+
+define noundef i32 @firstbitshigh_i64(i64 noundef %a) {
+entry:
+; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
+; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindSMsb [[O]]
+; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
+; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
+; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
+; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
+; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
+; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
+ %elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i64(i64 %a)
+ ret i32 %elt.firstbitshigh
}
-declare i16 @llvm.spv.firstbituhigh.i16(i16)
-declare i32 @llvm.spv.firstbituhigh.i32(i32)
-declare i16 @llvm.spv.firstbitshigh.i16(i16)
-declare i32 @llvm.spv.firstbitshigh.i32(i32)
+;declare i16 @llvm.spv.firstbituhigh.i16(i16)
+;declare i32 @llvm.spv.firstbituhigh.i32(i32)
+;declare i64 @llvm.spv.firstbituhigh.i64(i64)
+;declare i16 @llvm.spv.firstbituhigh.v2i16(<2 x i16>)
+;declare i32 @llvm.spv.firstbituhigh.v2i32(<2 x i32>)
+;declare i64 @llvm.spv.firstbituhigh.v2i64(<2 x i64>)
+;declare i16 @llvm.spv.firstbitshigh.i16(i16)
+;declare i32 @llvm.spv.firstbitshigh.i32(i32)
+;declare i64 @llvm.spv.firstbitshigh.i64(i64)
+;declare i16 @llvm.spv.firstbitshigh.v2i16(<2 x i16>)
+;declare i32 @llvm.spv.firstbitshigh.v2i32(<2 x i32>)
+;declare i64 @llvm.spv.firstbitshigh.v2i64(<2 x i64>)
>From d080a9b9866ae25d257df61ba4b027d4c1b498f8 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 16 Oct 2024 16:41:04 +0000
Subject: [PATCH 05/10] make clang format happy
---
clang/lib/CodeGen/CGBuiltin.cpp | 2 +-
clang/lib/Sema/SemaHLSL.cpp | 3 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 135 +++++++++---------
3 files changed, 71 insertions(+), 69 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 3c1ab642b1166d..a8bfa7d5df64e4 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18870,7 +18870,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *X = EmitScalarExpr(E->getArg(0));
return Builder.CreateIntrinsic(
- /*ReturnType=*/ConvertType(E->getType()),
+ /*ReturnType=*/ConvertType(E->getType()),
getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index f2ae14148311d8..ee7a1f455f7b6a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1959,7 +1959,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
if (auto *VecTy = EltTy->getAs<VectorType>()) {
EltTy = VecTy->getElementType();
- ResTy = SemaRef.Context.getVectorType(ResTy, VecTy->getNumElements(), VecTy->getVectorKind());
+ ResTy = SemaRef.Context.getVectorType(ResTy, VecTy->getNumElements(),
+ VecTy->getVectorKind());
}
if (!EltTy->isIntegerType()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ce9af5c2968480..e7d533287d5401 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -93,24 +93,24 @@ class SPIRVInstructionSelector : public InstructionSelector {
MachineInstr &I) const;
bool selectFirstBitHigh(Register ResVReg, const SPIRVType *ResType,
- MachineInstr &I, bool IsSigned) const;
+ MachineInstr &I, bool IsSigned) const;
bool selectFirstBitHigh16(Register ResVReg, const SPIRVType *ResType,
- MachineInstr &I, bool IsSigned) const;
+ MachineInstr &I, bool IsSigned) const;
bool selectFirstBitHigh32(Register ResVReg, const SPIRVType *ResType,
- MachineInstr &I, Register SrcReg,
- bool IsSigned) const;
+ MachineInstr &I, Register SrcReg,
+ bool IsSigned) const;
bool selectFirstBitHigh64(Register ResVReg, const SPIRVType *ResType,
- MachineInstr &I, bool IsSigned) const;
+ MachineInstr &I, bool IsSigned) const;
bool selectGlobalValue(Register ResVReg, MachineInstr &I,
const MachineInstr *Init = nullptr) const;
bool selectNAryOpWithSrcs(Register ResVReg, const SPIRVType *ResType,
- MachineInstr &I, std::vector<Register> SrcRegs,
- unsigned Opcode) const;
+ MachineInstr &I, std::vector<Register> SrcRegs,
+ unsigned Opcode) const;
bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, Register SrcReg,
@@ -837,14 +837,14 @@ bool SPIRVInstructionSelector::selectExtInst(Register ResVReg,
}
bool SPIRVInstructionSelector::selectNAryOpWithSrcs(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I,
- std::vector<Register> Srcs,
- unsigned Opcode) const {
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ std::vector<Register> Srcs,
+ unsigned Opcode) const {
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType));
- for(Register SReg : Srcs) {
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType));
+ for (Register SReg : Srcs) {
MIB.addUse(SReg);
}
return MIB.constrainAllUses(TII, TRI, RBI);
@@ -2714,46 +2714,46 @@ Register SPIRVInstructionSelector::buildPointerToResource(
}
bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I,
- bool IsSigned) const {
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert;
// zero or sign extend
Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
- bool Result = selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(),
- Opcode);
+ bool Result =
+ selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(), Opcode);
return Result & selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
}
bool SPIRVInstructionSelector::selectFirstBitHigh32(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I,
- Register SrcReg,
- bool IsSigned) const {
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ Register SrcReg,
+ bool IsSigned) const {
unsigned Opcode = IsSigned ? GL::FindSMsb : GL::FindUMsb;
return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
- .addImm(Opcode)
- .addUse(SrcReg)
- .constrainAllUses(TII, TRI, RBI);
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(Opcode)
+ .addUse(SrcReg)
+ .constrainAllUses(TII, TRI, RBI);
}
bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I,
- bool IsSigned) const {
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
Register OpReg = I.getOperand(2).getReg();
// 1. split our int64 into 2 pieces using a bitcast
unsigned count = GR.getScalarOrVectorComponentCount(ResType);
SPIRVType *baseType = GR.retrieveScalarOrVectorIntType(ResType);
MachineIRBuilder MIRBuilder(I);
- SPIRVType *postCastT = GR.getOrCreateSPIRVVectorType(baseType, 2 * count,
- MIRBuilder);
+ SPIRVType *postCastT =
+ GR.getOrCreateSPIRVVectorType(baseType, 2 * count, MIRBuilder);
Register bitcastReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
- bool Result = selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg,
- SPIRV::OpBitcast);
+ bool Result =
+ selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg, SPIRV::OpBitcast);
// 2. call firstbithigh
Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
@@ -2771,46 +2771,48 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
// count should be one.
Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
- auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpVectorShuffle))
- .addDef(HighReg)
- .addUse(GR.getSPIRVTypeID(VResType))
- .addUse(FBHReg)
- .addUse(FBHReg); // this vector will not be selected from; could be empty
+ auto MIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(HighReg)
+ .addUse(GR.getSPIRVTypeID(VResType))
+ .addUse(FBHReg)
+ .addUse(
+ FBHReg); // this vector will not be selected from; could be empty
unsigned i;
- for(i = 0; i < count*2; i += 2) {
+ for (i = 0; i < count * 2; i += 2) {
MIB.addImm(i);
}
Result &= MIB.constrainAllUses(TII, TRI, RBI);
// get low bits
Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
- MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpVectorShuffle))
- .addDef(LowReg)
- .addUse(GR.getSPIRVTypeID(VResType))
- .addUse(FBHReg)
- .addUse(FBHReg); // this vector will not be selected from; could be empty
- for(i = 1; i < count*2; i += 2) {
+ MIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(LowReg)
+ .addUse(GR.getSPIRVTypeID(VResType))
+ .addUse(FBHReg)
+ .addUse(
+ FBHReg); // this vector will not be selected from; could be empty
+ for (i = 1; i < count * 2; i += 2) {
MIB.addImm(i);
}
Result &= MIB.constrainAllUses(TII, TRI, RBI);
- SPIRVType *BoolType =
- GR.getOrCreateSPIRVVectorType(GR.getOrCreateSPIRVBoolType(I, TII),
- count,
- MIRBuilder);
+ SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
+ GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
// check if the high bits are == -1;
Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
// true if -1
Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
- SPIRV::OpIEqual);
+ SPIRV::OpIEqual);
// Select low bits if true in BReg, otherwise high bits
Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
- SPIRV::OpSelectVIVCond);
+ SPIRV::OpSelectVIVCond);
// Add 32 for high bits, 0 for low bits
Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
@@ -2818,29 +2820,28 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
- SPIRV::OpSelectVIVCond);
+ SPIRV::OpSelectVIVCond);
Register AddReg = ResVReg;
- if(isScalarRes)
+ if (isScalarRes)
AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
- SPIRV::OpIAddV);
+ SPIRV::OpIAddV);
// convert result back to scalar if necessary
if (!isScalarRes)
return Result;
else
- return Result & selectNAryOpWithSrcs(ResVReg, ResType, I,
- {AddReg,
- GR.getOrCreateConstInt(0, I, ResType,
- TII)},
- SPIRV::OpVectorExtractDynamic);
+ return Result & selectNAryOpWithSrcs(
+ ResVReg, ResType, I,
+ {AddReg, GR.getOrCreateConstInt(0, I, ResType, TII)},
+ SPIRV::OpVectorExtractDynamic);
}
bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I,
- bool IsSigned) const {
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
// FindUMsb intrinsic only supports 32 bit integers
Register OpReg = I.getOperand(2).getReg();
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
>From 2b0e964356d88caa3bac4e1ecce3de4fae7683c4 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 16 Oct 2024 19:59:06 +0000
Subject: [PATCH 06/10] switch to unaryBits
---
llvm/lib/Target/DirectX/DXIL.td | 4 ++--
llvm/test/CodeGen/DirectX/firstbithigh.ll | 28 +++++++++++------------
2 files changed, 16 insertions(+), 16 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index e676f90515c710..f5a952d82c11ee 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -564,7 +564,7 @@ def CountBits : DXILOp<31, unaryBits> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def FirstbitHi : DXILOp<33, unary> {
+def FirstbitHi : DXILOp<33, unaryBits> {
let Doc = "Returns the location of the first set bit starting from "
"the highest order bit and working downward.";
let LLVMIntrinsic = int_dx_firstbituhigh;
@@ -576,7 +576,7 @@ def FirstbitHi : DXILOp<33, unary> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def FirstbitSHi : DXILOp<34, unary> {
+def FirstbitSHi : DXILOp<34, unaryBits> {
let Doc = "Returns the location of the first set bit from "
"the highest order bit based on the sign.";
let LLVMIntrinsic = int_dx_firstbitshigh;
diff --git a/llvm/test/CodeGen/DirectX/firstbithigh.ll b/llvm/test/CodeGen/DirectX/firstbithigh.ll
index de0b11c97a9b98..5584c433fb6f0e 100644
--- a/llvm/test/CodeGen/DirectX/firstbithigh.ll
+++ b/llvm/test/CodeGen/DirectX/firstbithigh.ll
@@ -4,42 +4,42 @@
define noundef i32 @test_firstbithigh_ushort(i16 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i16(i32 33, i16 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i16(i32 33, i16 %{{.*}})
%elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i16(i16 %a)
ret i32 %elt.firstbithigh
}
define noundef i32 @test_firstbithigh_short(i16 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i16(i32 34, i16 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i16(i32 34, i16 %{{.*}})
%elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i16(i16 %a)
ret i32 %elt.firstbithigh
}
define noundef i32 @test_firstbithigh_uint(i32 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 33, i32 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i32(i32 33, i32 %{{.*}})
%elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i32(i32 %a)
ret i32 %elt.firstbithigh
}
define noundef i32 @test_firstbithigh_int(i32 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 34, i32 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i32(i32 34, i32 %{{.*}})
%elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i32(i32 %a)
ret i32 %elt.firstbithigh
}
define noundef i32 @test_firstbithigh_ulong(i64 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i64(i32 33, i64 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i64(i32 33, i64 %{{.*}})
%elt.firstbithigh = call i32 @llvm.dx.firstbituhigh.i64(i64 %a)
ret i32 %elt.firstbithigh
}
define noundef i32 @test_firstbithigh_long(i64 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i64(i32 34, i64 %{{.*}})
+; CHECK: call i32 @dx.op.unaryBits.i64(i32 34, i64 %{{.*}})
%elt.firstbithigh = call i32 @llvm.dx.firstbitshigh.i64(i64 %a)
ret i32 %elt.firstbithigh
}
@@ -47,13 +47,13 @@ entry:
define noundef <4 x i32> @test_firstbituhigh_vec4_i32(<4 x i32> noundef %a) {
entry:
; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
- ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee0]])
+ ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 33, i32 [[ee0]])
; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
- ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee1]])
+ ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 33, i32 [[ee1]])
; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
- ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee2]])
+ ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 33, i32 [[ee2]])
; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
- ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 33, i32 [[ee3]])
+ ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 33, i32 [[ee3]])
; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
@@ -65,13 +65,13 @@ entry:
define noundef <4 x i32> @test_firstbitshigh_vec4_i32(<4 x i32> noundef %a) {
entry:
; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
- ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee0]])
+ ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 34, i32 [[ee0]])
; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
- ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee1]])
+ ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 34, i32 [[ee1]])
; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
- ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee2]])
+ ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 34, i32 [[ee2]])
; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
- ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 34, i32 [[ee3]])
+ ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 34, i32 [[ee3]])
; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
>From 85a1d4e620f74efa2011cefb1c848fd62a4ebfd3 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 23 Oct 2024 20:09:01 +0000
Subject: [PATCH 07/10] address pr comments
---
.../CodeGenHLSL/builtins/firstbithigh.hlsl | 2 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 59 ++++++++++---------
2 files changed, 32 insertions(+), 29 deletions(-)
diff --git a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
index ce94a1c15c5332..debf6b6d3e3f5a 100644
--- a/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/firstbithigh.hlsl
@@ -150,4 +150,4 @@ uint3 test_firstbithigh_long3(int64_t3 p0) {
// CHECK: call <4 x i32> @llvm.[[TARGET]].firstbitshigh.v4i64
uint4 test_firstbithigh_long4(int64_t4 p0) {
return firstbithigh(p0);
-}
\ No newline at end of file
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index e7d533287d5401..b2d0935b939bb1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2590,9 +2590,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_sign:
return selectSign(ResVReg, ResType, I);
case Intrinsic::spv_firstbituhigh: // There is no CL equivalent of FindUMsb
- return selectFirstBitHigh(ResVReg, ResType, I, false);
+ return selectFirstBitHigh(ResVReg, ResType, I, /*IsSigned=*/false);
case Intrinsic::spv_firstbitshigh: // There is no CL equivalent of FindSMsb
- return selectFirstBitHigh(ResVReg, ResType, I, true);
+ return selectFirstBitHigh(ResVReg, ResType, I, /*IsSigned=*/true);
case Intrinsic::spv_group_memory_barrier_with_group_sync: {
Register MemSemReg =
buildI32Constant(SPIRV::MemorySemantics::SequentiallyConsistent, I);
@@ -2771,32 +2771,30 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
// count should be one.
Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
- auto MIB =
- BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpVectorShuffle))
- .addDef(HighReg)
- .addUse(GR.getSPIRVTypeID(VResType))
- .addUse(FBHReg)
- .addUse(
- FBHReg); // this vector will not be selected from; could be empty
- unsigned i;
- for (i = 0; i < count * 2; i += 2) {
- MIB.addImm(i);
+ auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(HighReg)
+ .addUse(GR.getSPIRVTypeID(VResType))
+ .addUse(FBHReg)
+ .addUse(FBHReg);
+ // ^^ this vector will not be selected from; could be empty
+ unsigned j;
+ for (j = 0; j < count * 2; j += 2) {
+ MIB.addImm(j);
}
Result &= MIB.constrainAllUses(TII, TRI, RBI);
// get low bits
Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
- MIB =
- BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpVectorShuffle))
- .addDef(LowReg)
- .addUse(GR.getSPIRVTypeID(VResType))
- .addUse(FBHReg)
- .addUse(
- FBHReg); // this vector will not be selected from; could be empty
- for (i = 1; i < count * 2; i += 2) {
- MIB.addImm(i);
+ MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(LowReg)
+ .addUse(GR.getSPIRVTypeID(VResType))
+ .addUse(FBHReg)
+ .addUse(FBHReg);
+ // ^^ this vector will not be selected from; could be empty
+ for (j = 1; j < count * 2; j += 2) {
+ MIB.addImm(j);
}
Result &= MIB.constrainAllUses(TII, TRI, RBI);
@@ -2825,6 +2823,7 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
Register AddReg = ResVReg;
if (isScalarRes)
AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+
Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
SPIRV::OpIAddV);
@@ -2842,17 +2841,21 @@ bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
bool IsSigned) const {
- // FindUMsb intrinsic only supports 32 bit integers
+ // FindUMsb and FindSMsb intrinsics only support 32 bit integers
Register OpReg = I.getOperand(2).getReg();
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
- unsigned bitWidth = GR.getScalarOrVectorBitWidth(OpType);
- if (bitWidth == 16)
+ switch (GR.getScalarOrVectorBitWidth(OpType)) {
+ case 16:
return selectFirstBitHigh16(ResVReg, ResType, I, IsSigned);
- else if (bitWidth == 32)
+ case 32:
return selectFirstBitHigh32(ResVReg, ResType, I, OpReg, IsSigned);
- else // 64 bit
+ case 64:
return selectFirstBitHigh64(ResVReg, ResType, I, IsSigned);
+ default:
+ report_fatal_error(
+ "spv_firstbituhigh and spv_firstbitshigh only support 16,32,64 bits.");
+ }
}
bool SPIRVInstructionSelector::selectAllocaArray(Register ResVReg,
>From d8b9af7c8f2dd26e45b86f0ddfb4745ad4abfaad Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 24 Oct 2024 18:21:55 +0000
Subject: [PATCH 08/10] remove use of vectors of length 1 in firstbithigh64
code
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 9 ++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 130 +++++++++---------
.../SPIRV/hlsl-intrinsics/firstbithigh.ll | 12 +-
4 files changed, 84 insertions(+), 71 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index f66506beaa6ed6..3e491733e0c62a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -449,6 +449,15 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
return Res;
}
+Register SPIRVGlobalRegistry::getOrCreateConstScalarOrVector(
+ uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII, bool ZeroAsNull) {
+ if (SpvType->getOpcode() == SPIRV::OpTypeVector)
+ return getOrCreateConstVector(Val, I, SpvType, TII, ZeroAsNull);
+ else
+ return getOrCreateConstInt(Val, I, SpvType, TII, ZeroAsNull);
+}
+
Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
MachineInstr &I,
SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a95b488960c4c3..7a6174523e6877 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -492,6 +492,10 @@ class SPIRVGlobalRegistry {
Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType = nullptr);
+ Register getOrCreateConstScalarOrVector(uint64_t Val, MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII,
+ bool ZeroAsNull = true);
Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b2d0935b939bb1..fbbf4f6d3f33a9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2759,82 +2759,82 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
- // 3. check if result of each top 32 bits is == -1
- // split result vector into vector of high bits and vector of low bits
- // get high bits
- // if ResType is a scalar we need a vector anyways because our code
- // operates on vectors, even vectors of length one.
- SPIRVType *VResType = ResType;
+ // 3. split result vector into high bits and low bits
+ Register HighReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+ Register LowReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+
+ bool ZeroAsNull = STI.isOpenCLEnv();
bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
- if (isScalarRes)
- VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
- // count should be one.
+ if (isScalarRes) {
+ // if scalar do a vector extract
+ Result &= selectNAryOpWithSrcs(
+ HighReg, ResType, I,
+ {FBHReg, GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull)},
+ SPIRV::OpVectorExtractDynamic);
+ Result &= selectNAryOpWithSrcs(
+ LowReg, ResType, I,
+ {FBHReg, GR.getOrCreateConstInt(1, I, ResType, TII, ZeroAsNull)},
+ SPIRV::OpVectorExtractDynamic);
+ } else { // vector case do a shufflevector
+ auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(HighReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(FBHReg)
+ .addUse(FBHReg);
+ // ^^ this vector will not be selected from; could be empty
+ unsigned j;
+ for (j = 0; j < count * 2; j += 2) {
+ MIB.addImm(j);
+ }
+ Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+ // get low bits
+ MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(LowReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(FBHReg)
+ .addUse(FBHReg);
+ // ^^ this vector will not be selected from; could be empty
+ for (j = 1; j < count * 2; j += 2) {
+ MIB.addImm(j);
+ }
+ Result &= MIB.constrainAllUses(TII, TRI, RBI);
+ }
+
+ // 4. check if result of each top 32 bits is == -1
+ SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII);
+ if (!isScalarRes)
+ BoolType = GR.getOrCreateSPIRVVectorType(BoolType, count, MIRBuilder);
- Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
- auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpVectorShuffle))
- .addDef(HighReg)
- .addUse(GR.getSPIRVTypeID(VResType))
- .addUse(FBHReg)
- .addUse(FBHReg);
- // ^^ this vector will not be selected from; could be empty
- unsigned j;
- for (j = 0; j < count * 2; j += 2) {
- MIB.addImm(j);
- }
- Result &= MIB.constrainAllUses(TII, TRI, RBI);
-
- // get low bits
- Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
- MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpVectorShuffle))
- .addDef(LowReg)
- .addUse(GR.getSPIRVTypeID(VResType))
- .addUse(FBHReg)
- .addUse(FBHReg);
- // ^^ this vector will not be selected from; could be empty
- for (j = 1; j < count * 2; j += 2) {
- MIB.addImm(j);
- }
- Result &= MIB.constrainAllUses(TII, TRI, RBI);
-
- SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
- GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
// check if the high bits are == -1;
- Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
+ Register NegOneReg =
+ GR.getOrCreateConstScalarOrVector(-1, I, ResType, TII, ZeroAsNull);
// true if -1
Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
SPIRV::OpIEqual);
// Select low bits if true in BReg, otherwise high bits
- Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
- Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
- SPIRV::OpSelectVIVCond);
+ unsigned selectOp =
+ isScalarRes ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
+ Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+ Result &= selectNAryOpWithSrcs(TmpReg, ResType, I, {BReg, LowReg, HighReg},
+ selectOp);
// Add 32 for high bits, 0 for low bits
- Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
- bool ZeroAsNull = STI.isOpenCLEnv();
- Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
- Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
- Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
- SPIRV::OpSelectVIVCond);
-
- Register AddReg = ResVReg;
- if (isScalarRes)
- AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-
- Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
- SPIRV::OpIAddV);
-
- // convert result back to scalar if necessary
- if (!isScalarRes)
- return Result;
- else
- return Result & selectNAryOpWithSrcs(
- ResVReg, ResType, I,
- {AddReg, GR.getOrCreateConstInt(0, I, ResType, TII)},
- SPIRV::OpVectorExtractDynamic);
+ Register ValReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+ Register Reg0 =
+ GR.getOrCreateConstScalarOrVector(0, I, ResType, TII, ZeroAsNull);
+ Register Reg32 =
+ GR.getOrCreateConstScalarOrVector(32, I, ResType, TII, ZeroAsNull);
+ Result &=
+ selectNAryOpWithSrcs(ValReg, ResType, I, {BReg, Reg0, Reg32}, selectOp);
+
+ return Result &=
+ selectNAryOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg},
+ isScalarRes ? SPIRV::OpIAddS : SPIRV::OpIAddV);
}
bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
index 057b1a9c78722a..2b2012a5d8ba1e 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll
@@ -2,6 +2,8 @@
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK: OpMemoryModel Logical GLSL450
+; CHECK: [[Z:%.*]] = OpConstant %[[#]] 0
+; CHECK: [[X:%.*]] = OpConstant %[[#]] 1
define noundef i32 @firstbituhigh_i32(i32 noundef %a) {
entry:
@@ -37,13 +39,12 @@ define noundef i32 @firstbituhigh_i64(i64 noundef %a) {
entry:
; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindUMsb [[O]]
-; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
-; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[M:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[Z]]
+; CHECK: [[L:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[X]]
; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
-; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
%elt.firstbituhigh = call i32 @llvm.spv.firstbituhigh.i64(i64 %a)
ret i32 %elt.firstbituhigh
}
@@ -82,13 +83,12 @@ define noundef i32 @firstbitshigh_i64(i64 noundef %a) {
entry:
; CHECK: [[O:%.*]] = OpBitcast %[[#]] %[[#]]
; CHECK: [[N:%.*]] = OpExtInst %[[#]] %[[#]] FindSMsb [[O]]
-; CHECK: [[M:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 0
-; CHECK: [[L:%.*]] = OpVectorShuffle %[[#]] [[N]] [[N]] 1
+; CHECK: [[M:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[Z]]
+; CHECK: [[L:%.*]] = OpVectorExtractDynamic %[[#]] [[N]] [[X]]
; CHECK: [[I:%.*]] = OpIEqual %[[#]] [[M]] %[[#]]
; CHECK: [[H:%.*]] = OpSelect %[[#]] [[I]] [[L]] [[M]]
; CHECK: [[C:%.*]] = OpSelect %[[#]] [[I]] %[[#]] %[[#]]
; CHECK: [[B:%.*]] = OpIAdd %[[#]] [[C]] [[H]]
-; CHECK: [[#]] = OpVectorExtractDynamic %[[#]] [[B]] %[[#]]
%elt.firstbitshigh = call i32 @llvm.spv.firstbitshigh.i64(i64 %a)
ret i32 %elt.firstbitshigh
}
>From eb759952c94b38a7a89878c105e906034457b7a9 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 30 Oct 2024 20:11:23 +0000
Subject: [PATCH 09/10] address pr comments
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 9 -----
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 --
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 39 +++++++++++--------
3 files changed, 23 insertions(+), 29 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 3e491733e0c62a..f66506beaa6ed6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -449,15 +449,6 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
return Res;
}
-Register SPIRVGlobalRegistry::getOrCreateConstScalarOrVector(
- uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
- const SPIRVInstrInfo &TII, bool ZeroAsNull) {
- if (SpvType->getOpcode() == SPIRV::OpTypeVector)
- return getOrCreateConstVector(Val, I, SpvType, TII, ZeroAsNull);
- else
- return getOrCreateConstInt(Val, I, SpvType, TII, ZeroAsNull);
-}
-
Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
MachineInstr &I,
SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7a6174523e6877..a95b488960c4c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -492,10 +492,6 @@ class SPIRVGlobalRegistry {
Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType = nullptr);
- Register getOrCreateConstScalarOrVector(uint64_t Val, MachineInstr &I,
- SPIRVType *SpvType,
- const SPIRVInstrInfo &TII,
- bool ZeroAsNull = true);
Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index fbbf4f6d3f33a9..d0581f9bdef646 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2722,7 +2722,7 @@ bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
bool Result =
selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(), Opcode);
- return Result & selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
+ return Result && selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
}
bool SPIRVInstructionSelector::selectFirstBitHigh32(Register ResVReg,
@@ -2805,36 +2805,43 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
// 4. check if result of each top 32 bits is == -1
SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII);
- if (!isScalarRes)
+ Register NegOneReg;
+ Register Reg0;
+ Register Reg32;
+ unsigned selectOp;
+ unsigned addOp;
+ if (isScalarRes) {
+ NegOneReg = GR.getOrCreateConstInt(-1, I, ResType, TII, ZeroAsNull);
+ Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+ Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
+ selectOp = SPIRV::OpSelectSISCond;
+ addOp = SPIRV::OpIAddS;
+ } else {
BoolType = GR.getOrCreateSPIRVVectorType(BoolType, count, MIRBuilder);
-
- // check if the high bits are == -1;
- Register NegOneReg =
- GR.getOrCreateConstScalarOrVector(-1, I, ResType, TII, ZeroAsNull);
- // true if -1
+ NegOneReg = GR.getOrCreateConstVector(-1, I, ResType, TII, ZeroAsNull);
+ Reg0 = GR.getOrCreateConstVector(0, I, ResType, TII, ZeroAsNull);
+ Reg32 = GR.getOrCreateConstVector(32, I, ResType, TII, ZeroAsNull);
+ selectOp = SPIRV::OpSelectVIVCond;
+ addOp = SPIRV::OpIAddV;
+ }
+
+ // check if the high bits are == -1; true if -1
Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
SPIRV::OpIEqual);
// Select low bits if true in BReg, otherwise high bits
- unsigned selectOp =
- isScalarRes ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
Result &= selectNAryOpWithSrcs(TmpReg, ResType, I, {BReg, LowReg, HighReg},
selectOp);
// Add 32 for high bits, 0 for low bits
Register ValReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
- Register Reg0 =
- GR.getOrCreateConstScalarOrVector(0, I, ResType, TII, ZeroAsNull);
- Register Reg32 =
- GR.getOrCreateConstScalarOrVector(32, I, ResType, TII, ZeroAsNull);
Result &=
selectNAryOpWithSrcs(ValReg, ResType, I, {BReg, Reg0, Reg32}, selectOp);
- return Result &=
- selectNAryOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg},
- isScalarRes ? SPIRV::OpIAddS : SPIRV::OpIAddV);
+ return Result &&
+ selectNAryOpWithSrcs(ResVReg, ResType, I, {ValReg, TmpReg}, addOp);
}
bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
>From 8457cb258afb010122878698aac08ebbc26a62ea Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Tue, 5 Nov 2024 18:37:59 +0000
Subject: [PATCH 10/10] cast to unsigned
---
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d0581f9bdef646..99642414312c0d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2811,14 +2811,16 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
unsigned selectOp;
unsigned addOp;
if (isScalarRes) {
- NegOneReg = GR.getOrCreateConstInt(-1, I, ResType, TII, ZeroAsNull);
+ NegOneReg =
+ GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
selectOp = SPIRV::OpSelectSISCond;
addOp = SPIRV::OpIAddS;
} else {
BoolType = GR.getOrCreateSPIRVVectorType(BoolType, count, MIRBuilder);
- NegOneReg = GR.getOrCreateConstVector(-1, I, ResType, TII, ZeroAsNull);
+ NegOneReg =
+ GR.getOrCreateConstVector((unsigned)-1, I, ResType, TII, ZeroAsNull);
Reg0 = GR.getOrCreateConstVector(0, I, ResType, TII, ZeroAsNull);
Reg32 = GR.getOrCreateConstVector(32, I, ResType, TII, ZeroAsNull);
selectOp = SPIRV::OpSelectVIVCond;
More information about the llvm-commits
mailing list