[llvm] Add support for atomic instruction on floating-point numbers (PR #81683)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 13 15:11:46 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR adds support for atomic instruction on floating-point numbers:
* SPV_EXT_shader_atomic_float_add
* SPV_EXT_shader_atomic_float_min_max
* SPV_EXT_shader_atomic_float16_add
and fixes asm printer output for half floating-type.

---

Patch is 40.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81683.diff


18 Files Affected:

- (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp (+3-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+64-5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+39) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.h (+8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+27-4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+66) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp (+17-5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+3) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_double.ll (+43) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll (+43) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_half.ll (+46) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_double.ll (+45) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_float.ll (+45) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_min_max/atomicrmw_fminfmax_half.ll (+45) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 163b2ec0fefe4d..b468b71cc0efb4 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -13,6 +13,7 @@
 #include "SPIRVInstPrinter.h"
 #include "SPIRV.h"
 #include "SPIRVBaseInfo.h"
+#include "SPIRVInstrInfo.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/CodeGen/Register.h"
 #include "llvm/MC/MCAsmInfo.h"
@@ -50,6 +51,7 @@ void SPIRVInstPrinter::printRemainingVariableOps(const MCInst *MI,
 void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
                                              unsigned StartIndex,
                                              raw_ostream &O) {
+  unsigned IsBitwidth16 = MI->getFlags() & SPIRV::ASM_PRINTER_WIDTH16;
   const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
 
   assert((NumVarOps == 1 || NumVarOps == 2) &&
@@ -65,7 +67,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
   }
 
   // Format and print float values.
-  if (MI->getOpcode() == SPIRV::OpConstantF) {
+  if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
     APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
                                 : APFloat(APInt(64, Imm).bitsToDouble());
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 8721b900c8beee..545891b9ba6d73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -93,6 +93,14 @@ struct IntelSubgroupsBuiltin {
 #define GET_IntelSubgroupsBuiltins_DECL
 #define GET_IntelSubgroupsBuiltins_IMPL
 
+struct AtomicFloatingBuiltin {
+  StringRef Name;
+  uint32_t Opcode;
+};
+
+#define GET_AtomicFloatingBuiltins_DECL
+#define GET_AtomicFloatingBuiltins_IMPL
+
 struct GetBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
@@ -402,7 +410,7 @@ getSPIRVMemSemantics(std::memory_order MemOrder) {
   case std::memory_order::memory_order_seq_cst:
     return SPIRV::MemorySemantics::SequentiallyConsistent;
   default:
-    llvm_unreachable("Unknown CL memory scope");
+    report_fatal_error("Unknown CL memory scope");
   }
 }
 
@@ -419,7 +427,7 @@ static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
   case SPIRV::CLMemoryScope::memory_scope_sub_group:
     return SPIRV::Scope::Subgroup;
   }
-  llvm_unreachable("Unknown CL memory scope");
+  report_fatal_error("Unknown CL memory scope");
 }
 
 static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
@@ -676,6 +684,38 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
   return true;
 }
 
+/// Helper function for building an atomic floating-type instruction.
+static bool buildAtomicFloatingRMWInst(const SPIRV::IncomingCall *Call,
+                                       unsigned Opcode,
+                                       MachineIRBuilder &MIRBuilder,
+                                       SPIRVGlobalRegistry *GR) {
+  assert(Call->Arguments.size() == 4 &&
+         "Wrong number of atomic floating-type builtin");
+
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+
+  Register PtrReg = Call->Arguments[0];
+  MRI->setRegClass(PtrReg, &SPIRV::IDRegClass);
+
+  Register ScopeReg = Call->Arguments[1];
+  MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
+
+  Register MemSemanticsReg = Call->Arguments[2];
+  MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
+
+  Register ValueReg = Call->Arguments[3];
+  MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
+
+  MIRBuilder.buildInstr(Opcode)
+      .addDef(Call->ReturnRegister)
+      .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+      .addUse(PtrReg)
+      .addUse(ScopeReg)
+      .addUse(MemSemanticsReg)
+      .addUse(ValueReg);
+  return true;
+}
+
 /// Helper function for building atomic flag instructions (e.g.
 /// OpAtomicFlagTestAndSet).
 static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
@@ -786,7 +826,7 @@ static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
   case SPIRV::Dim::DIM_3D:
     return 3;
   default:
-    llvm_unreachable("Cannot get num components for given Dim");
+    report_fatal_error("Cannot get num components for given Dim");
   }
 }
 
