[clang] [llvm] [HLSL] Add WaveActiveBitOr HLSL function (PR #178056)
Joshua Batista via cfe-commits
cfe-commits at lists.llvm.org
Thu Jan 29 22:46:12 PST 2026
https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/178056
>From fe1f4f48f6c5f6e0a41fe3ef1b1c77dc8b93c809 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 23 Jan 2026 12:20:11 -0800
Subject: [PATCH 1/7] handle arg promotion with customtypechecking
---
clang/include/clang/Basic/Builtins.td | 6 ++
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 58 ++++++++++++++++
.../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 24 +++++++
clang/lib/Sema/SemaHLSL.cpp | 26 +++++++
.../CodeGenHLSL/builtins/WaveActiveBitOr.hlsl | 67 +++++++++++++++++++
.../CodeGenHLSL/builtins/WaveActiveSum.hlsl | 11 +--
.../BuiltIns/WaveActiveBitOr-errors.hlsl | 30 +++++++++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 +
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 +
llvm/lib/Target/DirectX/DXIL.td | 25 +++++++
.../DirectX/DirectXTargetTransformInfo.cpp | 1 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 32 +++++++++
12 files changed, 278 insertions(+), 4 deletions(-)
create mode 100644 clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index bc8f1474493b0..e2c46634081a0 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5091,6 +5091,12 @@ def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> {
let Prototype = "_ExtVector<4, unsigned int>(bool)";
}
+def HLSLWaveActiveBitOr : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_active_bit_or"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 36691c7b72efe..d3d925e6fc8f4 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -276,6 +276,51 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
return RT.getFirstBitUHighIntrinsic();
}
+// select and return a specific wave bit op intrinsic,
+// based on the provided op kind.
+// OpKinds:
+// And = 0, bitwise and of values
+// Or = 1, bitwise or of values
+// Xor = 2, bitwise xor of values
+static Intrinsic::ID getWaveBitOpIntrinsic(int OpKind,
+ llvm::Triple::ArchType Arch,
+ CGHLSLRuntime &RT, QualType QT) {
+ switch (Arch) {
+ case llvm::Triple::spirv:
+ switch (OpKind) {
+
+ case 0:
+ case 2: {
+ llvm_unreachable("Not implemented yet!");
+ }
+ case 1: {
+ return Intrinsic::spv_wave_bit_or;
+ }
+ default: {
+ llvm_unreachable("Unexpected SubOp ID");
+ }
+ }
+ case llvm::Triple::dxil: {
+ switch (OpKind) {
+
+ case 0:
+ case 2: {
+ llvm_unreachable("Not implemented yet!");
+ }
+ case 1: {
+ return Intrinsic::dx_wave_bit_or;
+ }
+ default: {
+ llvm_unreachable("Unexpected SubOp ID");
+ }
+ }
+ }
+ default:
+ llvm_unreachable("Intrinsic WaveActiveBitOr"
+ " not supported by target architecture");
+ }
+}
+
// Return wave active sum that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
@@ -872,6 +917,19 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return handleHlslWaveActiveBallot(*this, E);
}
+ case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
+ Value *Op = EmitScalarExpr(E->getArg(0));
+ assert(Op->getType()->isIntegerTy() &&
+ "Intrinsic WaveActiveBitOr operand must be an integer type");
+
+ Intrinsic::ID IID = getWaveBitOpIntrinsic(
+ /* OpKind */ 1, getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
+ E->getArg(0)->getType());
+
+ return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
+ &CGM.getModule(), IID, {Op->getType()}),
+ ArrayRef{Op}, "hlsl.wave.active.bit.or");
+ }
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index e9a41b94d6c03..73481334b0abf 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2422,6 +2422,30 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_ballot)
__attribute__((convergent)) uint4 WaveActiveBallot(bool Val);
+/// \brief Returns the bitwise OR of all the values of Expr across all
+/// active non-helper lanes in the current wave, and replicates it back
+/// to all active non-helper lanes.
+/// \param Expr The integer expression to evaluate.
+/// \return The bitwise OR value of Expr across all active threads.
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) int16_t WaveActiveBitOr(int16_t Expr);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) int WaveActiveBitOr(int Expr);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) int64_t WaveActiveBitOr(int64_t Expr);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint16_t WaveActiveBitOr(uint16_t Expr);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint WaveActiveBitOr(uint Expr);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint64_t WaveActiveBitOr(uint64_t Expr);
+
/// \brief Counts the number of boolean variables which evaluate to true across
/// all active lanes in the current wave.
///
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 4d31e26d56e6b..1df66df1f3e70 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3583,6 +3583,32 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
TheCall->setType(ArgTyExpr);
break;
}
+ case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+
+ // Ensure input expr type is a scalar/vector and then
+ // set the return type to the arg type
+ QualType ArgType = TheCall->getArg(0)->getType();
+ auto *VTy = ArgType->getAs<VectorType>();
+ // not the scalar or vector<scalar>
+ if (!(ArgType->isScalarType())) {
+ SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+ diag::err_typecheck_expect_any_scalar_or_vector)
+ << ArgType << 0;
+ return true;
+ }
+
+ if (!(ArgType->isIntegerType())) {
+ SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+ diag::err_typecheck_expect_any_scalar_or_vector)
+ << ArgType << 0;
+ return true;
+ }
+
+ TheCall->setType(ArgType);
+ break;
+ }
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builtins will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
new file mode 100644
index 0000000000000..6b757ce60b389
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
@@ -0,0 +1,67 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \
+// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \
+// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \
+// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \
+// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i32([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i32([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.i32([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_int16
+int16_t test_int16_t(int16_t expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i16([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i16([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.i16([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.i16([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_int64
+int64_t test_int64_t(int64_t expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i64([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i64([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-LABEL: test_uint
+uint test_uint(uint expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i32([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i32([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-LABEL: test_uint16
+uint16_t test_uint16_t(uint16_t expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i16([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i16([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-LABEL: test_uint64
+uint64_t test_uint64_t(uint64_t expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i64([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i64([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl
index 1fc93c62c8db0..87ddd96e8368c 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl
@@ -1,9 +1,12 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
-// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
-// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \
+// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \
+// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
-// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
-// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \
+// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \
+// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
// Test basic lowering to runtime function call.
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
new file mode 100644
index 0000000000000..e3fd2eac28159
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
@@ -0,0 +1,30 @@
+// RUN: %clang_cc1 -finclude-default-header -fnative-int16-type -fnative-half-type \
+// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \
+// RUN: -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+ return __builtin_hlsl_wave_active_bit_or();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+int test_too_many_arg(int p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+float test_expr_bool_type_check(float p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0);
+ // expected-error at -1 {{invalid operand of type 'float'}}
+}
+
+float2 test_expr_bool_vec_type_check(float2 p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0);
+ // expected-error at -1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0);
+ // expected-error at -1 {{invalid operand of type 'S'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f79945785566c..acbabb128258d 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -154,6 +154,7 @@ def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_bit_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 293cb750cea98..303ee8e1a61bf 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -122,6 +122,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_subgroup_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+ def int_spv_wave_bit_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_min : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 3a40d2c36139d..7f170faef204c 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -215,6 +215,7 @@ defset list<DXILOpClass> OpClasses = {
def waveActiveAllEqual : DXILOpClass;
def waveActiveBit : DXILOpClass;
def waveActiveOp : DXILOpClass;
+ def waveBitOp : DXILOpClass;
def waveAllOp : DXILOpClass;
def waveAllTrue : DXILOpClass;
def waveAnyTrue : DXILOpClass;
@@ -317,6 +318,10 @@ defvar WaveOpKind_Product = 1;
defvar WaveOpKind_Min = 2;
defvar WaveOpKind_Max = 3;
+defvar WaveBitOpKind_And = 0;
+defvar WaveBitOpKind_Or = 1;
+defvar WaveBitOpKind_Xor = 2;
+
defvar SignedOpKind_Signed = 0;
defvar SignedOpKind_Unsigned = 1;
@@ -1124,6 +1129,26 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
let attributes = [Attributes<DXIL1_0, []>];
}
+// we may not need the third argument to intrinselect.
+def WaveBitOp : DXILOp<120, waveBitOp> {
+ let Doc = "returns the result of the bitwise operation across waves";
+ let intrinsics = [
+ IntrinSelect<int_dx_wave_bit_or,
+ [
+ IntrinArgIndex<0>, IntrinArgI8<WaveBitOpKind_Or>,
+ IntrinArgI8<SignedOpKind_Signed>
+ ]>
+ ];
+
+ let arguments = [OverloadTy, Int8Ty, Int8Ty];
+ let result = OverloadTy;
+ let overloads = [
+ Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>
+ ];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, []>];
+}
+
def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
let Doc = "returns the float16 stored in the low-half of the uint converted "
"to a float";
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index f54b48b91265e..4b69b5c5ce239 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -55,6 +55,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_saturate:
case Intrinsic::dx_splitdouble:
+ case Intrinsic::dx_wave_bit_or:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_min:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 626393d4ecb40..09e0d25b5f639 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -224,6 +224,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectWaveBitOr(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, bool IsUnsigned) const;
@@ -2711,6 +2714,33 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return Result;
}
+bool SPIRVInstructionSelector::selectWaveBitOr(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(1).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+ Register InputRegister = I.getOperand(1).getReg();
+ SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+
+ if (!InputType)
+ report_fatal_error("Input Type could not be determined.");
+ if (!GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeInt))
+ report_fatal_error("WaveActiveBitOr requires integer input");
+
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+
+ return BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformBitwiseOr))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+ .addImm(SPIRV::GroupOperation::Reduce)
+ .addUse(InputRegister)
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
@@ -3815,6 +3845,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
case Intrinsic::spv_wave_any:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
+ case Intrinsic::spv_wave_bit_or:
+ return selectWaveBitOr(ResVReg, ResType, I);
case Intrinsic::spv_subgroup_ballot:
return selectWaveOpInst(ResVReg, ResType, I,
SPIRV::OpGroupNonUniformBallot);
>From b8401525e8f117165f76c4b1b6366eb040799839 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Mon, 26 Jan 2026 14:41:42 -0800
Subject: [PATCH 2/7] remove unused var
---
clang/lib/Sema/SemaHLSL.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1df66df1f3e70..97f833e921239 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3590,8 +3590,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
// Ensure input expr type is a scalar/vector and then
// set the return type to the arg type
QualType ArgType = TheCall->getArg(0)->getType();
- auto *VTy = ArgType->getAs<VectorType>();
- // not the scalar or vector<scalar>
+
if (!(ArgType->isScalarType())) {
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_typecheck_expect_any_scalar_or_vector)
>From 89a3bf9ab27717f5aabc3de1b2db431b65c0a12b Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Mon, 26 Jan 2026 16:01:40 -0800
Subject: [PATCH 3/7] add enable 16bit preprocess flag
---
clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 0887a4c5d7a64..5760c251aac56 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2448,18 +2448,22 @@ __attribute__((convergent)) uint4 WaveActiveBallot(bool Val);
/// to all active non-helper lanes.
/// \param Expr The integer expression to evaluate.
/// \return The bitwise OR value of Expr across all active threads.
+#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) int16_t WaveActiveBitOr(int16_t Expr);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
-__attribute__((convergent)) int WaveActiveBitOr(int Expr);
+__attribute__((convergent)) uint16_t WaveActiveBitOr(uint16_t Expr);
+#endif
+
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
-__attribute__((convergent)) int64_t WaveActiveBitOr(int64_t Expr);
+__attribute__((convergent)) int WaveActiveBitOr(int Expr);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
-__attribute__((convergent)) uint16_t WaveActiveBitOr(uint16_t Expr);
+__attribute__((convergent)) int64_t WaveActiveBitOr(int64_t Expr);
+
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint WaveActiveBitOr(uint Expr);
>From 41ea8b0a7cbe49f24641ab0f9ae0a5bd23c349f8 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Tue, 27 Jan 2026 16:04:48 -0800
Subject: [PATCH 4/7] add some more tests
---
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 4 +--
llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll | 19 ++++++++++++
.../SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll | 30 +++++++++++++++++++
3 files changed, 51 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ca63bf3e17782..151f8da509735 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2723,9 +2723,9 @@ bool SPIRVInstructionSelector::selectWaveBitOr(Register ResVReg,
MachineInstr &I) const {
assert(I.getNumOperands() == 3);
- assert(I.getOperand(1).isReg());
+ assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();
- Register InputRegister = I.getOperand(1).getReg();
+ Register InputRegister = I.getOperand(2).getReg();
SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
if (!InputType)
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll b/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll
new file mode 100644
index 0000000000000..e7bc6a5292c3d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll
@@ -0,0 +1,19 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define noundef i32 @wave_bitor_simple(i32 noundef %p1) {
+entry:
+; CHECK: call i32 @dx.op.waveBitOp.i32(i32 120, i32 %p1, i8 1, i8 0)
+ %ret = call i32 @llvm.dx.wave.bit.or.i32(i32 %p1)
+ ret i32 %ret
+}
+
+declare i32 @llvm.dx.wave.bit.or.i32(i32)
+
+define noundef i64 @wave_bitor_simple64(i64 noundef %p1) {
+entry:
+; CHECK: call i64 @dx.op.waveBitOp.i64(i32 120, i64 %p1, i8 1, i8 0)
+ %ret = call i64 @llvm.dx.wave.bit.or.i64(i64 %p1)
+ ret i64 %ret
+}
+
+declare i64 @llvm.dx.wave.bit.or.i64(i64)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll
new file mode 100644
index 0000000000000..81b0bfe03dbe7
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll
@@ -0,0 +1,30 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val --target-env spv1.4 %}
+
+; Test lowering to spir-v backend for various types and scalar/vector
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#uint64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_uint
+; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]]
+define i32 @test_uint(i32 %iexpr) {
+entry:
+; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint]] %[[#scope]] Reduce %[[#iexpr]]
+ %0 = call i32 @llvm.spv.wave.bit.or.i32(i32 %iexpr)
+ ret i32 %0
+}
+
+declare i32 @llvm.spv.wave.bit.or.i32(i32)
+
+; CHECK-LABEL: Begin function test_uint64
+; CHECK: %[[#iexpr64:]] = OpFunctionParameter %[[#uint64]]
+define i64 @test_uint64(i64 %iexpr64) {
+entry:
+; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint64]] %[[#scope]] Reduce %[[#iexpr64]]
+ %0 = call i64 @llvm.spv.wave.bit.or.i64(i64 %iexpr64)
+ ret i64 %0
+}
+
+declare i64 @llvm.spv.wave.bit.or.i64(i64)
>From fcf01d2a7d7e8ac186c415378d0e5a219b72180f Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Thu, 29 Jan 2026 17:32:31 -0800
Subject: [PATCH 5/7] address Tex and Farzon
---
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 55 +++++--------------
clang/lib/Sema/SemaHLSL.cpp | 21 +++----
.../CodeGenHLSL/builtins/WaveActiveBitOr.hlsl | 33 +++++++++++
.../CodeGenHLSL/builtins/WaveActiveSum.hlsl | 11 ++--
.../BuiltIns/WaveActiveBitOr-errors.hlsl | 6 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 43 +++++++--------
6 files changed, 82 insertions(+), 87 deletions(-)
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 15491368caf82..24b1c2caefe7d 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -329,45 +329,15 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
return RT.getFirstBitUHighIntrinsic();
}
-// select and return a specific wave bit op intrinsic,
-// based on the provided op kind.
-// OpKinds:
-// And = 0, bitwise and of values
-// Or = 1, bitwise or of values
-// Xor = 2, bitwise xor of values
-static Intrinsic::ID getWaveBitOpIntrinsic(int OpKind,
- llvm::Triple::ArchType Arch,
- CGHLSLRuntime &RT, QualType QT) {
+static Intrinsic::ID getWaveBitOpOrIntrinsic(llvm::Triple::ArchType Arch,
+ CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
case llvm::Triple::spirv:
- switch (OpKind) {
+ return Intrinsic::spv_wave_bit_or;
- case 0:
- case 2: {
- llvm_unreachable("Not implemented yet!");
- }
- case 1: {
- return Intrinsic::spv_wave_bit_or;
- }
- default: {
- llvm_unreachable("Unexpected SubOp ID");
- }
- }
- case llvm::Triple::dxil: {
- switch (OpKind) {
+ case llvm::Triple::dxil:
+ return Intrinsic::dx_wave_bit_or;
- case 0:
- case 2: {
- llvm_unreachable("Not implemented yet!");
- }
- case 1: {
- return Intrinsic::dx_wave_bit_or;
- }
- default: {
- llvm_unreachable("Unexpected SubOp ID");
- }
- }
- }
default:
llvm_unreachable("Intrinsic WaveActiveBitOr"
" not supported by target architecture");
@@ -975,12 +945,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
}
case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
Value *Op = EmitScalarExpr(E->getArg(0));
- assert(Op->getType()->isIntegerTy() &&
- "Intrinsic WaveActiveBitOr operand must be an integer type");
-
- Intrinsic::ID IID = getWaveBitOpIntrinsic(
- /* OpKind */ 1, getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
- E->getArg(0)->getType());
+ llvm::Type *Ty = Op->getType();
+ assert(Ty->isIntegerTy() ||
+ (Ty->isVectorTy() && Ty->getScalarType()->isIntegerTy()) &&
+ "Intrinsic WaveActiveBitOr operand must be integer or "
+ "vector of integers");
+
+ Intrinsic::ID IID =
+ getWaveBitOpOrIntrinsic(getTarget().getTriple().getArch(),
+ CGM.getHLSLRuntime(), E->getArg(0)->getType());
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
&CGM.getModule(), IID, {Op->getType()}),
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index d5ec41bedd242..349c6a967e534 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3601,24 +3601,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
- // Ensure input expr type is a scalar/vector and then
- // set the return type to the arg type
QualType ArgType = TheCall->getArg(0)->getType();
- if (!(ArgType->isScalarType())) {
+ // Ensure input expr type is a scalar/vector
+ if (!ArgType->hasIntegerRepresentation()) {
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
- diag::err_typecheck_expect_any_scalar_or_vector)
- << ArgType << 0;
- return true;
- }
-
- if (!(ArgType->isIntegerType())) {
- SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
- diag::err_typecheck_expect_any_scalar_or_vector)
- << ArgType << 0;
+ diag::err_builtin_invalid_arg_type)
+ << 1 // %ordinal0: 1st argument
+ << 5 // %select1: scalar or vector of
+ << 1 // %select2: integer
+ << 0 // %select3: no floating-point
+ << TheCall->getArg(0)->getType();
return true;
}
+ // Set the return type to the arg type
TheCall->setType(ArgType);
break;
}
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
index 6b757ce60b389..f9966bc9ebf63 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
@@ -21,6 +21,39 @@ int test_int(int expr) {
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.i32([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.i32([[TY]]) #[[#attr:]]
+// CHECK-LABEL: test_int2
+int2 test_int2(int2 expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v2i32([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v2i32([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v2i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v2i32([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_int3
+int3 test_int3(int3 expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v3i32([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v3i32([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v3i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v3i32([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_int4
+int4 test_int4(int4 expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v4i32([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v4i32([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v4i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v4i32([[TY]]) #[[#attr:]]
+
// CHECK-LABEL: test_int16
int16_t test_int16_t(int16_t expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i16([[TY]] %[[#]])
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl
index 87ddd96e8368c..1fc93c62c8db0 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl
@@ -1,12 +1,9 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
-// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \
-// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \
-// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
-
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
-// RUN: spirv-pc-vulkan-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \
-// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \
-// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
// Test basic lowering to runtime function call.
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
index e3fd2eac28159..19c1f0ede4765 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
@@ -14,17 +14,17 @@ int test_too_many_arg(int p0) {
float test_expr_bool_type_check(float p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
- // expected-error at -1 {{invalid operand of type 'float'}}
+ // expected-error at -1 {{1st argument must be a scalar or vector of integer types (was 'float')}}
}
float2 test_expr_bool_vec_type_check(float2 p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
- // expected-error at -1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}}
+ // expected-error at -1 {{1st argument must be a scalar or vector of integer types (was 'float2' (aka 'vector<float, 2>'))}}
}
struct S { float f; };
S test_expr_struct_type_check(S p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
- // expected-error at -1 {{invalid operand of type 'S'}}
+ // expected-error at -1 {{1st argument must be a scalar or vector of integer types (was 'S')}}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 151f8da509735..8b0a06231ccca 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -224,8 +224,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
- bool selectWaveBitOr(Register ResVReg, const SPIRVType *ResType,
- MachineInstr &I) const;
+ bool selectWaveBitOpInst(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I, unsigned Opcode) const;
bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, bool IsUnsigned) const;
@@ -2718,31 +2718,25 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return Result;
}
-bool SPIRVInstructionSelector::selectWaveBitOr(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I) const {
-
- assert(I.getNumOperands() == 3);
- assert(I.getOperand(2).isReg());
+bool SPIRVInstructionSelector::selectWaveBitOpInst(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ unsigned Opcode) const {
MachineBasicBlock &BB = *I.getParent();
- Register InputRegister = I.getOperand(2).getReg();
- SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
- if (!InputType)
- report_fatal_error("Input Type could not be determined.");
- if (!GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeInt))
- report_fatal_error("WaveActiveBitOr requires integer input");
+ auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
+ IntTy, TII, !STI.isShader()));
+ BMI.addImm(SPIRV::GroupOperation::Reduce);
- SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+ for (unsigned J = 2; J < I.getNumOperands(); J++) {
+ BMI.addUse(I.getOperand(J).getReg());
+ }
- return BuildMI(BB, I, I.getDebugLoc(),
- TII.get(SPIRV::OpGroupNonUniformBitwiseOr))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
- .addImm(SPIRV::GroupOperation::Reduce)
- .addUse(InputRegister)
- .constrainAllUses(TII, TRI, RBI);
+ return BMI.constrainAllUses(TII, TRI, RBI);
}
bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
@@ -3896,7 +3890,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_wave_any:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
case Intrinsic::spv_wave_bit_or:
- return selectWaveBitOr(ResVReg, ResType, I);
+ return selectWaveBitOpInst(ResVReg, ResType, I,
+ SPIRV::OpGroupNonUniformBitwiseOr);
case Intrinsic::spv_subgroup_ballot:
return selectWaveOpInst(ResVReg, ResType, I,
SPIRV::OpGroupNonUniformBallot);
>From ac22964f653268326511b671715d77fe04aa8c7b Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Thu, 29 Jan 2026 18:13:44 -0800
Subject: [PATCH 6/7] clang-format
---
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c15fbd36a67fe..09353a127e033 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -226,7 +226,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectWaveBitOpInst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, unsigned Opcode) const;
-
+
bool selectWavePrefixBitCount(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
>From 998c1cb10a15ea65f797722d11aaae70352ff26b Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Thu, 29 Jan 2026 22:45:33 -0800
Subject: [PATCH 7/7] add missing stages + attributes
---
llvm/lib/Target/DirectX/DXIL.td | 2 ++
1 file changed, 2 insertions(+)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 5931a5218b811..43b4c47522cf2 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1144,6 +1144,8 @@ def WaveBitOp : DXILOp<120, waveBitOp> {
let overloads = [
Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>
];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, []>];
}
def WavePrefixBitCount : DXILOp<136, wavePrefixOp> {
More information about the cfe-commits
mailing list