[llvm] [SPIR-V] Support cl_ext_float_atomics and fix errors in definition of atomic_fetch_*_explicit builtins (PR #96767)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 26 06:51:37 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR:
* supports cl_ext_float_atomics by mapping atomic_fetch_add and atomic_fetch_sub applied to float arguments to the corresponding instructions from SPV_EXT_shader_atomic_float*_add, and
* fix errors in definition of atomic_fetch_*_explicit builtins by fixing a valid number of arguments.

---
Full diff: https://github.com/llvm/llvm-project/pull/96767.diff


3 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+26-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+5-7) 
- (modified) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll (+15) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 71168d2d7dacd..0b93a4d85eedf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -765,7 +765,7 @@ static bool buildAtomicCompareExchangeInst(
   return true;
 }
 
-/// Helper function for building an atomic load instruction.
+/// Helper function for building atomic instructions.
 static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
                                MachineIRBuilder &MIRBuilder,
                                SPIRVGlobalRegistry *GR) {
@@ -790,13 +790,36 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
   MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
                                          Semantics, MIRBuilder, GR);
   MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+  Register ValueReg = Call->Arguments[1];
+  Register ValueTypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+  // support cl_ext_float_atomics
+  if (Call->ReturnType->getOpcode() == SPIRV::OpTypeFloat) {
+    if (Opcode == SPIRV::OpAtomicIAdd) {
+      Opcode = SPIRV::OpAtomicFAddEXT;
+    } else if (Opcode == SPIRV::OpAtomicISub) {
+      // Translate OpAtomicISub applied to a floating type argument to
+      // OpAtomicFAddEXT with the negative value operand
+      Opcode = SPIRV::OpAtomicFAddEXT;
+      Register NegValueReg =
+          MRI->createGenericVirtualRegister(MRI->getType(ValueReg));
+      MRI->setRegClass(NegValueReg, &SPIRV::IDRegClass);
+      GR->assignSPIRVTypeToVReg(Call->ReturnType, NegValueReg,
+                                MIRBuilder.getMF());
+      MIRBuilder.buildInstr(TargetOpcode::G_FNEG)
+          .addDef(NegValueReg)
+          .addUse(ValueReg);
+      insertAssignInstr(NegValueReg, nullptr, Call->ReturnType, GR, MIRBuilder,
+                        MIRBuilder.getMF().getRegInfo());
+      ValueReg = NegValueReg;
+    }
+  }
   MIRBuilder.buildInstr(Opcode)
       .addDef(Call->ReturnRegister)
-      .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+      .addUse(ValueTypeReg)
       .addUse(PtrRegister)
       .addUse(ScopeRegister)
       .addUse(MemSemanticsReg)
-      .addUse(Call->Arguments[1]);
+      .addUse(ValueReg);
   return true;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 4bd1104103664..f0d480c6a13f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -546,11 +546,11 @@ defm : DemangledNativeBuiltin<"atomic_fetch_sub", OpenCL_std, Atomic, 2, 4, OpAt
 defm : DemangledNativeBuiltin<"atomic_fetch_or", OpenCL_std, Atomic, 2, 4, OpAtomicOr>;
 defm : DemangledNativeBuiltin<"atomic_fetch_xor", OpenCL_std, Atomic, 2, 4, OpAtomicXor>;
 defm : DemangledNativeBuiltin<"atomic_fetch_and", OpenCL_std, Atomic, 2, 4, OpAtomicAnd>;
-defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicIAdd>;
-defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicISub>;
-defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicOr>;
-defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicXor>;
-defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicAnd>;
+defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicIAdd>;
+defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicISub>;
+defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicOr>;
+defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicXor>;
+defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicAnd>;
 defm : DemangledNativeBuiltin<"atomic_flag_test_and_set", OpenCL_std, Atomic, 1, 1, OpAtomicFlagTestAndSet>;
 defm : DemangledNativeBuiltin<"__spirv_AtomicFlagTestAndSet", OpenCL_std, Atomic, 3, 3, OpAtomicFlagTestAndSet>;
 defm : DemangledNativeBuiltin<"atomic_flag_test_and_set_explicit", OpenCL_std, Atomic, 2, 3, OpAtomicFlagTestAndSet>;
@@ -1037,8 +1037,6 @@ multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<
 defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
 defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
 defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
-// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
-//       on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)
 
 //===----------------------------------------------------------------------===//
 // Class defining a sub group builtin that should be translated into a
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll
index 4fb99d9dfc76f..3fd5bd65853f7 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll
@@ -11,13 +11,18 @@
 ; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstant %[[TyFP32]] 0
 ; CHECK-DAG: %[[Const42:[0-9]+]] = OpConstant %[[TyFP32]] 42
 ; CHECK-DAG: %[[ScopeDevice:[0-9]+]] = OpConstant %[[TyInt32]] 1
+; CHECK-DAG: %[[ScopeWorkgroup:[0-9]+]] = OpConstant %[[TyInt32]] 2
 ; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16
+; CHECK-DAG: %[[WorkgroupMemory:[0-9]+]] = OpConstant %[[TyInt32]] 512
 ; CHECK-DAG: %[[TyFP32Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyFP32]]
 ; CHECK-DAG: %[[DblPtr:[0-9]+]] = OpVariable %[[TyFP32Ptr]] {{[a-zA-Z]+}} %[[Const0]]
 ; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42]]
 ; CHECK: %[[Const42Neg:[0-9]+]] = OpFNegate %[[TyFP32]] %[[Const42]]
 ; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42Neg]]
 ; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42]]
+; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeWorkgroup]] %[[WorkgroupMemory]] %[[Const42]]
+; CHECK: %[[Neg42:[0-9]+]] = OpFNegate %[[TyFP32]] %[[Const42]]
+; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeWorkgroup]] %[[WorkgroupMemory]] %[[Neg42]]
 
 target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
 target triple = "spir64"
@@ -39,5 +44,15 @@ entry:
 
 declare dso_local spir_func float @_Z21__spirv_AtomicFAddEXT(ptr addrspace(1), i32, i32, float)
 
+define dso_local spir_func void @test3() local_unnamed_addr {
+entry:
+  %r1 = tail call spir_func float @_Z25atomic_fetch_add_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1) @f, float 42.000000e+00, i32 0)
+  %r2 = tail call spir_func float @_Z25atomic_fetch_sub_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1) @f, float 42.000000e+00, i32 0)
+  ret void
+}
+
+declare spir_func float @_Z25atomic_fetch_add_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1), float, i32)
+declare spir_func float @_Z25atomic_fetch_sub_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1), float, i32)
+
 !llvm.module.flags = !{!0}
 !0 = !{i32 1, !"wchar_size", i32 4}

``````````

</details>


https://github.com/llvm/llvm-project/pull/96767


More information about the llvm-commits mailing list