[llvm] 6ef3218 - [SPIRV] Add support for `bfloat16` atomics via the `SPV_INTEL_16bit_atomics` extension (#166257)

via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 9 09:26:18 PST 2025


Author: Alex Voicu
Date: 2025-11-09T17:26:14Z
New Revision: 6ef32188b5a10167b94ac9e8f7bbac5dfc6c8730

URL: https://github.com/llvm/llvm-project/commit/6ef32188b5a10167b94ac9e8f7bbac5dfc6c8730
DIFF: https://github.com/llvm/llvm-project/commit/6ef32188b5a10167b94ac9e8f7bbac5dfc6c8730.diff

LOG: [SPIRV] Add support for `bfloat16` atomics via the `SPV_INTEL_16bit_atomics` extension (#166257)

This enables support for atomic RMW ops (add, sub, min and max to be
precise) with `bfloat16` operands, via the [SPV_INTEL_16bit_atomics
extension](https://github.com/intel/llvm/pull/20009). It's logically a
successor to #166031 (I should've used a stack), but I'm putting it up
for early review.

---------

Co-authored-by: Matt Arsenault <arsenm2 at gmail.com>

Added: 
    llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll
    llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll

Modified: 
    llvm/docs/SPIRVUsage.rst
    llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
    llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
    llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
    llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Removed: 
    


################################################################################
diff  --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 9ecd39025e781..5ee3d83bd7aac 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -167,6 +167,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
      - Adds atomic add instruction on floating-point numbers.
    * - ``SPV_EXT_shader_atomic_float_min_max``
      - Adds atomic min and max instruction on floating-point numbers.
+   * - ``SPV_INTEL_16bit_atomics``
+     - Extends the SPV_EXT_shader_atomic_float_add and SPV_EXT_shader_atomic_float_min_max to support addition, minimum and maximum on 16-bit `bfloat16` floating-point numbers in memory.
    * - ``SPV_INTEL_2d_block_io``
      - Adds additional subgroup block prefetch, load, load transposed, load transformed and store instructions to read two-dimensional blocks of data from a two-dimensional region of memory, or to write two-dimensional blocks of data to a two dimensional region of memory.
    * - ``SPV_INTEL_arbitrary_precision_integers``

diff  --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 4f6a19fe6633b..d656f1071400d 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -3482,7 +3482,7 @@ bool IRTranslator::translateAtomicCmpXchg(const User &U,
 
 bool IRTranslator::translateAtomicRMW(const User &U,
                                       MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   const AtomicRMWInst &I = cast<AtomicRMWInst>(U);

diff  --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index f681b0d9fb433..ac09b937a584a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -29,6 +29,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
         {"SPV_EXT_shader_atomic_float_min_max",
          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
+        {"SPV_INTEL_16bit_atomics",
+         SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics},
         {"SPV_EXT_arithmetic_fence",
          SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
         {"SPV_EXT_demote_to_helper_invocation",

diff  --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index af76016861761..fbb127df16dd9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1058,6 +1058,13 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
   }
 }
 
+static bool isBFloat16Type(const SPIRVType *TypeDef) {
+  return TypeDef && TypeDef->getNumOperands() == 3 &&
+         TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+         TypeDef->getOperand(1).getImm() == 16 &&
+         TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
+}
+
 // Add requirements for handling atomic float instructions
 #define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
   "The atomic float instruction requires the following SPIR-V "                \
@@ -1081,11 +1088,21 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
     switch (BitWidth) {
     case 16:
-      if (!ST.canUseExtension(
-              SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
-        report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
-      Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
-      Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+      if (isBFloat16Type(TypeDef)) {
+        if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+          report_fatal_error(
+              "The atomic bfloat16 instruction requires the following SPIR-V "
+              "extension: SPV_INTEL_16bit_atomics",
+              false);
+        Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+        Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
+      } else {
+        if (!ST.canUseExtension(
+                SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
+          report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
+        Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
+        Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+      }
       break;
     case 32:
       Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
@@ -1104,7 +1121,17 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
     Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
     switch (BitWidth) {
     case 16:
-      Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+      if (isBFloat16Type(TypeDef)) {
+        if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+          report_fatal_error(
+              "The atomic bfloat16 instruction requires the following SPIR-V "
+              "extension: SPV_INTEL_16bit_atomics",
+              false);
+        Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+        Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
+      } else {
+        Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+      }
       break;
     case 32:
       Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
@@ -1328,13 +1355,6 @@ void addPrintfRequirements(const MachineInstr &MI,
   }
 }
 
-static bool isBFloat16Type(const SPIRVType *TypeDef) {
-  return TypeDef && TypeDef->getNumOperands() == 3 &&
-         TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
-         TypeDef->getOperand(1).getImm() == 16 &&
-         TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
-}
-
 void addInstrRequirements(const MachineInstr &MI,
                           SPIRV::ModuleAnalysisInfo &MAI,
                           const SPIRVSubtarget &ST) {

diff  --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 65a888529bb58..f02a587013856 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -389,6 +389,7 @@ defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
 defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
 defm SPV_INTEL_bfloat16_arithmetic
     : ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
+defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -566,9 +567,11 @@ defm FloatControls2
 defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
 defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
 defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
+defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>;
 defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
+defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>;
 defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;

diff  --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll
new file mode 100644
index 0000000000000..a189b2a655589
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_faddfsub_bfloat16.ll
@@ -0,0 +1,34 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR1
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR2
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_add,+SPV_INTEL_16bit_atomics,+SPV_KHR_bfloat16,+SPV_INTEL_bfloat16_arithmetic %s -o - | FileCheck %s
+
+; CHECK-ERROR1: LLVM ERROR: The atomic float instruction requires the following SPIR-V extension: SPV_EXT_shader_atomic_float_add
+; CHECK-ERROR2: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics
+
+; CHECK: Capability BFloat16TypeKHR
+; CHECK: Capability AtomicBFloat16AddINTEL
+; CHECK: Extension "SPV_KHR_bfloat16"
+; CHECK: Extension "SPV_EXT_shader_atomic_float_add"
+; CHECK: Extension "SPV_INTEL_16bit_atomics"
+; CHECK-DAG: %[[TyBF16:[0-9]+]] = OpTypeFloat 16 0
+; CHECK-DAG: %[[TyBF16Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyBF16]]
+; CHECK-DAG: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: %[[ConstBF16:[0-9]+]] = OpConstant %[[TyBF16]] 16936{{$}}
+; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstantNull %[[TyBF16]]
+; CHECK-DAG: %[[BF16Ptr:[0-9]+]] = OpVariable %[[TyBF16Ptr]] CrossWorkgroup %[[Const0]]
+; CHECK-DAG: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: OpAtomicFAddEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+; CHECK: %[[NegatedConstBF16:[0-9]+]] = OpFNegate %[[TyBF16]] %[[ConstBF16]]
+; CHECK: OpAtomicFAddEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstBF16]]
+
+
+ at f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
+
+define dso_local spir_func void @test1() local_unnamed_addr {
+entry:
+  %addval = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
+  %subval = atomicrmw fsub ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
+  ret void
+}

diff  --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll
new file mode 100644
index 0000000000000..dd8448039ec62
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_16bit_atomics/atomicrmw_fminfmax_bfloat16.ll
@@ -0,0 +1,28 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_min_max,+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_EXT_shader_atomic_float_min_max,+SPV_INTEL_16bit_atomics,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+
+; CHECK-ERROR: LLVM ERROR: The atomic bfloat16 instruction requires the following SPIR-V extension: SPV_INTEL_16bit_atomics
+
+; CHECK: Capability AtomicBFloat16MinMaxINTEL
+; CHECK: Extension "SPV_KHR_bfloat16"
+; CHECK: Extension "SPV_EXT_shader_atomic_float_min_max"
+; CHECK: Extension "SPV_INTEL_16bit_atomics"
+; CHECK-DAG: %[[TyBF16:[0-9]+]] = OpTypeFloat 16 0
+; CHECK-DAG: %[[TyBF16Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyBF16]]
+; CHECK-DAG: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: %[[ConstBF16:[0-9]+]] = OpConstant %[[TyBF16]] 16936{{$}}
+; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstantNull %[[TyBF16]]
+; CHECK-DAG: %[[BF16Ptr:[0-9]+]] = OpVariable %[[TyBF16Ptr]] CrossWorkgroup %[[Const0]]
+; CHECK-DAG: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: OpAtomicFMinEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+; CHECK: OpAtomicFMaxEXT %[[TyBF16]] %[[BF16Ptr]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstBF16]]
+
+ at f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
+
+define spir_func void @test1() {
+entry:
+  %minval = atomicrmw fmin ptr addrspace(1) @f, bfloat 42.0e+00 seq_cst
+  %maxval = atomicrmw fmax ptr addrspace(1) @f, bfloat 42.0e+00 seq_cst
+  ret void
+}


        


More information about the llvm-commits mailing list