[llvm] 93d64a5 - [SPIRV] Add `<2 x half>` and `<4 x half>` atomics via `SPV_NV_shader_atomic_fp16_vector` (#170213)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 5 12:23:29 PST 2025
Author: Alex Voicu
Date: 2025-12-05T20:23:25Z
New Revision: 93d64a5c4d912e8898d1e7c7e5a95f6f44f7e983
URL: https://github.com/llvm/llvm-project/commit/93d64a5c4d912e8898d1e7c7e5a95f6f44f7e983
DIFF: https://github.com/llvm/llvm-project/commit/93d64a5c4d912e8898d1e7c7e5a95f6f44f7e983.diff
LOG: [SPIRV] Add `<2 x half>` and `<4 x half>` atomics via `SPV_NV_shader_atomic_fp16_vector` (#170213)
This adds support for the `SPV_NV_shader_atomic_fp16_vector` extension,
and then uses it to enable lowering of atomic add, sub, min and max on 2
and 4 component vectors of FP16, which are rather common options in ML
workloads. Even though `bfloat16` also works in practice, we do not
enable it since it's not specified in the extension (which might need
updating / promoting to KHR at least). A `TODO` is also inserted in
`SPIRVModuleAnalysis.cpp' regarding the need to upgrade its ample usage
of `report_fatal_error`; I have a WiP patch for that, but it still needs
a bit of baking. Finally, a paired patch will be necessary in the
Translator, as it's not aware of the extension either - I'll update this
review to reference the PR once I create it.
Added:
llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll
llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll
Modified:
llvm/docs/SPIRVUsage.rst
llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Removed:
################################################################################
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 88164e6fa53d8..e2f85ba3c2774 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -169,6 +169,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
- Adds atomic min and max instruction on floating-point numbers.
* - ``SPV_INTEL_16bit_atomics``
- Extends the SPV_EXT_shader_atomic_float_add and SPV_EXT_shader_atomic_float_min_max to support addition, minimum and maximum on 16-bit `bfloat16` floating-point numbers in memory.
+ * - ``SPV_NV_shader_atomic_fp16_vector``
+ - Adds atomic add, min and max instructions on 2 or 4-component vectors with 16-bit float components.
* - ``SPV_INTEL_2d_block_io``
- Adds additional subgroup block prefetch, load, load transposed, load transformed and store instructions to read two-dimensional blocks of data from a two-dimensional region of memory, or to write two-dimensional blocks of data to a two dimensional region of memory.
* - ``SPV_ALTERA_arbitrary_precision_integers``
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 146384f4bf08c..d2a8fddc5d8e4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -31,6 +31,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
{"SPV_INTEL_16bit_atomics",
SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics},
+ {"SPV_NV_shader_atomic_fp16_vector",
+ SPIRV::Extension::Extension::SPV_NV_shader_atomic_fp16_vector},
{"SPV_EXT_arithmetic_fence",
SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
{"SPV_EXT_demote_to_helper_invocation",
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index df958d2c86b33..e0f7b19c91fbd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1193,7 +1193,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_ATOMICRMW_FSUB:
// Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT,
- SPIRV::OpFNegate);
+ ResType->getOpcode() == SPIRV::OpTypeVector
+ ? SPIRV::OpFNegateV
+ : SPIRV::OpFNegate);
case TargetOpcode::G_ATOMICRMW_FMIN:
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT);
case TargetOpcode::G_ATOMICRMW_FMAX:
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index b5912c27316c9..71df4cced434e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -127,7 +127,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
auto allIntScalars = {s8, s16, s32, s64};
- auto allFloatScalars = {s16, s32, s64};
+ auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16};
auto allFloatScalarsAndVectors = {
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
@@ -351,7 +351,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
getActionDefinitionsBuilder(
{G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
- .legalForCartesianProduct(allFloatScalars, allPtrs);
+ .legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s,
+ allPtrs);
getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
.legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 2feb73d8dedfa..73432279c3306 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -14,6 +14,10 @@
//
//===----------------------------------------------------------------------===//
+// TODO: uses or report_fatal_error (which is also deprecated) /
+// ReportFatalUsageError in this file should be refactored, as per LLVM
+// best practices, to rely on the Diagnostic infrastructure.
+
#include "SPIRVModuleAnalysis.h"
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "MCTargetDesc/SPIRVMCTargetDesc.h"
@@ -1071,6 +1075,39 @@ static bool isBFloat16Type(const SPIRVType *TypeDef) {
#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
"The atomic float instruction requires the following SPIR-V " \
"extension: SPV_EXT_shader_atomic_float" ExtName
+static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
+ SPIRV::RequirementHandler &Reqs,
+ const SPIRVSubtarget &ST) {
+ SPIRVType *VecTypeDef =
+ MI.getMF()->getRegInfo().getVRegDef(MI.getOperand(1).getReg());
+
+ const unsigned Rank = VecTypeDef->getOperand(2).getImm();
+ if (Rank != 2 && Rank != 4)
+ reportFatalUsageError("Result type of an atomic vector float instruction "
+ "must be a 2-component or 4 component vector");
+
+ SPIRVType *EltTypeDef =
+ MI.getMF()->getRegInfo().getVRegDef(VecTypeDef->getOperand(1).getReg());
+
+ if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
+ EltTypeDef->getOperand(1).getImm() != 16)
+ reportFatalUsageError(
+ "The element type for the result type of an atomic vector float "
+ "instruction must be a 16-bit floating-point scalar");
+
+ if (isBFloat16Type(EltTypeDef))
+ reportFatalUsageError(
+ "The element type for the result type of an atomic vector float "
+ "instruction cannot be a bfloat16 scalar");
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
+ reportFatalUsageError(
+ "The atomic float16 vector instruction requires the following SPIR-V "
+ "extension: SPV_NV_shader_atomic_fp16_vector");
+
+ Reqs.addExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16VectorNV);
+}
+
static void AddAtomicFloatRequirements(const MachineInstr &MI,
SPIRV::RequirementHandler &Reqs,
const SPIRVSubtarget &ST) {
@@ -1078,6 +1115,10 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
"Expect register operand in atomic float instruction");
Register TypeReg = MI.getOperand(1).getReg();
SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
+
+ if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+ return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
+
if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
report_fatal_error("Result type of an atomic float instruction must be a "
"floating-point type scalar");
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 94e0138c66487..078f1dff839ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -391,6 +391,8 @@ defm SPV_INTEL_bfloat16_arithmetic
: ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>;
defm SPV_ALTERA_arbitrary_precision_fixed_point : ExtensionOperand<131, [EnvOpenCL, EnvVulkan]>;
+defm SPV_NV_shader_atomic_fp16_vector
+ : ExtensionOperand<132, [EnvVulkan, EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -573,6 +575,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>;
+defm AtomicFloat16VectorNV : CapabilityOperand<5404, 0, 0, [SPV_NV_shader_atomic_fp16_vector], []>;
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll
new file mode 100644
index 0000000000000..36f6e38fc75de
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_faddfsub_vec_float16.ll
@@ -0,0 +1,47 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_NV_shader_atomic_fp16_vector %s -o - | FileCheck %s
+
+; CHECK-ERROR: LLVM ERROR: The atomic float16 vector instruction requires the following SPIR-V extension: SPV_NV_shader_atomic_fp16_vector
+
+; CHECK: Capability Float16
+; CHECK-DAG: Capability AtomicFloat16VectorNV
+; CHECK: Extension "SPV_NV_shader_atomic_fp16_vector"
+; CHECK-DAG: %[[TyF16:[0-9]+]] = OpTypeFloat 16
+; CHECK: %[[TyF16Vec2:[0-9]+]] = OpTypeVector %[[TyF16]] 2
+; CHECK: %[[TyF16Vec4:[0-9]+]] = OpTypeVector %[[TyF16]] 4
+; CHECK: %[[TyF16Vec4Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec4]]
+; CHECK: %[[TyF16Vec2Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec2]]
+; CHECK: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK: %[[ConstF16:[0-9]+]] = OpConstant %[[TyF16]] 20800{{$}}
+; CHECK: %[[Const0F16Vec2:[0-9]+]] = OpConstantNull %[[TyF16Vec2]]
+; CHECK: %[[f:[0-9]+]] = OpVariable %[[TyF16Vec2Ptr]] CrossWorkgroup %[[Const0F16Vec2]]
+; CHECK: %[[Const0F16Vec4:[0-9]+]] = OpConstantNull %[[TyF16Vec4]]
+; CHECK: %[[g:[0-9]+]] = OpVariable %[[TyF16Vec4Ptr]] CrossWorkgroup %[[Const0F16Vec4]]
+; CHECK: %[[ConstF16Vec2:[0-9]+]] = OpConstantComposite %[[TyF16Vec2]] %[[ConstF16]] %[[ConstF16]]
+; CHECK: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: %[[ConstF16Vec4:[0-9]+]] = OpConstantComposite %[[TyF16Vec4]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]]
+
+ at f = common dso_local local_unnamed_addr addrspace(1) global <2 x half> <half 0.000000e+00, half 0.000000e+00>
+ at g = common dso_local local_unnamed_addr addrspace(1) global <4 x half> <half 0.000000e+00, half 0.000000e+00, half 0.000000e+00, half 0.000000e+00>
+
+; CHECK-DAG: OpAtomicFAddEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
+; CHECK: %[[NegatedConstF16Vec2:[0-9]+]] = OpFNegate %[[TyF16Vec2]] %[[ConstF16Vec2]]
+; CHECK: OpAtomicFAddEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstF16Vec2]]
+define dso_local spir_func void @test1() local_unnamed_addr {
+entry:
+ %addval = atomicrmw fadd ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
+ %subval = atomicrmw fsub ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
+ ret void
+}
+
+; CHECK-DAG: OpAtomicFAddEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
+; CHECK: %[[NegatedConstF16Vec4:[0-9]+]] = OpFNegate %[[TyF16Vec4]] %[[ConstF16Vec4]]
+; CHECK: OpAtomicFAddEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstF16Vec4]]
+define dso_local spir_func void @test2() local_unnamed_addr {
+entry:
+ %addval = atomicrmw fadd ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
+ %subval = atomicrmw fsub ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
+ ret void
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll
new file mode 100644
index 0000000000000..7ac772bf5d094
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_NV_shader_atomic_fp16_vector/atomicrmw_fminfmax_vec_float16.ll
@@ -0,0 +1,45 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_NV_shader_atomic_fp16_vector %s -o - | FileCheck %s
+
+; CHECK-ERROR: LLVM ERROR: The atomic float16 vector instruction requires the following SPIR-V extension: SPV_NV_shader_atomic_fp16_vector
+
+; CHECK: Capability Float16
+; CHECK-DAG: Capability AtomicFloat16VectorNV
+; CHECK: Extension "SPV_NV_shader_atomic_fp16_vector"
+; CHECK-DAG: %[[TyF16:[0-9]+]] = OpTypeFloat 16
+; CHECK: %[[TyF16Vec2:[0-9]+]] = OpTypeVector %[[TyF16]] 2
+; CHECK: %[[TyF16Vec4:[0-9]+]] = OpTypeVector %[[TyF16]] 4
+; CHECK: %[[TyF16Vec4Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec4]]
+; CHECK: %[[TyF16Vec2Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec2]]
+; CHECK: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
+; CHECK: %[[ConstF16:[0-9]+]] = OpConstant %[[TyF16]] 20800{{$}}
+; CHECK: %[[Const0F16Vec2:[0-9]+]] = OpConstantNull %[[TyF16Vec2]]
+; CHECK: %[[f:[0-9]+]] = OpVariable %[[TyF16Vec2Ptr]] CrossWorkgroup %[[Const0F16Vec2]]
+; CHECK: %[[Const0F16Vec4:[0-9]+]] = OpConstantNull %[[TyF16Vec4]]
+; CHECK: %[[g:[0-9]+]] = OpVariable %[[TyF16Vec4Ptr]] CrossWorkgroup %[[Const0F16Vec4]]
+; CHECK: %[[ConstF16Vec2:[0-9]+]] = OpConstantComposite %[[TyF16Vec2]] %[[ConstF16]] %[[ConstF16]]
+; CHECK: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
+; CHECK: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
+; CHECK: %[[ConstF16Vec4:[0-9]+]] = OpConstantComposite %[[TyF16Vec4]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]]
+
+ at f = common dso_local local_unnamed_addr addrspace(1) global <2 x half> <half 0.000000e+00, half 0.000000e+00>
+ at g = common dso_local local_unnamed_addr addrspace(1) global <4 x half> <half 0.000000e+00, half 0.000000e+00, half 0.000000e+00, half 0.000000e+00>
+
+; CHECK-DAG: OpAtomicFMinEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
+; CHECK: OpAtomicFMaxEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
+define dso_local spir_func void @test1() local_unnamed_addr {
+entry:
+ %minval = atomicrmw fmin ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
+ %maxval = atomicrmw fmax ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
+ ret void
+}
+
+; CHECK-DAG: OpAtomicFMinEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
+; CHECK: OpAtomicFMaxEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
+define dso_local spir_func void @test2() local_unnamed_addr {
+entry:
+ %minval = atomicrmw fmin ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
+ %maxval = atomicrmw fmax ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
+ ret void
+}
\ No newline at end of file
More information about the llvm-commits
mailing list