@@ -1157,6 +1197,23 @@ static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
   }
 }
 
+static bool generateAtomicFloatingInst(const SPIRV::IncomingCall *Call,
+                                       MachineIRBuilder &MIRBuilder,
+                                       SPIRVGlobalRegistry *GR) {
+  // Lookup the instruction opcode in the TableGen records.
+  const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+  unsigned Opcode = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name)->Opcode;
+
+  switch (Opcode) {
+  case SPIRV::OpAtomicFAddEXT:
+  case SPIRV::OpAtomicFMinEXT:
+  case SPIRV::OpAtomicFMaxEXT:
+    return buildAtomicFloatingRMWInst(Call, Opcode, MIRBuilder, GR);
+  default:
+    return false;
+  }
+}
+
 static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder,
                                 SPIRVGlobalRegistry *GR) {
@@ -1311,7 +1368,7 @@ getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
   case SPIRV::CLK_ADDRESS_NONE:
     return SPIRV::SamplerAddressingMode::None;
   default:
-    llvm_unreachable("Unknown CL address mode");
+    report_fatal_error("Unknown CL address mode");
   }
 }
 
@@ -2021,6 +2078,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateBuiltinVar(Call.get(), MIRBuilder, GR);
   case SPIRV::Atomic:
     return generateAtomicInst(Call.get(), MIRBuilder, GR);
+  case SPIRV::AtomicFloating:
+    return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Barrier:
     return generateBarrierInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Dot:
@@ -2089,7 +2148,7 @@ static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
     return Type::getFloatTy(Context);
   else if (Name.starts_with("half"))
     return Type::getHalfTy(Context);
-  llvm_unreachable("Unable to recognize type!");
+  report_fatal_error("Unable to recognize type!");
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 4013dd22f4ab57..92009ae33ef32f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -55,6 +55,7 @@ def AsyncCopy : BuiltinGroup;
 def VectorLoadStore : BuiltinGroup;
 def LoadStore : BuiltinGroup;
 def IntelSubgroups : BuiltinGroup;
+def AtomicFloating : BuiltinGroup;
 
 //===----------------------------------------------------------------------===//
 // Class defining a demangled builtin record. The information in the record
@@ -872,6 +873,44 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
 defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
 defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;
 
+//===----------------------------------------------------------------------===//
+// Class defining an atomic instruction on floating-point numbers.
+//
+// name is the demangled name of the given builtin.
+// opcode specifies the SPIR-V operation code of the generated instruction.
+//===----------------------------------------------------------------------===//
+class AtomicFloatingBuiltin<string name, Op operation> {
+  string Name = name;
+  Op Opcode = operation;
+}
+
+// Table gathering all the Intel sub group builtins.
+def AtomicFloatingBuiltins : GenericTable {
+  let FilterClass = "AtomicFloatingBuiltin";
+  let Fields = ["Name", "Opcode"];
+}
+
+// Function to lookup group builtins by their name and set.
+def lookupAtomicFloatingBuiltin : SearchIndex {
+  let Table = AtomicFloatingBuiltins;
+  let Key = ["Name"];
+}
+
+// Multiclass used to define incoming builtin records for the SPV_INTEL_subgroups extension
+// and corresponding work/sub group builtin records.
+multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
+  def : DemangledBuiltin<!strconcat("__spirv_AtomicF", name), OpenCL_std, AtomicFloating, minNumArgs, maxNumArgs>;
+  def : AtomicFloatingBuiltin<!strconcat("__spirv_AtomicF", name), operation>;
+}
+
+// SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_atomic_float16_add
+// Atomic add, min and max instruction on floating-point numbers:
+defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
+defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
+defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
+// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
+//       on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)
+
 //===----------------------------------------------------------------------===//
 // Class defining a sub group builtin that should be translated into a
 // SPIR-V instruction using the SPV_INTEL_subgroups extension.
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
index c01e30e109bd5b..b9a3425cb5a1c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
@@ -53,6 +53,14 @@ class SPIRVInstrInfo : public SPIRVGenInstrInfo {
                    bool KillSrc) const override;
   bool expandPostRAPseudo(MachineInstr &MI) const override;
 };
