[llvm] Add support for atomic instruction on floating-point numbers (PR #81683)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 13 15:11:46 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR adds support for atomic instruction on floating-point numbers:
* SPV_EXT_shader_atomic_float_add
* SPV_EXT_shader_atomic_float_min_max
* SPV_EXT_shader_atomic_float16_add
and fixes asm printer output for half floating-type.
---
Patch is 40.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81683.diff
18 Files Affected:
- (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp (+3-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+64-5)
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+39)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.h (+8)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+3)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+27-4)
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+6)
- (modified) llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp (+2)
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+66)
- (modified) llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp (+17-5)
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+6)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+3)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_double.ll (+43)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll (+43)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_half.ll (+46)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_double.ll (+45)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_float.ll (+45)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_half.ll (+45)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 163b2ec0fefe4d..b468b71cc0efb4 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -13,6 +13,7 @@
#include "SPIRVInstPrinter.h"
#include "SPIRV.h"
#include "SPIRVBaseInfo.h"
+#include "SPIRVInstrInfo.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/MC/MCAsmInfo.h"
@@ -50,6 +51,7 @@ void SPIRVInstPrinter::printRemainingVariableOps(const MCInst *MI,
void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
unsigned StartIndex,
raw_ostream &O) {
+ unsigned IsBitwidth16 = MI->getFlags() & SPIRV::ASM_PRINTER_WIDTH16;
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
assert((NumVarOps == 1 || NumVarOps == 2) &&
@@ -65,7 +67,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
}
// Format and print float values.
- if (MI->getOpcode() == SPIRV::OpConstantF) {
+ if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
: APFloat(APInt(64, Imm).bitsToDouble());
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 8721b900c8beee..545891b9ba6d73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -93,6 +93,14 @@ struct IntelSubgroupsBuiltin {
#define GET_IntelSubgroupsBuiltins_DECL
#define GET_IntelSubgroupsBuiltins_IMPL
+struct AtomicFloatingBuiltin {
+ StringRef Name;
+ uint32_t Opcode;
+};
+
+#define GET_AtomicFloatingBuiltins_DECL
+#define GET_AtomicFloatingBuiltins_IMPL
+
struct GetBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
@@ -402,7 +410,7 @@ getSPIRVMemSemantics(std::memory_order MemOrder) {
case std::memory_order::memory_order_seq_cst:
return SPIRV::MemorySemantics::SequentiallyConsistent;
default:
- llvm_unreachable("Unknown CL memory scope");
+ report_fatal_error("Unknown CL memory scope");
}
}
@@ -419,7 +427,7 @@ static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
case SPIRV::CLMemoryScope::memory_scope_sub_group:
return SPIRV::Scope::Subgroup;
}
- llvm_unreachable("Unknown CL memory scope");
+ report_fatal_error("Unknown CL memory scope");
}
static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
@@ -676,6 +684,38 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
return true;
}
+/// Helper function for building an atomic floating-type instruction.
+static bool buildAtomicFloatingRMWInst(const SPIRV::IncomingCall *Call,
+ unsigned Opcode,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ assert(Call->Arguments.size() == 4 &&
+ "Wrong number of atomic floating-type builtin");
+
+ MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+
+ Register PtrReg = Call->Arguments[0];
+ MRI->setRegClass(PtrReg, &SPIRV::IDRegClass);
+
+ Register ScopeReg = Call->Arguments[1];
+ MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
+
+ Register MemSemanticsReg = Call->Arguments[2];
+ MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
+
+ Register ValueReg = Call->Arguments[3];
+ MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
+
+ MIRBuilder.buildInstr(Opcode)
+ .addDef(Call->ReturnRegister)
+ .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+ .addUse(PtrReg)
+ .addUse(ScopeReg)
+ .addUse(MemSemanticsReg)
+ .addUse(ValueReg);
+ return true;
+}
+
/// Helper function for building atomic flag instructions (e.g.
/// OpAtomicFlagTestAndSet).
static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
@@ -786,7 +826,7 @@ static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
case SPIRV::Dim::DIM_3D:
return 3;
default:
- llvm_unreachable("Cannot get num components for given Dim");
+ report_fatal_error("Cannot get num components for given Dim");
}
}
@@ -1157,6 +1197,23 @@ static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
}
}
+static bool generateAtomicFloatingInst(const SPIRV::IncomingCall *Call,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ // Lookup the instruction opcode in the TableGen records.
+ const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+ unsigned Opcode = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name)->Opcode;
+
+ switch (Opcode) {
+ case SPIRV::OpAtomicFAddEXT:
+ case SPIRV::OpAtomicFMinEXT:
+ case SPIRV::OpAtomicFMaxEXT:
+ return buildAtomicFloatingRMWInst(Call, Opcode, MIRBuilder, GR);
+ default:
+ return false;
+ }
+}
+
static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -1311,7 +1368,7 @@ getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
case SPIRV::CLK_ADDRESS_NONE:
return SPIRV::SamplerAddressingMode::None;
default:
- llvm_unreachable("Unknown CL address mode");
+ report_fatal_error("Unknown CL address mode");
}
}
@@ -2021,6 +2078,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateBuiltinVar(Call.get(), MIRBuilder, GR);
case SPIRV::Atomic:
return generateAtomicInst(Call.get(), MIRBuilder, GR);
+ case SPIRV::AtomicFloating:
+ return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
case SPIRV::Barrier:
return generateBarrierInst(Call.get(), MIRBuilder, GR);
case SPIRV::Dot:
@@ -2089,7 +2148,7 @@ static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
return Type::getFloatTy(Context);
else if (Name.starts_with("half"))
return Type::getHalfTy(Context);
- llvm_unreachable("Unable to recognize type!");
+ report_fatal_error("Unable to recognize type!");
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 4013dd22f4ab57..92009ae33ef32f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -55,6 +55,7 @@ def AsyncCopy : BuiltinGroup;
def VectorLoadStore : BuiltinGroup;
def LoadStore : BuiltinGroup;
def IntelSubgroups : BuiltinGroup;
+def AtomicFloating : BuiltinGroup;
//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
@@ -872,6 +873,44 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;
+//===----------------------------------------------------------------------===//
+// Class defining an atomic instruction on floating-point numbers.
+//
+// name is the demangled name of the given builtin.
+// opcode specifies the SPIR-V operation code of the generated instruction.
+//===----------------------------------------------------------------------===//
+class AtomicFloatingBuiltin<string name, Op operation> {
+ string Name = name;
+ Op Opcode = operation;
+}
+
+// Table gathering all the Intel sub group builtins.
+def AtomicFloatingBuiltins : GenericTable {
+ let FilterClass = "AtomicFloatingBuiltin";
+ let Fields = ["Name", "Opcode"];
+}
+
+// Function to lookup group builtins by their name and set.
+def lookupAtomicFloatingBuiltin : SearchIndex {
+ let Table = AtomicFloatingBuiltins;
+ let Key = ["Name"];
+}
+
+// Multiclass used to define incoming builtin records for the SPV_INTEL_subgroups extension
+// and corresponding work/sub group builtin records.
+multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
+ def : DemangledBuiltin<!strconcat("__spirv_AtomicF", name), OpenCL_std, AtomicFloating, minNumArgs, maxNumArgs>;
+ def : AtomicFloatingBuiltin<!strconcat("__spirv_AtomicF", name), operation>;
+}
+
+// SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_atomic_float16_add
+// Atomic add, min and max instruction on floating-point numbers:
+defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
+defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
+defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
+// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
+// on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)
+
//===----------------------------------------------------------------------===//
// Class defining a sub group builtin that should be translated into a
// SPIR-V instruction using the SPV_INTEL_subgroups extension.
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
index c01e30e109bd5b..b9a3425cb5a1c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
@@ -53,6 +53,14 @@ class SPIRVInstrInfo : public SPIRVGenInstrInfo {
bool KillSrc) const override;
bool expandPostRAPseudo(MachineInstr &MI) const override;
};
+
+namespace SPIRV {
+enum AsmComments {
+ // It is a half type
+ ASM_PRINTER_WIDTH16 = MachineInstr::TAsmComments
+};
+}; // namespace SPIRV
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVINSTRINFO_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 904fef1d6c82f9..7965dd969e5cfa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -643,6 +643,9 @@ def OpAtomicAnd: AtomicOpVal<"OpAtomicAnd", 240>;
def OpAtomicOr: AtomicOpVal<"OpAtomicOr", 241>;
def OpAtomicXor: AtomicOpVal<"OpAtomicXor", 242>;
+def OpAtomicFAddEXT: AtomicOpVal<"OpAtomicFAddEXT", 6035>;
+def OpAtomicFMinEXT: AtomicOpVal<"OpAtomicFMinEXT", 5614>;
+def OpAtomicFMaxEXT: AtomicOpVal<"OpAtomicFMaxEXT", 5615>;
def OpAtomicFlagTestAndSet: AtomicOp<"OpAtomicFlagTestAndSet", 318>;
def OpAtomicFlagClear: Op<319, (outs), (ins ID:$ptr, ID:$sc, ID:$sem),
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 52eeb8a523e6f6..79ccf5bee7ff37 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -102,7 +102,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectMemOperation(Register ResVReg, MachineInstr &I) const;
bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType,
- MachineInstr &I, unsigned NewOpcode) const;
+ MachineInstr &I, unsigned NewOpcode,
+ unsigned NegateOpcode = 0) const;
bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -489,6 +490,17 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_ATOMIC_CMPXCHG:
return selectAtomicCmpXchg(ResVReg, ResType, I);
+ case TargetOpcode::G_ATOMICRMW_FADD:
+ return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT);
+ case TargetOpcode::G_ATOMICRMW_FSUB:
+ // Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand
+ return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT,
+ SPIRV::OpFNegate);
+ case TargetOpcode::G_ATOMICRMW_FMIN:
+ return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT);
+ case TargetOpcode::G_ATOMICRMW_FMAX:
+ return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMaxEXT);
+
case TargetOpcode::G_FENCE:
return selectFence(I);
@@ -686,7 +698,8 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
- unsigned NewOpcode) const {
+ unsigned NewOpcode,
+ unsigned NegateOpcode) const {
assert(I.hasOneMemOperand());
const MachineMemOperand *MemOp = *I.memoperands_begin();
uint32_t Scope = static_cast<uint32_t>(getScope(MemOp->getSyncScopeID()));
@@ -700,14 +713,24 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));
Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, I);
- return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
+ bool Result = false;
+ Register ValueReg = I.getOperand(2).getReg();
+ if (NegateOpcode != 0) {
+ // Translation with negative value operand is requested
+ Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+ Result |= selectUnOpWithSrc(TmpReg, ResType, I, ValueReg, NegateOpcode);
+ ValueReg = TmpReg;
+ }
+
+ Result |= BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Ptr)
.addUse(ScopeReg)
.addUse(MemSemReg)
- .addUse(I.getOperand(2).getReg())
+ .addUse(ValueReg)
.constrainAllUses(TII, TRI, RBI);
+ return Result;
}
bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 061bc967423712..011a550a7b3d9b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -125,6 +125,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
auto allIntScalars = {s8, s16, s32, s64};
+ auto allFloatScalars = {s16, s32, s64};
+
auto allFloatScalarsAndVectors = {
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
@@ -205,6 +207,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
.legalForCartesianProduct(allIntScalars, allWritablePtrs);
+ getActionDefinitionsBuilder(
+ {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
+ .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
+
getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
.legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
diff --git a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
index 0fa05d377d9f10..8c6649bf628265 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
@@ -23,6 +23,8 @@ using namespace llvm;
void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
SPIRV::ModuleAnalysisInfo *MAI) const {
OutMI.setOpcode(MI->getOpcode());
+ // Propagate previously set flags
+ OutMI.setFlags(MI->getAsmPrinterFlags());
const MachineFunction *MF = MI->getMF();
for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
const MachineOperand &MO = MI->getOperand(i);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index a18aae1761c834..a716732ea57056 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -704,6 +704,67 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
Reqs.addRequirements(SPIRV::Capability::ImageBasic);
}
+// Add requirements for handling atomic float instructions
+#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
+ "The atomic float instruction requires the following SPIR-V " \
+ "extension: SPV_EXT_shader_atomic_float" ExtName
+static void AddAtomicFloatRequirements(const MachineInstr &MI,
+ SPIRV::RequirementHandler &Reqs,
+ const SPIRVSubtarget &ST) {
+ assert(MI.getOperand(1).isReg() &&
+ "Expect register operand in atomic float instruction");
+ Register TypeReg = MI.getOperand(1).getReg();
+ SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
+ if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
+ report_fatal_error("Result type of an atomic float instruction must be a "
+ "floating-point type scalar");
+
+ unsigned BitWidth = TypeDef->getOperand(1).getImm();
+ unsigned Op = MI.getOpcode();
+ if (Op == SPIRV::OpAtomicFAddEXT) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
+ report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
+ Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
+ switch (BitWidth) {
+ case 16:
+ if (!ST.canUseExtension(
+ SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
+ report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
+ Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ break;
+ case 32:
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
+ break;
+ case 64:
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
+ break;
+ default:
+ report_fatal_error(
+ "Unexpected floating-point type width in atomic float instruction");
+ }
+ } else {
+ if (!ST.canUseExtension(
+ SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
+ report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
+ Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
+ switch (BitWidth) {
+ case 16:
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ break;
+ case 32:
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
+ break;
+ case 64:
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
+ break;
+ default:
+ report_fatal_error(
+ "Unexpected floating-point type width in atomic float instruction");
+ }
+ }
+}
+
void addInstrRequirements(const MachineInstr &MI,
SPIRV::RequirementHandler &Reqs,
const SPIRVSubtarget &ST) {
@@ -976,6 +1037,11 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
}
break;
+ case SPIRV::OpAtomicFAddEXT:
+ case SPIRV::OpAtomicFMinEXT:
+ case SPIRV::OpAtomicFMaxEXT:
+ AddAtomicFloatRequirements(MI, Reqs, ST);
+ break;
default:
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index effedc2f17d351..f8fd6309ce0e72 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -31,12 +31,24 @@ cl::list<SPIRV::Extension::Extension> Extensions(
"spirv-extensions", cl::desc("SPIR-V extensions"), cl::ZeroOrMore,
cl::Hidden,
cl::values(
+ clEnumValN(SPIRV::Extension::SPV_EXT_shader_atomic_float_add,
+ "SPV_EXT_shader_atomic_float_add",
+ "Adds atomic add instruction on floating-point numbers."),
+ clEnumValN(
+ SPIRV::Extension::SPV_EXT_shader_atomic_float16_add,
+ "SPV_EXT_shader_atomic_float16_add",
+ "Extends the SPV_EXT_shader_atomic_float_add extension to support "
+ "atomically adding to 16-bit floating-point numbers in memory."),
+ clEnumValN(
+ SPIRV::E...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81683
More information about the llvm-commits
mailing list