[llvm] 6ef3218 - [SPIRV] Add support for `bfloat16` atomics via the `SPV_INTEL_16bit_atomics` extension (#166257)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Nov 9 09:26:18 PST 2025
Author: Alex Voicu
Date: 2025-11-09T17:26:14Z
New Revision: 6ef32188b5a10167b94ac9e8f7bbac5dfc6c8730
URL: https://github.com/llvm/llvm-project/commit/6ef32188b5a10167b94ac9e8f7bbac5dfc6c8730
DIFF: https://github.com/llvm/llvm-project/commit/6ef32188b5a10167b94ac9e8f7bbac5dfc6c8730.diff
LOG: [SPIRV] Add support for `bfloat16` atomics via the `SPV_INTEL_16bit_atomics` extension (#166257)
This enables support for atomic RMW ops (add, sub, min and max to be
precise) with `bfloat16` operands, via the [SPV_INTEL_16bit_atomics
extension](https://github.com/intel/llvm/pull/20009). It's logically a
successor to #166031 (I should've used a stack), but I'm putting it up
for early review.
---------
Co-authored-by: Matt Arsenault <arsenm2 at gmail.com>
Added:
llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll
llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll
Modified:
llvm/docs/SPIRVUsage.rst
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Removed:
################################################################################
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 9ecd39025e781..5ee3d83bd7aac 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -167,6 +167,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
- Adds atomic add instruction on floating-point numbers.
* - ``SPV_EXT_shader_atomic_float_min_max``
- Adds atomic min and max instruction on floating-point numbers.
+ * - ``SPV_INTEL_16bit_atomics``
+ - Extends the SPV_EXT_shader_atomic_float_add and SPV_EXT_shader_atomic_float_min_max to support addition, minimum and maximum on 16-bit `bfloat16` floating-point numbers in memory.
* - ``SPV_INTEL_2d_block_io``
- Adds additional subgroup block prefetch, load, load transposed, load transformed and store instructions to read two-dimensional blocks of data from a two-dimensional region of memory, or to write two-dimensional blocks of data to a two dimensional region of memory.
* - ``SPV_INTEL_arbitrary_precision_integers``
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 4f6a19fe6633b..d656f1071400d 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -3482,7 +3482,7 @@ bool IRTranslator::translateAtomicCmpXchg(const User &U,
bool IRTranslator::translateAtomicRMW(const User &U,
MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
+ if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
return false;
const AtomicRMWInst &I = cast<AtomicRMWInst>(U);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index f681b0d9fb433..ac09b937a584a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -29,6 +29,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
{"SPV_EXT_shader_atomic_float_min_max",
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
+ {"SPV_INTEL_16bit_atomics",
+ SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics},
{"SPV_EXT_arithmetic_fence",
SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
{"SPV_EXT_demote_to_helper_invocation",
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index af76016861761..fbb127df16dd9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1058,6 +1058,13 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
}
}
+static bool isBFloat16Type(const SPIRVType *TypeDef) {
+ return TypeDef && TypeDef->getNumOperands() == 3 &&
+ TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+ TypeDef->getOperand(1).getImm() == 16 &&
+ TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
+}
+
// Add requirements for handling atomic float instructions
#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
"The atomic float instruction requires the following SPIR-V " \
@@ -1081,11 +1088,21 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
switch (BitWidth) {
case 16:
- if (!ST.canUseExtension(
- SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
- report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
- Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
- Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+ report_fatal_error(
+ "The atomic bfloat16 instruction requires the following SPIR-V "
+ "extension: SPV_INTEL_16bit_atomics",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+ Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
+ } else {
+ if (!ST.canUseExtension(
+ SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
+ report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
+ Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ }
break;
case 32:
Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
@@ -1104,7 +1121,17 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
switch (BitWidth) {
case 16:
- Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+ report_fatal_error(
+ "The atomic bfloat16 instruction requires the following SPIR-V "
+ "extension: SPV_INTEL_16bit_atomics",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+ Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
+ } else {
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ }
break;
case 32:
Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
@@ -1328,13 +1355,6 @@ void addPrintfRequirements(const MachineInstr &MI,
}
}
-static bool isBFloat16Type(const SPIRVType *TypeDef) {
- return TypeDef && TypeDef->getNumOperands() == 3 &&
- TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
- TypeDef->getOperand(1).getImm() == 16 &&
- TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
-}
-
void addInstrRequirements(const MachineInstr &MI,
SPIRV::ModuleAnalysisInfo &MAI,
const SPIRVSubtarget &ST) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 65a888529bb58..f02a587013856 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -389,6 +389,7 @@ defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
defm SPV_INTEL_bfloat16_arithmetic
: ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
+defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -566,9 +567,11 @@ defm FloatControls2
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
+defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>;
defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
+defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>;
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll
new file mode 100644
index 0000000000000..a189b2a655589
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll
@@ -0,0 +1,34 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR1
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR2
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_INTEL_16bit_atomics,+SPV_KHR_bfloat16,+SPV_INTEL_bfloat16_arithmetic %s -o - | FileCheck %s
+
+; CHECK-ERROR1: LLVM ERROR: The atomic float instruction requires the following SPIR-V extension: SPV_EXT_shader_atomic_float_add
+; CHECK-ERROR2: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics
+
+; CHECK: Capability BFloat16TypeKHR
+; CHECK: Capability AtomicBFloat16AddINTEL
+; CHECK: Extension "SPV_KHR_bfloat16"
+; CHECK: Extension "SPV_EXT_shader_atomic_float_add"
+; CHECK: Extension "SPV_INTEL_16bit_atomics"
+; CHECK-DAG: %[[TyBF16:[0-9]+]] = OpTypeFloat 16 0
+; CHECK-DAG: %[[TyBF16Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyBF16]]
+; CHECK-DAG: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: %[[ConstBF16:[0-9]+]] = OpConstant %[[TyBF16]] 16936{{$}}
+; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstantNull %[[TyBF16]]
+; CHECK-DAG: %[[BF16Ptr:[0-9]+]] = OpVariable %[[TyBF16Ptr]] CrossWorkgroup %[[Const0]]
+; CHECK-DAG: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: OpAtomicFAddEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+; CHECK: %[[NegatedConstBF16:[0-9]+]] = OpFNegate %[[TyBF16]] %[[ConstBF16]]
+; CHECK: OpAtomicFAddEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstBF16]]
+
+
+ at f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
+
+define dso_local spir_func void @test1() local_unnamed_addr {
+entry:
+ %addval = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
+ %subval = atomicrmw fsub ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll
new file mode 100644
index 0000000000000..dd8448039ec62
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll
@@ -0,0 +1,28 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_min_max,+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_min_max,+SPV_INTEL_16bit_atomics,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+
+; CHECK-ERROR: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics
+
+; CHECK: Capability AtomicBFloat16MinMaxINTEL
+; CHECK: Extension "SPV_KHR_bfloat16"
+; CHECK: Extension "SPV_EXT_shader_atomic_float_min_max"
+; CHECK: Extension "SPV_INTEL_16bit_atomics"
+; CHECK-DAG: %[[TyBF16:[0-9]+]] = OpTypeFloat 16 0
+; CHECK-DAG: %[[TyBF16Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyBF16]]
+; CHECK-DAG: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: %[[ConstBF16:[0-9]+]] = OpConstant %[[TyBF16]] 16936{{$}}
+; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstantNull %[[TyBF16]]
+; CHECK-DAG: %[[BF16Ptr:[0-9]+]] = OpVariable %[[TyBF16Ptr]] CrossWorkgroup %[[Const0]]
+; CHECK-DAG: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: OpAtomicFMinEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+; CHECK: OpAtomicFMaxEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+
+ at f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
+
+define spir_func void @test1() {
+entry:
+ %minval = atomicrmw fmin ptr addrspace(1) @f, bfloat 42.0e+00 seq_cst
+ %maxval = atomicrmw fmax ptr addrspace(1) @f, bfloat 42.0e+00 seq_cst
+ ret void
+}
More information about the llvm-commits
mailing list