+
+namespace SPIRV {
+enum AsmComments {
+  // It is a half type
+  ASM_PRINTER_WIDTH16 = MachineInstr::TAsmComments
+};
+}; // namespace SPIRV
+
 } // namespace llvm
 
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVINSTRINFO_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 904fef1d6c82f9..7965dd969e5cfa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -643,6 +643,9 @@ def OpAtomicAnd: AtomicOpVal<"OpAtomicAnd", 240>;
 def OpAtomicOr: AtomicOpVal<"OpAtomicOr", 241>;
 def OpAtomicXor: AtomicOpVal<"OpAtomicXor", 242>;
 
+def OpAtomicFAddEXT: AtomicOpVal<"OpAtomicFAddEXT", 6035>;
+def OpAtomicFMinEXT: AtomicOpVal<"OpAtomicFMinEXT", 5614>;
+def OpAtomicFMaxEXT: AtomicOpVal<"OpAtomicFMaxEXT", 5615>;
 
 def OpAtomicFlagTestAndSet: AtomicOp<"OpAtomicFlagTestAndSet", 318>;
 def OpAtomicFlagClear: Op<319, (outs), (ins ID:$ptr, ID:$sc, ID:$sem),
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 52eeb8a523e6f6..79ccf5bee7ff37 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -102,7 +102,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectMemOperation(Register ResVReg, MachineInstr &I) const;
 
   bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType,
-                       MachineInstr &I, unsigned NewOpcode) const;
+                       MachineInstr &I, unsigned NewOpcode,
+                       unsigned NegateOpcode = 0) const;
 
   bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType,
                            MachineInstr &I) const;
@@ -489,6 +490,17 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_ATOMIC_CMPXCHG:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
 
+  case TargetOpcode::G_ATOMICRMW_FADD:
+    return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT);
+  case TargetOpcode::G_ATOMICRMW_FSUB:
+    // Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand
+    return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT,
+                           SPIRV::OpFNegate);
+  case TargetOpcode::G_ATOMICRMW_FMIN:
+    return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT);
+  case TargetOpcode::G_ATOMICRMW_FMAX:
+    return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMaxEXT);
+
   case TargetOpcode::G_FENCE:
     return selectFence(I);
 
@@ -686,7 +698,8 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
 bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
                                                const SPIRVType *ResType,
                                                MachineInstr &I,
-                                               unsigned NewOpcode) const {
+                                               unsigned NewOpcode,
+                                               unsigned NegateOpcode) const {
   assert(I.hasOneMemOperand());
   const MachineMemOperand *MemOp = *I.memoperands_begin();
   uint32_t Scope = static_cast<uint32_t>(getScope(MemOp->getSyncScopeID()));
@@ -700,14 +713,24 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
   uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));
   Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, I);
 
-  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
+  bool Result = false;
+  Register ValueReg = I.getOperand(2).getReg();
+  if (NegateOpcode != 0) {
+    // Translation with negative value operand is requested
+    Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    Result |= selectUnOpWithSrc(TmpReg, ResType, I, ValueReg, NegateOpcode);
+    ValueReg = TmpReg;
+  }
+
+  Result |= BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
       .addUse(Ptr)
       .addUse(ScopeReg)
       .addUse(MemSemReg)
-      .addUse(I.getOperand(2).getReg())
+      .addUse(ValueReg)
       .constrainAllUses(TII, TRI, RBI);
+  return Result;
 }
 
 bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 061bc967423712..011a550a7b3d9b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -125,6 +125,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   auto allIntScalars = {s8, s16, s32, s64};
 
