[llvm] [SPIR-V] Implement support of the SPV_EXT_arithmetic_fence SPIRV extension (PR #110500)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 30 05:30:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR implements support of the SPV_EXT_arithmetic_fence SPIRV extension: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/EXT/SPV_EXT_arithmetic_fence.html.
---
Full diff: https://github.com/llvm/llvm-project/pull/110500.diff
7 Files Affected:
- (modified) llvm/docs/SPIRVUsage.rst (+2)
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+4)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+10)
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+8)
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+2)
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/arithmetic-fence.ll (+60)
``````````diff
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index bb12b05246afba..38c41b0fad12e6 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -147,6 +147,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds atomic add instruction on floating-point numbers.
* - ``SPV_EXT_shader_atomic_float_min_max``
- Adds atomic min and max instruction on floating-point numbers.
+ * - ``SPV_EXT_arithmetic_fence``
+ - Adds an instruction that prevents fast-math optimizations between its argument and the expression that contains it.
* - ``SPV_INTEL_arbitrary_precision_integers``
- Allows generating arbitrary width integer types.
* - ``SPV_INTEL_bfloat16_conversion``
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 127585f85915fb..a9efd5448fdf62 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -28,6 +28,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
{"SPV_EXT_shader_atomic_float_min_max",
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
+ {"SPV_EXT_arithmetic_fence",
+ SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
{"SPV_INTEL_arbitrary_precision_integers",
SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
{"SPV_INTEL_cache_controls",
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index fe45be4daba650..1d63b5b69c641b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -878,3 +878,7 @@ def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
"$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
"$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
+
+// SPV_EXT_arithmetic_fence
+def OpArithmeticFenceEXT: Op<6145, (outs ID:$res), (ins TYPE:$type, ID:$target),
+ "$res = OpArithmeticFenceEXT $type $target">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 2f7efbdc81f845..7dda99e4bed08d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2511,6 +2511,16 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg());
break;
+ case Intrinsic::arithmetic_fence:
+ if (STI.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpArithmeticFenceEXT))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(2).getReg());
+ else
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(TargetOpcode::COPY), ResVReg)
+ .addUse(I.getOperand(2).getReg());
+ break;
case Intrinsic::spv_thread_id:
return selectSpvThreadId(ResVReg, ResType, I);
case Intrinsic::spv_fdot:
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index fa71223a341b16..8908d8965b67c5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1200,6 +1200,14 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
break;
+ case SPIRV::OpArithmeticFenceEXT:
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
+ report_fatal_error("OpArithmeticFenceEXT requires the "
+ "following SPIR-V extension: SPV_EXT_arithmetic_fence",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
+ Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
+ break;
default:
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 23cd32eff45d5b..56ff253d28796c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -303,6 +303,7 @@ defm SPV_INTEL_cache_controls : ExtensionOperand<108>;
defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>;
defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;
+defm SPV_EXT_arithmetic_fence : ExtensionOperand<112>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -480,6 +481,7 @@ defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_
defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>;
defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>;
+defm ArithmeticFenceEXT : CapabilityOperand<6144, 0, 0, [SPV_EXT_arithmetic_fence], []>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/arithmetic-fence.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/arithmetic-fence.ll
new file mode 100644
index 00000000000000..5d8f547054dbf8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/arithmetic-fence.ll
@@ -0,0 +1,60 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-linux %s -o - | FileCheck %s --check-prefixes=CHECK-NOEXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-linux %s -o - --spirv-ext=+SPV_EXT_arithmetic_fence | FileCheck %s --check-prefixes=CHECK-EXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-NOEXT-NO: OpCapability ArithmeticFenceEXT
+; CHECK-NOEXT-NO: OpExtension "SPV_EXT_arithmetic_fence"
+; CHECK-NOEXT: OpFunction
+; CHECK-NOEXT: OpFMul
+; CHECK-NOEXT: OpFAdd
+; CHECK-NOEXT-NO: OpArithmeticFenceEXT
+; CHECK-NOEXT: OpFunction
+; CHECK-NOEXT-NO: OpArithmeticFenceEXT
+; CHECK-NOEXT: OpFunction
+; CHECK-NOEXT-NO: OpArithmeticFenceEXT
+
+; CHECK-EXT: OpCapability ArithmeticFenceEXT
+; CHECK-EXT: OpExtension "SPV_EXT_arithmetic_fence"
+; CHECK-EXT: OpFunction
+; CHECK-EXT: [[R1:%.*]] = OpFMul [[I32Ty:%.*]] %[[#]] %[[#]]
+; CHECK-EXT: [[R2:%.*]] = OpArithmeticFenceEXT [[I32Ty]] [[R1]]
+; CHECK-EXT: %[[#]] = OpFAdd [[I32Ty]] [[R2]] %[[#]]
+; CHECK-EXT: OpFunction
+; CHECK-EXT: [[R3:%.*]] = OpFAdd [[I64Ty:%.*]] [[A1:%.*]] [[A1]]
+; CHECK-EXT: [[R4:%.*]] = OpArithmeticFenceEXT [[I64Ty]] [[R3]]
+; CHECK-EXT: [[R5:%.*]] = OpFAdd [[I64Ty]] [[A1]] [[A1]]
+; CHECK-EXT: %[[#]] = OpFAdd [[I64Ty]] [[R4]] [[R5]]
+; CHECK-EXT: OpFunction
+; CHECK-EXT: [[R6:%.*]] = OpFAdd [[I32VecTy:%.*]] [[A2:%.*]] [[A2]]
+; CHECK-EXT: [[R7:%.*]] = OpArithmeticFenceEXT [[I32VecTy]] [[R6]]
+; CHECK-EXT: [[R8:%.*]] = OpFAdd [[I32VecTy]] [[A2]] [[A2]]
+; CHECK-EXT: %[[#]] = OpFAdd [[I32VecTy]] [[R7]] [[R8]]
+
+define float @f1(float %a, float %b, float %c) {
+ %mul = fmul fast float %b, %a
+ %tmp = call float @llvm.arithmetic.fence.f32(float %mul)
+ %add = fadd fast float %tmp, %c
+ ret float %add
+}
+
+define double @f2(double %a) {
+ %1 = fadd fast double %a, %a
+ %t = call double @llvm.arithmetic.fence.f64(double %1)
+ %2 = fadd fast double %a, %a
+ %3 = fadd fast double %t, %2
+ ret double %3
+}
+
+define <2 x float> @f3(<2 x float> %a) {
+ %1 = fadd fast <2 x float> %a, %a
+ %t = call <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float> %1)
+ %2 = fadd fast <2 x float> %a, %a
+ %3 = fadd fast <2 x float> %t, %2
+ ret <2 x float> %3
+}
+
+declare float @llvm.arithmetic.fence.f32(float)
+declare double @llvm.arithmetic.fence.f64(double)
+declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/110500
More information about the llvm-commits
mailing list