[llvm] [SPIR-V] Support cl_ext_float_atomics and fix errors in definition of atomic_fetch_*_explicit builtins (PR #96767)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 26 06:51:37 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR:
* supports cl_ext_float_atomics by mapping atomic_fetch_add and atomic_fetch_sub applied to float arguments to the corresponding instructions from SPV_EXT_shader_atomic_float*_add, and
* fix errors in definition of atomic_fetch_*_explicit builtins by fixing a valid number of arguments.
---
Full diff: https://github.com/llvm/llvm-project/pull/96767.diff
3 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+26-3)
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+5-7)
- (modified) llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll (+15)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 71168d2d7dacd..0b93a4d85eedf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -765,7 +765,7 @@ static bool buildAtomicCompareExchangeInst(
return true;
}
-/// Helper function for building an atomic load instruction.
+/// Helper function for building atomic instructions.
static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -790,13 +790,36 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
Semantics, MIRBuilder, GR);
MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+ Register ValueReg = Call->Arguments[1];
+ Register ValueTypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+ // support cl_ext_float_atomics
+ if (Call->ReturnType->getOpcode() == SPIRV::OpTypeFloat) {
+ if (Opcode == SPIRV::OpAtomicIAdd) {
+ Opcode = SPIRV::OpAtomicFAddEXT;
+ } else if (Opcode == SPIRV::OpAtomicISub) {
+ // Translate OpAtomicISub applied to a floating type argument to
+ // OpAtomicFAddEXT with the negative value operand
+ Opcode = SPIRV::OpAtomicFAddEXT;
+ Register NegValueReg =
+ MRI->createGenericVirtualRegister(MRI->getType(ValueReg));
+ MRI->setRegClass(NegValueReg, &SPIRV::IDRegClass);
+ GR->assignSPIRVTypeToVReg(Call->ReturnType, NegValueReg,
+ MIRBuilder.getMF());
+ MIRBuilder.buildInstr(TargetOpcode::G_FNEG)
+ .addDef(NegValueReg)
+ .addUse(ValueReg);
+ insertAssignInstr(NegValueReg, nullptr, Call->ReturnType, GR, MIRBuilder,
+ MIRBuilder.getMF().getRegInfo());
+ ValueReg = NegValueReg;
+ }
+ }
MIRBuilder.buildInstr(Opcode)
.addDef(Call->ReturnRegister)
- .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+ .addUse(ValueTypeReg)
.addUse(PtrRegister)
.addUse(ScopeRegister)
.addUse(MemSemanticsReg)
- .addUse(Call->Arguments[1]);
+ .addUse(ValueReg);
return true;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 4bd1104103664..f0d480c6a13f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -546,11 +546,11 @@ defm : DemangledNativeBuiltin<"atomic_fetch_sub", OpenCL_std, Atomic, 2, 4, OpAt
defm : DemangledNativeBuiltin<"atomic_fetch_or", OpenCL_std, Atomic, 2, 4, OpAtomicOr>;
defm : DemangledNativeBuiltin<"atomic_fetch_xor", OpenCL_std, Atomic, 2, 4, OpAtomicXor>;
defm : DemangledNativeBuiltin<"atomic_fetch_and", OpenCL_std, Atomic, 2, 4, OpAtomicAnd>;
-defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicIAdd>;
-defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicISub>;
-defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicOr>;
-defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicXor>;
-defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicAnd>;
+defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicIAdd>;
+defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicISub>;
+defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicOr>;
+defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicXor>;
+defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicAnd>;
defm : DemangledNativeBuiltin<"atomic_flag_test_and_set", OpenCL_std, Atomic, 1, 1, OpAtomicFlagTestAndSet>;
defm : DemangledNativeBuiltin<"__spirv_AtomicFlagTestAndSet", OpenCL_std, Atomic, 3, 3, OpAtomicFlagTestAndSet>;
defm : DemangledNativeBuiltin<"atomic_flag_test_and_set_explicit", OpenCL_std, Atomic, 2, 3, OpAtomicFlagTestAndSet>;
@@ -1037,8 +1037,6 @@ multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<
defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
-// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
-// on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)
//===----------------------------------------------------------------------===//
// Class defining a sub group builtin that should be translated into a
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll
index 4fb99d9dfc76f..3fd5bd65853f7 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll
@@ -11,13 +11,18 @@
; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstant %[[TyFP32]] 0
; CHECK-DAG: %[[Const42:[0-9]+]] = OpConstant %[[TyFP32]] 42
; CHECK-DAG: %[[ScopeDevice:[0-9]+]] = OpConstant %[[TyInt32]] 1
+; CHECK-DAG: %[[ScopeWorkgroup:[0-9]+]] = OpConstant %[[TyInt32]] 2
; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16
+; CHECK-DAG: %[[WorkgroupMemory:[0-9]+]] = OpConstant %[[TyInt32]] 512
; CHECK-DAG: %[[TyFP32Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyFP32]]
; CHECK-DAG: %[[DblPtr:[0-9]+]] = OpVariable %[[TyFP32Ptr]] {{[a-zA-Z]+}} %[[Const0]]
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42]]
; CHECK: %[[Const42Neg:[0-9]+]] = OpFNegate %[[TyFP32]] %[[Const42]]
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42Neg]]
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42]]
+; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeWorkgroup]] %[[WorkgroupMemory]] %[[Const42]]
+; CHECK: %[[Neg42:[0-9]+]] = OpFNegate %[[TyFP32]] %[[Const42]]
+; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeWorkgroup]] %[[WorkgroupMemory]] %[[Neg42]]
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64"
@@ -39,5 +44,15 @@ entry:
declare dso_local spir_func float @_Z21__spirv_AtomicFAddEXT(ptr addrspace(1), i32, i32, float)
+define dso_local spir_func void @test3() local_unnamed_addr {
+entry:
+ %r1 = tail call spir_func float @_Z25atomic_fetch_add_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1) @f, float 42.000000e+00, i32 0)
+ %r2 = tail call spir_func float @_Z25atomic_fetch_sub_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1) @f, float 42.000000e+00, i32 0)
+ ret void
+}
+
+declare spir_func float @_Z25atomic_fetch_add_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1), float, i32)
+declare spir_func float @_Z25atomic_fetch_sub_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1), float, i32)
+
!llvm.module.flags = !{!0}
!0 = !{i32 1, !"wchar_size", i32 4}
``````````
</details>
https://github.com/llvm/llvm-project/pull/96767
More information about the llvm-commits
mailing list