+  auto allFloatScalars = {s16, s32, s64};
+
   auto allFloatScalarsAndVectors = {
       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
@@ -205,6 +207,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                                G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
       .legalForCartesianProduct(allIntScalars, allWritablePtrs);
 
+  getActionDefinitionsBuilder(
+      {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
+      .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
+
   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
       .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
index 0fa05d377d9f10..8c6649bf628265 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp
@@ -23,6 +23,8 @@ using namespace llvm;
 void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
                              SPIRV::ModuleAnalysisInfo *MAI) const {
   OutMI.setOpcode(MI->getOpcode());
+  // Propagate previously set flags
+  OutMI.setFlags(MI->getAsmPrinterFlags());
   const MachineFunction *MF = MI->getMF();
   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
     const MachineOperand &MO = MI->getOperand(i);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index a18aae1761c834..a716732ea57056 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -704,6 +704,67 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
     Reqs.addRequirements(SPIRV::Capability::ImageBasic);
 }
 
+// Add requirements for handling atomic float instructions
+#define ATOM_FLT_REQ_EXT_MSG(ExtName)                                          \
+  "The atomic float instruction requires the following SPIR-V "                \
+  "extension: SPV_EXT_shader_atomic_float" ExtName
+static void AddAtomicFloatRequirements(const MachineInstr &MI,
+                                       SPIRV::RequirementHandler &Reqs,
+                                       const SPIRVSubtarget &ST) {
+  assert(MI.getOperand(1).isReg() &&
+         "Expect register operand in atomic float instruction");
+  Register TypeReg = MI.getOperand(1).getReg();
+  SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
+  if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
+    report_fatal_error("Result type of an atomic float instruction must be a "
+                       "floating-point type scalar");
+
+  unsigned BitWidth = TypeDef->getOperand(1).getImm();
+  unsigned Op = MI.getOpcode();
+  if (Op == SPIRV::OpAtomicFAddEXT) {
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
+      report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
+    Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
+    switch (BitWidth) {
+    case 16:
+      if (!ST.canUseExtension(
+              SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
+        report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
+      Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+      break;
+    case 32:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
+      break;
+    case 64:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
+      break;
+    default:
+      report_fatal_error(
+          "Unexpected floating-point type width in atomic float instruction");
+    }
+  } else {
+    if (!ST.canUseExtension(
+            SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
+      report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
+    Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
+    switch (BitWidth) {
+    case 16:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+      break;
+    case 32:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
+      break;
+    case 64:
+      Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
+      break;
+    default:
+      report_fatal_error(
+          "Unexpected floating-point type width in atomic float instruction");
+    }
+  }
+}
+
 void addInstrRequirements(const MachineInstr &MI,
                           SPIRV::RequirementHandler &Reqs,
                           const SPIRVSubtarget &ST) {
@@ -976,6 +1037,11 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
     }
     break;
+  case SPIRV::OpAtomicFAddEXT:
+  case SPIRV::OpAtomicFMinEXT:
+  case SPIRV::OpAtomicFMaxEXT:
+    AddAtomicFloatRequirements(MI, Reqs, ST);
+    break;
   default:
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index effedc2f17d351..f8fd6309ce0e72 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -31,12 +31,24 @@ cl::list<SPIRV::Extension::Extension> Extensions(
     "spirv-extensions", cl::desc("SPIR-V extensions"), cl::ZeroOrMore,
     cl::Hidden,
     cl::values(
+        clEnumValN(SPIRV::Extension::SPV_EXT_shader_atomic_float_add,
+                   "SPV_EXT_shader_atomic_float_add",
+                   "Adds atomic add instruction on floating-point numbers."),
+        clEnumValN(
+            SPIRV::Extension::SPV_EXT_shader_atomic_float16_add,
+            "SPV_EXT_shader_atomic_float16_add",
+            "Extends the SPV_EXT_shader_atomic_float_add extension to support "
+            "atomically adding to 16-bit floating-point numbers in memory."),
+        clEnumValN(
+            SPIRV::E...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81683


More information about the llvm-commits mailing list