[clang] [llvm] [HLSL] Add WaveActiveAllEqual functions (PR #183634)
Joshua Batista via cfe-commits
cfe-commits at lists.llvm.org
Mon Mar 2 14:49:39 PST 2026
https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/183634
>From 8f3b1eca648517042b010d92c380995212afd59b Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Mon, 23 Feb 2026 17:04:58 -0800
Subject: [PATCH 1/6] first attempt
---
clang/include/clang/Basic/Builtins.td | 6 +
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 7 +
clang/lib/CodeGen/CGHLSLRuntime.h | 27 +++-
.../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 124 ++++++++++++++++++
clang/lib/Sema/SemaHLSL.cpp | 13 ++
.../builtins/WaveActiveAllEqual.hlsl | 45 +++++++
.../BuiltIns/WaveActiveAllEqual-errors.hlsl | 28 ++++
.../BuiltIns/WaveActiveAllTrue-errors.hlsl | 49 ++++---
llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 +
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 +
llvm/lib/Target/DirectX/DXIL.td | 10 ++
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 2 +-
.../DirectX/DirectXTargetTransformInfo.cpp | 1 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 3 +
.../CodeGen/DirectX/ShaderFlags/wave-ops.ll | 7 +
.../CodeGen/DirectX/WaveActiveAllEqual.ll | 87 ++++++++++++
.../hlsl-intrinsics/WaveActiveAllEqual.ll | 41 ++++++
17 files changed, 426 insertions(+), 26 deletions(-)
create mode 100644 clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 78dd26aa2c455..c66c029900453 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5132,6 +5132,12 @@ def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLWaveActiveAllEqual : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_active_all_equal"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLWaveActiveAllTrue : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_all_true"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 70891eac39425..09dae2ab931ee 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -1088,6 +1088,13 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
+ case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
+ Value *Op = EmitScalarExpr(E->getArg(0));
+
+ Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllEqualIntrinsic();
+ return EmitRuntimeCall(
+ Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+ }
case Builtin::BI__builtin_hlsl_wave_active_all_true: {
Value *Op = EmitScalarExpr(E->getArg(0));
assert(Op->getType()->isIntegerTy(1) &&
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index dbbc887353cec..940e3bbae8df2 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -34,16 +34,33 @@
// A function generator macro for picking the right intrinsic
// for the target backend
-#define GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix) \
+#define _GEN_INTRIN_CHOOSER(_1, _2, _3, NAME, ...) NAME
+
+#define GENERATE_HLSL_INTRINSIC_FUNCTION(...) \
+ _GEN_INTRIN_CHOOSER(__VA_ARGS__, GENERATE_HLSL_INTRINSIC_FUNCTION3, \
+ GENERATE_HLSL_INTRINSIC_FUNCTION2, \
+ /* dummy to solve pre-C++20 errors */ ignored)( \
+ __VA_ARGS__)
+
+// 2-arg form: same postfix for both backends (uses the identity)
+#define GENERATE_HLSL_INTRINSIC_FUNCTION2(FunctionName, IntrinsicPostfix) \
+ llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \
+ llvm::Triple::ArchType Arch = getArch(); \
+ switch (Arch) {} \
+ }
+
+// 3-arg form: explicit SPIR-V postfix override (perfect for wave->subgroup)
+#define GENERATE_HLSL_INTRINSIC_FUNCTION3(FunctionName, DxilPostfix, \
+ SpirvPostfix) \
llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \
llvm::Triple::ArchType Arch = getArch(); \
switch (Arch) { \
case llvm::Triple::dxil: \
- return llvm::Intrinsic::dx_##IntrinsicPostfix; \
+ return llvm::Intrinsic::dx_##DxilPostfix; \
case llvm::Triple::spirv: \
- return llvm::Intrinsic::spv_##IntrinsicPostfix; \
+ return llvm::Intrinsic::spv_##SpirvPostfix; \
default: \
- llvm_unreachable("Intrinsic " #IntrinsicPostfix \
+ llvm_unreachable("Intrinsic " #DxilPostfix \
" not supported by target architecture"); \
} \
}
@@ -144,6 +161,8 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllEqual, wave_all_equal,
+ subgroup_all_equal)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveMax, wave_reduce_max)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 2543401bdfbf9..e4a9c5dc7b4a8 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2413,6 +2413,130 @@ float4 trunc(float4);
// Wave* builtins
//===----------------------------------------------------------------------===//
+/// \brief Evaluates a value for all active invocations in the group. The
+/// result is true if Value is equal for all active invocations in the
+/// group. Otherwise, the result is false.
+/// \param Value The value to compare with
+/// \return True if all values across all lanes are equal, false otherwise
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) half WaveActiveAllEqual(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) half2 WaveActiveAllEqual(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) half3 WaveActiveAllEqual(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) half4 WaveActiveAllEqual(half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int16_t WaveActiveAllEqual(int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int16_t2 WaveActiveAllEqual(int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int16_t3 WaveActiveAllEqual(int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int16_t4 WaveActiveAllEqual(int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint16_t WaveActiveAllEqual(uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint16_t2 WaveActiveAllEqual(uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint16_t3 WaveActiveAllEqual(uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint16_t4 WaveActiveAllEqual(uint16_t4);
+#endif
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int WaveActiveAllEqual(int);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int2 WaveActiveAllEqual(int2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int3 WaveActiveAllEqual(int3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int4 WaveActiveAllEqual(int4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint WaveActiveAllEqual(uint);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint2 WaveActiveAllEqual(uint2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint3 WaveActiveAllEqual(uint3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint4 WaveActiveAllEqual(uint4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int64_t WaveActiveAllEqual(int64_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int64_t2 WaveActiveAllEqual(int64_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int64_t3 WaveActiveAllEqual(int64_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int64_t4 WaveActiveAllEqual(int64_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint64_t WaveActiveAllEqual(uint64_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint64_t2 WaveActiveAllEqual(uint64_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint64_t3 WaveActiveAllEqual(uint64_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint64_t4 WaveActiveAllEqual(uint64_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) float WaveActiveAllEqual(float);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) float2 WaveActiveAllEqual(float2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) float3 WaveActiveAllEqual(float3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) float4 WaveActiveAllEqual(float4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) double WaveActiveAllEqual(double);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) double2 WaveActiveAllEqual(double2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) double3 WaveActiveAllEqual(double3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) double4 WaveActiveAllEqual(double4);
+
/// \brief Returns true if the expression is true in all active lanes in the
/// current wave.
///
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 802a1bdbccfdd..249d8dc58b866 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3809,6 +3809,19 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
TheCall->setType(ArgTyA);
break;
}
+ case Builtin::BI__builtin_hlsl_wave_active_all_true: {
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+
+ // Ensure input expr type is a scalar/vector
+ if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+ return true;
+
+ // set return type to bool
+ TheCall->setType(getASTContext().BoolTy);
+
+ break;
+ }
case Builtin::BI__builtin_hlsl_wave_active_max:
case Builtin::BI__builtin_hlsl_wave_active_min:
case Builtin::BI__builtin_hlsl_wave_active_sum: {
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
new file mode 100644
index 0000000000000..4b4149d05eb3f
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
@@ -0,0 +1,45 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+bool test_int(int expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all.equal.i32([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32([[TY]] %[[#]])
+ // CHECK: ret i1 %[[RET]]
+ return WaveActiveAllEqual(expr);
+}
+
+// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i32([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_uint64_t
+bool test_uint64_t(uint64_t expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all.equal.i64(i64 %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.uproduct.i64(i64 %[[#]])
+ // CHECK: ret i1 %[[RET]]
+ return WaveActiveAllEqual(expr);
+}
+
+// CHECK-DXIL: declare i1 @llvm.dx.wave.uproduct.i64(i64 #[[#attr:]]
+// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i64(i64) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+bool test_floatv4(float4 expr) {
+ // CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func i1 @llvm.spv.wave.all.equal.v4f32(i32 %[[#]]
+ // CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn i1 @llvm.dx.wave.all.equal.v4f32(i32 %[[#]])
+ // CHECK: ret [[TY1]] %[[RET1]]
+ return WaveActiveAllEqual(expr);
+}
+
+// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.v4f32(i32) #[[#attr]]
+// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.v4f32(i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
new file mode 100644
index 0000000000000..2c838cb51dd78
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+ return __builtin_hlsl_wave_active_all_equal();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return __builtin_hlsl_wave_active_all_equal(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool test_expr_bool_type_check(bool p0) {
+ return __builtin_hlsl_wave_active_all_equal(p0);
+ // expected-error at -1 {{invalid operand of type 'bool'}}
+}
+
+bool2 test_expr_bool_vec_type_check(bool2 p0) {
+ return __builtin_hlsl_wave_active_all_equal(p0);
+ // expected-error at -1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+ return __builtin_hlsl_wave_active_all_equal(p0);
+ // expected-error at -1 {{invalid operand of type 'S' where a scalar or vector is required}}
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
index b0d0fdfca5e18..af926d60624c6 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
@@ -1,21 +1,28 @@
-// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
-
-bool test_too_few_arg() {
- return __builtin_hlsl_wave_active_all_true();
- // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
-}
-
-bool test_too_many_arg(bool p0) {
- return __builtin_hlsl_wave_active_all_true(p0, p0);
- // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
-}
-
-struct Foo
-{
- int a;
-};
-
-bool test_type_check(Foo p0) {
- return __builtin_hlsl_wave_active_all_true(p0);
- // expected-error at -1 {{no viable conversion from 'Foo' to 'bool'}}
-}
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+ return __builtin_hlsl_wave_active_product();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return __builtin_hlsl_wave_active_product(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool test_expr_bool_type_check(bool p0) {
+ return __builtin_hlsl_wave_active_product(p0);
+ // expected-error at -1 {{invalid operand of type 'bool'}}
+}
+
+bool2 test_expr_bool_vec_type_check(bool2 p0) {
+ return __builtin_hlsl_wave_active_product(p0);
+ // expected-error at -1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+ return __builtin_hlsl_wave_active_product(p0);
+ // expected-error at -1 {{invalid operand of type 'S' where a scalar or vector is required}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 909482d72aa88..a688da131ce75 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -213,6 +213,7 @@ def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_
def int_dx_wave_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 77f49ae721ad5..59a9612d1ff50 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_subgroup_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+ def int_spv_subgroup_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_subgroup_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 59a5b7fe4d508..a378f0d665d44 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -217,6 +217,7 @@ defset list<DXILOpClass> OpClasses = {
def waveActiveOp : DXILOpClass;
def waveAllOp : DXILOpClass;
def waveAllTrue : DXILOpClass;
+ def waveAllEqual : DXILOpClass;
def waveAnyTrue : DXILOpClass;
def waveActiveBallot : DXILOpClass;
def waveGetLaneCount : DXILOpClass;
@@ -1062,6 +1063,15 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
+def WaveActiveAllEqual : DXILOp<115, waveAllEqual> {
+ let Doc = "returns true if the expression is equal in all of the active lanes "
+ "in the current wave";
+ let intrinsics = [IntrinSelect<int_dx_wave_all_equal>];
+ let arguments = [OverloadTy];
+ let result = Int1Ty;
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
def WaveActiveBallot : DXILOp<116, waveActiveBallot> {
let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave.";
let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 52993ee1c1220..1d14079407cbe 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -64,7 +64,6 @@ static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM,
static bool checkWaveOps(Intrinsic::ID IID) {
// Currently unsupported intrinsics
// case Intrinsic::dx_wave_getlanecount:
- // case Intrinsic::dx_wave_allequal:
// case Intrinsic::dx_wave_readfirst:
// case Intrinsic::dx_wave_reduce.and:
// case Intrinsic::dx_wave_reduce.or:
@@ -85,6 +84,7 @@ static bool checkWaveOps(Intrinsic::ID IID) {
case Intrinsic::dx_wave_is_first_lane:
case Intrinsic::dx_wave_getlaneindex:
case Intrinsic::dx_wave_any:
+ case Intrinsic::dx_wave_all_equal:
case Intrinsic::dx_wave_all:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_active_countbits:
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 8018b09c9f248..eca2343227577 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -57,6 +57,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_saturate:
case Intrinsic::dx_splitdouble:
+ case Intrinsic::dx_wave_all_equal:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_min:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3d3e311eeedb7..b9c6cb1e67595 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -4084,6 +4084,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWavePrefixBitCount(ResVReg, ResType, I);
case Intrinsic::spv_wave_active_countbits:
return selectWaveActiveCountBits(ResVReg, ResType, I);
+ case Intrinsic::spv_subgroup_all_equal:
+ return selectWaveOpInst(ResVReg, ResType, I,
+ SPIRV::OpGroupNonUniformAllEqual);
case Intrinsic::spv_wave_all:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
case Intrinsic::spv_wave_any:
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
index be53d19aca8f2..6c29ac73719e6 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
@@ -42,6 +42,13 @@ entry:
ret i1 %ret
}
+define noundef i1 @wave_all_equal(i1 %x) {
+entry:
+ ; CHECK: Function wave_all_equal : [[WAVE_FLAG]]
+ %ret = call i1 @llvm.dx.wave.all.equal(i1 %x)
+ ret i1 %ret
+}
+
define noundef i1 @wave_readlane(i1 %x, i32 %idx) {
entry:
; CHECK: Function wave_readlane : [[WAVE_FLAG]]
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
new file mode 100644
index 0000000000000..702f2ad1dde5f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
@@ -0,0 +1,87 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s
+
+; Test that for scalar values, WaveAcitveProduct maps down to the DirectX op
+
+define noundef half @wave_active_all_equal_half(half noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr, i8 1, i8 0)
+ %ret = call half @llvm.dx.wave.all.equal.f16(half %expr)
+ ret half %ret
+}
+
+define noundef float @wave_active_all_equal_float(float noundef %expr) {
+entry:
+; CHECK: call float @dx.op.waveActiveAllEqual.f32(i32 119, float %expr, i8 1, i8 0)
+ %ret = call float @llvm.dx.wave.all.equal.f32(float %expr)
+ ret float %ret
+}
+
+define noundef double @wave_active_all_equal_double(double noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr, i8 1, i8 0)
+ %ret = call double @llvm.dx.wave.all.equal.f64(double %expr)
+ ret double %ret
+}
+
+define noundef i16 @wave_active_all_equal_i16(i16 noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveAllEqual.i16(i32 119, i16 %expr, i8 1, i8 0)
+ %ret = call i16 @llvm.dx.wave.all.equal.i16(i16 %expr)
+ ret i16 %ret
+}
+
+define noundef i32 @wave_active_all_equal_i32(i32 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr, i8 1, i8 0)
+ %ret = call i32 @llvm.dx.wave.all.equal.i32(i32 %expr)
+ ret i32 %ret
+}
+
+define noundef i64 @wave_active_all_equal_i64(i64 noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveAllEqual.i64(i32 119, i64 %expr, i8 1, i8 0)
+ %ret = call i64 @llvm.dx.wave.all.equal.i64(i64 %expr)
+ ret i64 %ret
+}
+
+declare half @llvm.dx.wave.all.equal.f16(half)
+declare float @llvm.dx.wave.all.equal.f32(float)
+declare double @llvm.dx.wave.all.equal.f64(double)
+
+declare i16 @llvm.dx.wave.all.equal.i16(i16)
+declare i32 @llvm.dx.wave.all.equal.i32(i32)
+declare i64 @llvm.dx.wave.all.equal.i64(i64)
+
+; Test that for vector values, WaveAcitveProduct scalarizes and maps down to the
+; DirectX op
+
+define noundef <2 x half> @wave_active_all_equal_v2half(<2 x half> noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i0, i8 1, i8 0)
+; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i1, i8 1, i8 0)
+ %ret = call <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr)
+ ret <2 x half> %ret
+}
+
+define noundef <3 x i32> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i0, i8 1, i8 0)
+; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i1, i8 1, i8 0)
+; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i2, i8 1, i8 0)
+ %ret = call <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr)
+ ret <3 x i32> %ret
+}
+
+define noundef <4 x double> @wave_active_all_equal_v4f64(<4 x double> noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i0, i8 1, i8 0)
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i1, i8 1, i8 0)
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i2, i8 1, i8 0)
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i3, i8 1, i8 0)
+ %ret = call <4 x double> @llvm.dx.wave.all.equal.v464(<4 x double> %expr)
+ ret <4 x double> %ret
+}
+
+declare <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half>)
+declare <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32>)
+declare <4 x double> @llvm.dx.wave.all.equal.v4f64(<4 x double>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
new file mode 100644
index 0000000000000..e871dc9a7aa28
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
@@ -0,0 +1,41 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend for various types and scalar/vector
+
+; CHECK-DAG: %[[#f16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v4_half:]] = OpTypeVector %[[#f16]] 4
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_float
+; CHECK: %[[#fexpr:]] = OpFunctionParameter %[[#f32]]
+define i1 @test_float(float %fexpr) {
+entry:
+; CHECK: %[[#fret:]] = OpGroupNonUniformAllEqual %[[#f32]] %[[#scope]] Reduce %[[#fexpr]]
+ %0 = call i1 @llvm.spv.wave.all.equal.f32(float %fexpr)
+ ret i1 %0
+}
+
+; CHECK-LABEL: Begin function test_int
+; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]]
+define i1 @test_int(i32 %iexpr) {
+entry:
+; CHECK: %[[#iret:]] = OpGroupNonUniformAllEqual %[[#uint]] %[[#scope]] Reduce %[[#iexpr]]
+ %0 = call i1 @llvm.spv.wave.all.equal.i32(i32 %iexpr)
+ ret i1 %0
+}
+
+; CHECK-LABEL: Begin function test_vhalf
+; CHECK: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]]
+define i1 @test_vhalf(<4 x half> %vbexpr) {
+entry:
+; CHECK: %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#v4_half]] %[[#scope]] Reduce %[[#vbexpr]]
+ %0 = call i1 @llvm.spv.wave.all.equal.v4half(<4 x half> %vbexpr)
+ ret i1 %0
+}
+
+declare i1 @llvm.spv.wave.all.equal.f32(float)
+declare i1 @llvm.spv.wave.all.equal.i32(i32)
+declare i1 @llvm.spv.wave.all.equal.v4half(<4 x half>)
>From c523c88934ae38d62d4b65729f4374e6f9eb617e Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 27 Feb 2026 13:39:08 -0800
Subject: [PATCH 2/6] fix return type, self review
---
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 5 +-
clang/lib/CodeGen/CGHLSLRuntime.h | 10 +-
.../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 72 +++++-----
clang/lib/Sema/SemaHLSL.cpp | 18 ++-
.../builtins/WaveActiveAllEqual.hlsl | 30 ++---
.../BuiltIns/WaveActiveAllEqual-errors.hlsl | 16 +--
.../BuiltIns/WaveActiveAllTrue-errors.hlsl | 31 ++---
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +-
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 2 +-
llvm/lib/Target/DirectX/DXIL.td | 7 +-
.../DirectX/DirectXTargetTransformInfo.cpp | 3 +
.../CodeGen/DirectX/WaveActiveAllEqual.ll | 124 +++++++++++-------
.../hlsl-intrinsics/WaveActiveAllEqual.ll | 24 ++--
13 files changed, 187 insertions(+), 157 deletions(-)
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 09dae2ab931ee..47b7e2b18d942 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -1092,8 +1092,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *Op = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllEqualIntrinsic();
- return EmitRuntimeCall(
- Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+ return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
+ &CGM.getModule(), ID, {Op->getType()}),
+ {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_all_true: {
Value *Op = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 940e3bbae8df2..d6055d89e3c84 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -46,7 +46,15 @@
#define GENERATE_HLSL_INTRINSIC_FUNCTION2(FunctionName, IntrinsicPostfix) \
llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \
llvm::Triple::ArchType Arch = getArch(); \
- switch (Arch) {} \
+ switch (Arch) { \
+ case llvm::Triple::dxil: \
+ return llvm::Intrinsic::dx_##IntrinsicPostfix; \
+ case llvm::Triple::spirv: \
+ return llvm::Intrinsic::spv_##IntrinsicPostfix; \
+ default: \
+ llvm_unreachable("Intrinsic " #IntrinsicPostfix \
+ " not supported by target architecture"); \
+ } \
}
// 3-arg form: explicit SPIR-V postfix override (perfect for wave->subgroup)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index e4a9c5dc7b4a8..5ca4713c2d520 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2420,122 +2420,122 @@ float4 trunc(float4);
/// \return True if all values across all lanes are equal, false otherwise
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) half WaveActiveAllEqual(half);
+__attribute__((convergent)) bool WaveActiveAllEqual(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) half2 WaveActiveAllEqual(half2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) half3 WaveActiveAllEqual(half3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) half4 WaveActiveAllEqual(half4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(half4);
#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int16_t WaveActiveAllEqual(int16_t);
+__attribute__((convergent)) bool WaveActiveAllEqual(int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int16_t2 WaveActiveAllEqual(int16_t2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int16_t3 WaveActiveAllEqual(int16_t3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int16_t4 WaveActiveAllEqual(int16_t4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(int16_t4);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint16_t WaveActiveAllEqual(uint16_t);
+__attribute__((convergent)) bool WaveActiveAllEqual(uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint16_t2 WaveActiveAllEqual(uint16_t2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint16_t3 WaveActiveAllEqual(uint16_t3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint16_t4 WaveActiveAllEqual(uint16_t4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(uint16_t4);
#endif
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int WaveActiveAllEqual(int);
+__attribute__((convergent)) bool WaveActiveAllEqual(int);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int2 WaveActiveAllEqual(int2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(int2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int3 WaveActiveAllEqual(int3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(int3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int4 WaveActiveAllEqual(int4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(int4);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint WaveActiveAllEqual(uint);
+__attribute__((convergent)) bool WaveActiveAllEqual(uint);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint2 WaveActiveAllEqual(uint2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(uint2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint3 WaveActiveAllEqual(uint3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(uint3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint4 WaveActiveAllEqual(uint4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(uint4);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int64_t WaveActiveAllEqual(int64_t);
+__attribute__((convergent)) bool WaveActiveAllEqual(int64_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int64_t2 WaveActiveAllEqual(int64_t2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(int64_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int64_t3 WaveActiveAllEqual(int64_t3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(int64_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int64_t4 WaveActiveAllEqual(int64_t4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(int64_t4);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint64_t WaveActiveAllEqual(uint64_t);
+__attribute__((convergent)) bool WaveActiveAllEqual(uint64_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint64_t2 WaveActiveAllEqual(uint64_t2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(uint64_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint64_t3 WaveActiveAllEqual(uint64_t3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(uint64_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint64_t4 WaveActiveAllEqual(uint64_t4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(uint64_t4);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) float WaveActiveAllEqual(float);
+__attribute__((convergent)) bool WaveActiveAllEqual(float);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) float2 WaveActiveAllEqual(float2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(float2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) float3 WaveActiveAllEqual(float3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(float3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) float4 WaveActiveAllEqual(float4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(float4);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) double WaveActiveAllEqual(double);
+__attribute__((convergent)) bool WaveActiveAllEqual(double);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) double2 WaveActiveAllEqual(double2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(double2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) double3 WaveActiveAllEqual(double3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(double3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) double4 WaveActiveAllEqual(double4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(double4);
/// \brief Returns true if the expression is true in all active lanes in the
/// current wave.
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 249d8dc58b866..46cc3835c85a8 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3809,7 +3809,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
TheCall->setType(ArgTyA);
break;
}
- case Builtin::BI__builtin_hlsl_wave_active_all_true: {
+ case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
@@ -3817,9 +3817,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
return true;
- // set return type to bool
- TheCall->setType(getASTContext().BoolTy);
+ QualType InputTy = TheCall->getArg(0)->getType();
+ ASTContext &Ctx = getASTContext();
+ QualType RetTy;
+
+ // If vector, construct bool vector of same size
+ if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) {
+ unsigned NumElts = VecTy->getNumElements();
+ RetTy = Ctx.getExtVectorType(Ctx.BoolTy, NumElts);
+ } else {
+ // Scalar case
+ RetTy = Ctx.BoolTy;
+ }
+
+ TheCall->setType(RetTy);
break;
}
case Builtin::BI__builtin_hlsl_wave_active_max:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
index 4b4149d05eb3f..65d15633eb6cf 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
@@ -9,37 +9,37 @@
// CHECK-LABEL: test_int
bool test_int(int expr) {
- // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all.equal.i32([[TY]] %[[#]])
- // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32([[TY]] %[[#]])
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.subgroup.all.equal.i32(i32
+ // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32(i32
// CHECK: ret i1 %[[RET]]
return WaveActiveAllEqual(expr);
}
-// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32([[TY]]) #[[#attr:]]
-// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i32([[TY]]) #[[#attr:]]
+// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32(i32) #[[attr:.*]]
+// CHECK-SPIRV: declare i1 @llvm.spv.subgroup.all.equal.i32(i32) #[[attr:.*]]
// CHECK-LABEL: test_uint64_t
bool test_uint64_t(uint64_t expr) {
- // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all.equal.i64(i64 %[[#]])
- // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.uproduct.i64(i64 %[[#]])
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.subgroup.all.equal.i64(i64
+ // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i64(i64
// CHECK: ret i1 %[[RET]]
return WaveActiveAllEqual(expr);
}
-// CHECK-DXIL: declare i1 @llvm.dx.wave.uproduct.i64(i64 #[[#attr:]]
-// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i64(i64) #[[#attr:]]
+// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i64(i64) #[[attr]]
+// CHECK-SPIRV: declare i1 @llvm.spv.subgroup.all.equal.i64(i64) #[[attr]]
// Test basic lowering to runtime function call with array and float value.
// CHECK-LABEL: test_floatv4
-bool test_floatv4(float4 expr) {
- // CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func i1 @llvm.spv.wave.all.equal.v4f32(i32 %[[#]]
- // CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn i1 @llvm.dx.wave.all.equal.v4f32(i32 %[[#]])
- // CHECK: ret [[TY1]] %[[RET1]]
+bool4 test_floatv4(float4 expr) {
+ // CHECK-SPIRV: %[[RET1:.*]] = call spir_func <4 x i1> @llvm.spv.subgroup.all.equal.v4f32(<4 x float>
+ // CHECK-DXIL: %[[RET1:.*]] = call <4 x i1> @llvm.dx.wave.all.equal.v4f32(<4 x float>
+ // CHECK: ret <4 x i1> %[[RET1]]
return WaveActiveAllEqual(expr);
}
-// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.v4f32(i32) #[[#attr]]
-// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.v4f32(i32) #[[#attr]]
+// CHECK-DXIL: declare <4 x i1> @llvm.dx.wave.all.equal.v4f32(<4 x float>) #[[attr]]
+// CHECK-SPIRV: declare <4 x i1> @llvm.spv.subgroup.all.equal.v4f32(<4 x float>) #[[attr]]
-// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
+// CHECK: attributes #[[attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
index 2c838cb51dd78..1b5d7955baffc 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
@@ -1,28 +1,18 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
-int test_too_few_arg() {
+bool test_too_few_arg() {
return __builtin_hlsl_wave_active_all_equal();
// expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
}
-float2 test_too_many_arg(float2 p0) {
+bool test_too_many_arg(float2 p0) {
return __builtin_hlsl_wave_active_all_equal(p0, p0);
// expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
}
-bool test_expr_bool_type_check(bool p0) {
- return __builtin_hlsl_wave_active_all_equal(p0);
- // expected-error at -1 {{invalid operand of type 'bool'}}
-}
-
-bool2 test_expr_bool_vec_type_check(bool2 p0) {
- return __builtin_hlsl_wave_active_all_equal(p0);
- // expected-error at -1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
-}
-
struct S { float f; };
-S test_expr_struct_type_check(S p0) {
+bool test_expr_struct_type_check(S p0) {
return __builtin_hlsl_wave_active_all_equal(p0);
// expected-error at -1 {{invalid operand of type 'S' where a scalar or vector is required}}
}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
index af926d60624c6..0975ad649e714 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
@@ -1,28 +1,21 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
-int test_too_few_arg() {
- return __builtin_hlsl_wave_active_product();
+bool test_too_few_arg() {
+ return __builtin_hlsl_wave_active_all_true();
// expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
}
-float2 test_too_many_arg(float2 p0) {
- return __builtin_hlsl_wave_active_product(p0, p0);
+bool test_too_many_arg(bool p0) {
+ return __builtin_hlsl_wave_active_all_true(p0, p0);
// expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
}
-bool test_expr_bool_type_check(bool p0) {
- return __builtin_hlsl_wave_active_product(p0);
- // expected-error at -1 {{invalid operand of type 'bool'}}
-}
-
-bool2 test_expr_bool_vec_type_check(bool2 p0) {
- return __builtin_hlsl_wave_active_product(p0);
- // expected-error at -1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
-}
+struct Foo
+{
+ int a;
+};
-struct S { float f; };
-
-S test_expr_struct_type_check(S p0) {
- return __builtin_hlsl_wave_active_product(p0);
- // expected-error at -1 {{invalid operand of type 'S' where a scalar or vector is required}}
-}
+bool test_type_check(Foo p0) {
+ return __builtin_hlsl_wave_active_all_true(p0);
+ // expected-error at -1 {{no viable conversion from 'Foo' to 'bool'}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index a688da131ce75..6774a33556c09 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -213,7 +213,7 @@ def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_
def int_dx_wave_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
-def int_dx_wave_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_all_equal : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], [llvm_any_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 59a9612d1ff50..b91905f350506 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,7 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_subgroup_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
- def int_spv_subgroup_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrConvergent, IntrNoMem]>;
+ def int_spv_subgroup_all_equal : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], [llvm_any_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_subgroup_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index a378f0d665d44..e64909b059d29 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -217,7 +217,6 @@ defset list<DXILOpClass> OpClasses = {
def waveActiveOp : DXILOpClass;
def waveAllOp : DXILOpClass;
def waveAllTrue : DXILOpClass;
- def waveAllEqual : DXILOpClass;
def waveAnyTrue : DXILOpClass;
def waveActiveBallot : DXILOpClass;
def waveGetLaneCount : DXILOpClass;
@@ -1063,9 +1062,9 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
-def WaveActiveAllEqual : DXILOp<115, waveAllEqual> {
- let Doc = "returns true if the expression is equal in all of the active lanes "
- "in the current wave";
+def WaveActiveAllEqual : DXILOp<115, waveActiveAllEqual> {
+ let Doc = "returns true for each scalar element of the expression if the "
+ "expression is equal in all of the active lanes in the current wave";
let intrinsics = [IntrinSelect<int_dx_wave_all_equal>];
let arguments = [OverloadTy];
let result = Int1Ty;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index eca2343227577..a2d7ffefbb5a2 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -36,8 +36,11 @@ bool DirectXTTIImpl::isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
case Intrinsic::dx_isnan:
case Intrinsic::dx_legacyf16tof32:
case Intrinsic::dx_legacyf32tof16:
+ case Intrinsic::dx_wave_all_equal:
return OpdIdx == 0;
default:
+ // All DX intrinsics are overloaded on return type unless specified
+ // otherwise
return OpdIdx == -1;
}
}
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
index 702f2ad1dde5f..f6dcd59c33958 100644
--- a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
+++ b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
@@ -2,86 +2,108 @@
; Test that for scalar values, WaveAcitveProduct maps down to the DirectX op
-define noundef half @wave_active_all_equal_half(half noundef %expr) {
+define noundef i1 @wave_active_all_equal_half(half noundef %expr) {
entry:
-; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr, i8 1, i8 0)
- %ret = call half @llvm.dx.wave.all.equal.f16(half %expr)
- ret half %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half %expr)
+ %ret = call i1 @llvm.dx.wave.all.equal.f16(half %expr)
+ ret i1 %ret
}
-define noundef float @wave_active_all_equal_float(float noundef %expr) {
+define noundef i1 @wave_active_all_equal_float(float noundef %expr) {
entry:
-; CHECK: call float @dx.op.waveActiveAllEqual.f32(i32 119, float %expr, i8 1, i8 0)
- %ret = call float @llvm.dx.wave.all.equal.f32(float %expr)
- ret float %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.f32(i32 115, float %expr)
+ %ret = call i1 @llvm.dx.wave.all.equal.f32(float %expr)
+ ret i1 %ret
}
-define noundef double @wave_active_all_equal_double(double noundef %expr) {
+define noundef i1 @wave_active_all_equal_double(double noundef %expr) {
entry:
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr, i8 1, i8 0)
- %ret = call double @llvm.dx.wave.all.equal.f64(double %expr)
- ret double %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %expr)
+ %ret = call i1 @llvm.dx.wave.all.equal.f64(double %expr)
+ ret i1 %ret
}
-define noundef i16 @wave_active_all_equal_i16(i16 noundef %expr) {
+define noundef i1 @wave_active_all_equal_i16(i16 noundef %expr) {
entry:
-; CHECK: call i16 @dx.op.waveActiveAllEqual.i16(i32 119, i16 %expr, i8 1, i8 0)
- %ret = call i16 @llvm.dx.wave.all.equal.i16(i16 %expr)
- ret i16 %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.i16(i32 115, i16 %expr)
+ %ret = call i1 @llvm.dx.wave.all.equal.i16(i16 %expr)
+ ret i1 %ret
}
-define noundef i32 @wave_active_all_equal_i32(i32 noundef %expr) {
+define noundef i1 @wave_active_all_equal_i32(i32 noundef %expr) {
entry:
-; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr, i8 1, i8 0)
- %ret = call i32 @llvm.dx.wave.all.equal.i32(i32 %expr)
- ret i32 %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %expr)
+ %ret = call i1 @llvm.dx.wave.all.equal.i32(i32 %expr)
+ ret i1 %ret
}
-define noundef i64 @wave_active_all_equal_i64(i64 noundef %expr) {
+define noundef i1 @wave_active_all_equal_i64(i64 noundef %expr) {
entry:
-; CHECK: call i64 @dx.op.waveActiveAllEqual.i64(i32 119, i64 %expr, i8 1, i8 0)
- %ret = call i64 @llvm.dx.wave.all.equal.i64(i64 %expr)
- ret i64 %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.i64(i32 115, i64 %expr)
+ %ret = call i1 @llvm.dx.wave.all.equal.i64(i64 %expr)
+ ret i1 %ret
}
-declare half @llvm.dx.wave.all.equal.f16(half)
-declare float @llvm.dx.wave.all.equal.f32(float)
-declare double @llvm.dx.wave.all.equal.f64(double)
+declare i1 @llvm.dx.wave.all.equal.f16(half)
+declare i1 @llvm.dx.wave.all.equal.f32(float)
+declare i1 @llvm.dx.wave.all.equal.f64(double)
-declare i16 @llvm.dx.wave.all.equal.i16(i16)
-declare i32 @llvm.dx.wave.all.equal.i32(i32)
-declare i64 @llvm.dx.wave.all.equal.i64(i64)
+declare i1 @llvm.dx.wave.all.equal.i16(i16)
+declare i1 @llvm.dx.wave.all.equal.i32(i32)
+declare i1 @llvm.dx.wave.all.equal.i64(i64)
; Test that for vector values, WaveAcitveProduct scalarizes and maps down to the
; DirectX op
-define noundef <2 x half> @wave_active_all_equal_v2half(<2 x half> noundef %expr) {
+define noundef <2 x i1> @wave_active_all_equal_v2half(<2 x half> noundef %expr) {
entry:
-; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i0, i8 1, i8 0)
-; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i1, i8 1, i8 0)
- %ret = call <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr)
- ret <2 x half> %ret
+; CHECK: %[[EXPR0:.*]] = extractelement <2 x half> %expr, i64 0
+; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half %[[EXPR0]])
+; CHECK: %[[EXPR1:.*]] = extractelement <2 x half> %expr, i64 1
+; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half %[[EXPR1]])
+; CHECK: %[[RETUPTO0:.*]] = insertelement <2 x i1> poison, i1 %[[RET0]], i64 0
+; CHECK: %ret = insertelement <2 x i1> %[[RETUPTO0]], i1 %[[RET1]], i64 1
+; CHECK: ret <2 x i1> %ret
+
+ %ret = call <2 x i1> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr)
+ ret <2 x i1> %ret
}
-define noundef <3 x i32> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) {
+define noundef <3 x i1> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) {
entry:
-; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i0, i8 1, i8 0)
-; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i1, i8 1, i8 0)
-; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i2, i8 1, i8 0)
- %ret = call <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr)
- ret <3 x i32> %ret
+; CHECK: %[[EXPR0:.*]] = extractelement <3 x i32> %expr, i64 0
+; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %[[EXPR0]])
+; CHECK: %[[EXPR1:.*]] = extractelement <3 x i32> %expr, i64 1
+; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %[[EXPR1]])
+; CHECK: %[[EXPR2:.*]] = extractelement <3 x i32> %expr, i64 2
+; CHECK: %[[RET2:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %[[EXPR2]])
+; CHECK: %[[RETUPTO0:.*]] = insertelement <3 x i1> poison, i1 %[[RET0]], i64 0
+; CHECK: %[[RETUPTO1:.*]] = insertelement <3 x i1> %[[RETUPTO0]], i1 %[[RET1]], i64 1
+; CHECK: %ret = insertelement <3 x i1> %[[RETUPTO1]], i1 %[[RET2]], i64 2
+
+ %ret = call <3 x i1> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr)
+ ret <3 x i1> %ret
}
-define noundef <4 x double> @wave_active_all_equal_v4f64(<4 x double> noundef %expr) {
+define noundef <4 x i1> @wave_active_all_equal_v4f64(<4 x double> noundef %expr) {
entry:
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i0, i8 1, i8 0)
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i1, i8 1, i8 0)
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i2, i8 1, i8 0)
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i3, i8 1, i8 0)
- %ret = call <4 x double> @llvm.dx.wave.all.equal.v464(<4 x double> %expr)
- ret <4 x double> %ret
+; CHECK: %[[EXPR0:.*]] = extractelement <4 x double> %expr, i64 0
+; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %[[EXPR0]])
+; CHECK: %[[EXPR1:.*]] = extractelement <4 x double> %expr, i64 1
+; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %[[EXPR1]])
+; CHECK: %[[EXPR2:.*]] = extractelement <4 x double> %expr, i64 2
+; CHECK: %[[RET2:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %[[EXPR2]])
+; CHECK: %[[EXPR3:.*]] = extractelement <4 x double> %expr, i64 3
+; CHECK: %[[RET3:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %[[EXPR3]])
+; CHECK: %[[RETUPTO0:.*]] = insertelement <4 x i1> poison, i1 %[[RET0]], i64 0
+; CHECK: %[[RETUPTO1:.*]] = insertelement <4 x i1> %[[RETUPTO0]], i1 %[[RET1]], i64 1
+; CHECK: %[[RETUPTO2:.*]] = insertelement <4 x i1> %[[RETUPTO1]], i1 %[[RET2]], i64 2
+; CHECK: %ret = insertelement <4 x i1> %[[RETUPTO2]], i1 %[[RET3]], i64 3
+
+ %ret = call <4 x i1> @llvm.dx.wave.all.equal.v464(<4 x double> %expr)
+ ret <4 x i1> %ret
}
-declare <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half>)
-declare <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32>)
-declare <4 x double> @llvm.dx.wave.all.equal.v4f64(<4 x double>)
+declare <2 x i1> @llvm.dx.wave.all.equal.v2f16(<2 x half>)
+declare <3 x i1> @llvm.dx.wave.all.equal.v3i32(<3 x i32>)
+declare <4 x i1> @llvm.dx.wave.all.equal.v4f64(<4 x double>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
index e871dc9a7aa28..c64e5770b2d6d 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
@@ -5,6 +5,8 @@
; CHECK-DAG: %[[#f16:]] = OpTypeFloat 16
; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#bool4:]] = OpTypeVector %[[#bool]] 4
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#v4_half:]] = OpTypeVector %[[#f16]] 4
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
@@ -13,8 +15,8 @@
; CHECK: %[[#fexpr:]] = OpFunctionParameter %[[#f32]]
define i1 @test_float(float %fexpr) {
entry:
-; CHECK: %[[#fret:]] = OpGroupNonUniformAllEqual %[[#f32]] %[[#scope]] Reduce %[[#fexpr]]
- %0 = call i1 @llvm.spv.wave.all.equal.f32(float %fexpr)
+; CHECK: %[[#fret:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#fexpr]]
+ %0 = call i1 @llvm.spv.subgroup.all.equal.f32(float %fexpr)
ret i1 %0
}
@@ -22,20 +24,20 @@ entry:
; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]]
define i1 @test_int(i32 %iexpr) {
entry:
-; CHECK: %[[#iret:]] = OpGroupNonUniformAllEqual %[[#uint]] %[[#scope]] Reduce %[[#iexpr]]
- %0 = call i1 @llvm.spv.wave.all.equal.i32(i32 %iexpr)
+; CHECK: %[[#iret:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#iexpr]]
+ %0 = call i1 @llvm.spv.subgroup.all.equal.i32(i32 %iexpr)
ret i1 %0
}
; CHECK-LABEL: Begin function test_vhalf
; CHECK: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]]
-define i1 @test_vhalf(<4 x half> %vbexpr) {
+define <4 x i1> @test_vhalf(<4 x half> %vbexpr) {
entry:
-; CHECK: %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#v4_half]] %[[#scope]] Reduce %[[#vbexpr]]
- %0 = call i1 @llvm.spv.wave.all.equal.v4half(<4 x half> %vbexpr)
- ret i1 %0
+; CHECK: %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#bool4]] %[[#scope]] %[[#vbexpr]]
+ %0 = call <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half> %vbexpr)
+ ret <4 x i1> %0
}
-declare i1 @llvm.spv.wave.all.equal.f32(float)
-declare i1 @llvm.spv.wave.all.equal.i32(i32)
-declare i1 @llvm.spv.wave.all.equal.v4half(<4 x half>)
+declare i1 @llvm.spv.subgroup.all.equal.f32(float)
+declare i1 @llvm.spv.subgroup.all.equal.i32(i32)
+declare <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half>)
>From c78d96eecf4435cae10014da20025c8ceef4d384 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Fri, 27 Feb 2026 13:41:26 -0800
Subject: [PATCH 3/6] revert file changes
---
.../BuiltIns/WaveActiveAllTrue-errors.hlsl | 42 +++++++++----------
1 file changed, 21 insertions(+), 21 deletions(-)
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
index 0975ad649e714..b0d0fdfca5e18 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
@@ -1,21 +1,21 @@
-// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
-
-bool test_too_few_arg() {
- return __builtin_hlsl_wave_active_all_true();
- // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
-}
-
-bool test_too_many_arg(bool p0) {
- return __builtin_hlsl_wave_active_all_true(p0, p0);
- // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
-}
-
-struct Foo
-{
- int a;
-};
-
-bool test_type_check(Foo p0) {
- return __builtin_hlsl_wave_active_all_true(p0);
- // expected-error at -1 {{no viable conversion from 'Foo' to 'bool'}}
-}
\ No newline at end of file
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_wave_active_all_true();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+bool test_too_many_arg(bool p0) {
+ return __builtin_hlsl_wave_active_all_true(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+struct Foo
+{
+ int a;
+};
+
+bool test_type_check(Foo p0) {
+ return __builtin_hlsl_wave_active_all_true(p0);
+ // expected-error at -1 {{no viable conversion from 'Foo' to 'bool'}}
+}
>From df45116b4e350d1bfa36f7f3d737cca9fb0b8879 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Mon, 2 Mar 2026 12:51:59 -0800
Subject: [PATCH 4/6] perform manual scalarization
---
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 96 ++++++++++++++++++-
.../hlsl-intrinsics/WaveActiveAllEqual.ll | 15 ++-
2 files changed, 107 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b9c6cb1e67595..3794092703470 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -337,6 +337,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectWaveActiveCountBits(Register ResVReg, SPIRVTypeInst ResType,
MachineInstr &I) const;
+ bool selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst ResType,
+ MachineInstr &I) const;
+
bool selectUnmergeValues(MachineInstr &I) const;
bool selectHandleFromBinding(Register &ResVReg, SPIRVTypeInst ResType,
@@ -2830,6 +2833,96 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return true;
}
+unsigned getVectorSizeOrOne(SPIRVTypeInst Type) {
+
+ if (Type->getOpcode() != SPIRV::OpTypeVector)
+ return 1;
+
+ // Operand(2) is the vector size
+ return Type->getOperand(2).getImm();
+}
+
+bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
+ SPIRVTypeInst ResType,
+ MachineInstr &I) const {
+
+ MachineBasicBlock &BB = *I.getParent();
+ const DebugLoc &DL = I.getDebugLoc();
+
+ SPIRVTypeInst SpvTy = GR.getSPIRVTypeForVReg(ResVReg);
+ unsigned NumElems = getVectorSizeOrOne(SpvTy);
+ bool IsVector = NumElems > 1;
+
+ // Subgroup scope constant
+ SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+
+ Register ScopeConst = GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
+ TII, !STI.isShader());
+
+ Register InputReg = I.getOperand(2).getReg();
+
+ SmallVector<Register, 4> ElementResults;
+
+ // If vector, determine element type once
+ SPIRVTypeInst ElemInputType = SpvTy;
+ SPIRVTypeInst ElemBoolType = ResType;
+
+ if (IsVector) {
+ Register ElemTypeReg = SpvTy->getOperand(1).getReg();
+ ElemInputType = GR.getSPIRVTypeForVReg(ElemTypeReg);
+
+ Register BoolElemReg = ResType->getOperand(1).getReg();
+ ElemBoolType = GR.getSPIRVTypeForVReg(BoolElemReg);
+ }
+
+ for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
+
+ Register ElemInput = InputReg;
+
+ if (IsVector) {
+ Register Extracted =
+ MRI->createVirtualRegister(GR.getRegClass(ElemInputType));
+
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract))
+ .addDef(Extracted)
+ .addUse(GR.getSPIRVTypeID(ElemInputType))
+ .addUse(InputReg)
+ .addImm(Idx)
+ .constrainAllUses(TII, TRI, RBI);
+
+ ElemInput = Extracted;
+ }
+
+ Register ElemResult =
+ IsVector ? MRI->createVirtualRegister(GR.getRegClass(ElemBoolType))
+ : ResVReg;
+
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual))
+ .addDef(ElemResult)
+ .addUse(GR.getSPIRVTypeID(ElemBoolType))
+ .addUse(ScopeConst)
+ .addUse(ElemInput)
+ .constrainAllUses(TII, TRI, RBI);
+
+ ElementResults.push_back(ElemResult);
+ }
+
+ if (!IsVector)
+ return true;
+
+ auto MIB = BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeConstruct))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType));
+
+ for (Register R : ElementResults)
+ MIB.addUse(R);
+
+ MIB.constrainAllUses(TII, TRI, RBI);
+
+ return true;
+}
+
+
bool SPIRVInstructionSelector::selectWavePrefixBitCount(Register ResVReg,
SPIRVTypeInst ResType,
MachineInstr &I) const {
@@ -4085,8 +4178,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_wave_active_countbits:
return selectWaveActiveCountBits(ResVReg, ResType, I);
case Intrinsic::spv_subgroup_all_equal:
- return selectWaveOpInst(ResVReg, ResType, I,
- SPIRV::OpGroupNonUniformAllEqual);
+ return selectWaveActiveAllEqual(ResVReg, ResType, I);
case Intrinsic::spv_wave_all:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
case Intrinsic::spv_wave_any:
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
index c64e5770b2d6d..9c63539e4b9e4 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
@@ -30,10 +30,21 @@ entry:
}
; CHECK-LABEL: Begin function test_vhalf
-; CHECK: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]]
+; Here there's a vector, so we scalarize and then recombine the
+; result back into one vector
define <4 x i1> @test_vhalf(<4 x half> %vbexpr) {
entry:
-; CHECK: %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#bool4]] %[[#scope]] %[[#vbexpr]]
+; CHECK: %[[#param:]] = OpFunctionParameter %[[#v4float:]]
+; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#bool]] %[[#param]] 0
+; CHECK-NEXT: %[[#res1:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext1]]
+; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#bool]] %[[#param]] 1
+; CHECK-NEXT: %[[#res2:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext2]]
+; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#bool]] %[[#param]] 2
+; CHECK-NEXT: %[[#res3:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext3]]
+; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#bool]] %[[#param]] 3
+; CHECK-NEXT: %[[#res4:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext4]]
+; CHECK-NEXT: %[[#ret:]] = OpCompositeConstruct %[[#bool4]] %[[#res1:]] %[[#res2:]] %[[#res3:]] %[[#res4:]]
+; CHECK-NEXT: OpReturnValue %[[#ret]]
%0 = call <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half> %vbexpr)
ret <4 x i1> %0
}
>From 8ce17b4829f1d8c32c4e027c607ef5511460d695 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Mon, 2 Mar 2026 13:27:59 -0800
Subject: [PATCH 5/6] clang format
---
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3794092703470..fe16ceaf28c14 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -338,7 +338,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
MachineInstr &I) const;
bool selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst ResType,
- MachineInstr &I) const;
+ MachineInstr &I) const;
bool selectUnmergeValues(MachineInstr &I) const;
@@ -2922,7 +2922,6 @@ bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
return true;
}
-
bool SPIRVInstructionSelector::selectWavePrefixBitCount(Register ResVReg,
SPIRVTypeInst ResType,
MachineInstr &I) const {
>From 9a58dc845943b20c414c238dee6440a2f9e96d1d Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Mon, 2 Mar 2026 14:49:19 -0800
Subject: [PATCH 6/6] more repairs to pass spirv-val
---
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 77 ++++++++++---------
.../hlsl-intrinsics/WaveActiveAllEqual.ll | 8 +-
2 files changed, 44 insertions(+), 41 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index fe16ceaf28c14..1a9e07eb54e8c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2845,57 +2845,63 @@ unsigned getVectorSizeOrOne(SPIRVTypeInst Type) {
bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
SPIRVTypeInst ResType,
MachineInstr &I) const {
-
MachineBasicBlock &BB = *I.getParent();
const DebugLoc &DL = I.getDebugLoc();
- SPIRVTypeInst SpvTy = GR.getSPIRVTypeForVReg(ResVReg);
- unsigned NumElems = getVectorSizeOrOne(SpvTy);
+ // Input to the intrinsic
+ Register InputReg = I.getOperand(2).getReg();
+ SPIRVTypeInst InputType = GR.getSPIRVTypeForVReg(InputReg);
+
+ // Determine if input is vector
+ unsigned NumElems = getVectorSizeOrOne(InputType);
bool IsVector = NumElems > 1;
+ // Determine element types
+ SPIRVTypeInst ElemInputType = InputType;
+ SPIRVTypeInst ElemBoolType = ResType;
+ if (IsVector) {
+ ElemInputType = GR.getSPIRVTypeForVReg(InputType->getOperand(1).getReg());
+ ElemBoolType = GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg());
+ }
+
// Subgroup scope constant
SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
-
Register ScopeConst = GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
TII, !STI.isShader());
- Register InputReg = I.getOperand(2).getReg();
+ // === Scalar case ===
+ if (!IsVector) {
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ElemBoolType))
+ .addUse(ScopeConst)
+ .addUse(InputReg)
+ .constrainAllUses(TII, TRI, RBI);
+ return true;
+ }
+ // === Vector case ===
SmallVector<Register, 4> ElementResults;
-
- // If vector, determine element type once
- SPIRVTypeInst ElemInputType = SpvTy;
- SPIRVTypeInst ElemBoolType = ResType;
-
- if (IsVector) {
- Register ElemTypeReg = SpvTy->getOperand(1).getReg();
- ElemInputType = GR.getSPIRVTypeForVReg(ElemTypeReg);
-
- Register BoolElemReg = ResType->getOperand(1).getReg();
- ElemBoolType = GR.getSPIRVTypeForVReg(BoolElemReg);
- }
+ ElementResults.reserve(NumElems);
for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
-
+ // Extract element
Register ElemInput = InputReg;
+ Register Extracted =
+ MRI->createVirtualRegister(GR.getRegClass(ElemInputType));
+
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract))
+ .addDef(Extracted)
+ .addUse(GR.getSPIRVTypeID(ElemInputType))
+ .addUse(InputReg)
+ .addImm(Idx)
+ .constrainAllUses(TII, TRI, RBI);
- if (IsVector) {
- Register Extracted =
- MRI->createVirtualRegister(GR.getRegClass(ElemInputType));
-
- BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract))
- .addDef(Extracted)
- .addUse(GR.getSPIRVTypeID(ElemInputType))
- .addUse(InputReg)
- .addImm(Idx)
- .constrainAllUses(TII, TRI, RBI);
-
- ElemInput = Extracted;
- }
+ ElemInput = Extracted;
+ // Emit per-element AllEqual
Register ElemResult =
- IsVector ? MRI->createVirtualRegister(GR.getRegClass(ElemBoolType))
- : ResVReg;
+ MRI->createVirtualRegister(GR.getRegClass(ElemBoolType));
BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual))
.addDef(ElemResult)
@@ -2907,13 +2913,10 @@ bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
ElementResults.push_back(ElemResult);
}
- if (!IsVector)
- return true;
-
+ // Reconstruct vector<bool>
auto MIB = BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeConstruct))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType));
-
for (Register R : ElementResults)
MIB.addUse(R);
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
index 9c63539e4b9e4..8733505942c4c 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
@@ -35,13 +35,13 @@ entry:
define <4 x i1> @test_vhalf(<4 x half> %vbexpr) {
entry:
; CHECK: %[[#param:]] = OpFunctionParameter %[[#v4float:]]
-; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#bool]] %[[#param]] 0
+; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#f16]] %[[#param]] 0
; CHECK-NEXT: %[[#res1:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext1]]
-; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#bool]] %[[#param]] 1
+; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#f16]] %[[#param]] 1
; CHECK-NEXT: %[[#res2:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext2]]
-; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#bool]] %[[#param]] 2
+; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#f16]] %[[#param]] 2
; CHECK-NEXT: %[[#res3:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext3]]
-; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#bool]] %[[#param]] 3
+; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#f16]] %[[#param]] 3
; CHECK-NEXT: %[[#res4:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext4]]
; CHECK-NEXT: %[[#ret:]] = OpCompositeConstruct %[[#bool4]] %[[#res1:]] %[[#res2:]] %[[#res3:]] %[[#res4:]]
; CHECK-NEXT: OpReturnValue %[[#ret]]
More information about the cfe-commits
mailing list