[llvm] [SPIR-V] Implement support of the SPV_EXT_arithmetic_fence SPIRV extension (PR #110500)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 30 05:30:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR implements support of the SPV_EXT_arithmetic_fence SPIRV extension: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/EXT/SPV_EXT_arithmetic_fence.html.

---
Full diff: https://github.com/llvm/llvm-project/pull/110500.diff


7 Files Affected:

- (modified) llvm/docs/SPIRVUsage.rst (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+10) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+2) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/arithmetic-fence.ll (+60) 


``````````diff
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index bb12b05246afba..38c41b0fad12e6 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -147,6 +147,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
      - Adds atomic add instruction on floating-point numbers.
    * - ``SPV_EXT_shader_atomic_float_min_max``
      - Adds atomic min and max instruction on floating-point numbers.
+   * - ``SPV_EXT_arithmetic_fence``
+     - Adds an instruction that prevents fast-math optimizations between its argument and the expression that contains it.
    * - ``SPV_INTEL_arbitrary_precision_integers``
      - Allows generating arbitrary width integer types.
    * - ``SPV_INTEL_bfloat16_conversion``
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 127585f85915fb..a9efd5448fdf62 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -28,6 +28,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
         {"SPV_EXT_shader_atomic_float_min_max",
          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
+        {"SPV_EXT_arithmetic_fence",
+         SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
         {"SPV_INTEL_arbitrary_precision_integers",
          SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
         {"SPV_INTEL_cache_controls",
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index fe45be4daba650..1d63b5b69c641b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -878,3 +878,7 @@ def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
                   "$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
 def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
                   "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
+
+// SPV_EXT_arithmetic_fence
+def OpArithmeticFenceEXT: Op<6145, (outs ID:$res), (ins TYPE:$type, ID:$target),
+                  "$res = OpArithmeticFenceEXT $type $target">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 2f7efbdc81f845..7dda99e4bed08d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2511,6 +2511,16 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
           .addUse(I.getOperand(2).getReg())
           .addUse(I.getOperand(3).getReg());
     break;
+  case Intrinsic::arithmetic_fence:
+    if (STI.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
+      BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpArithmeticFenceEXT))
+          .addDef(ResVReg)
+          .addUse(GR.getSPIRVTypeID(ResType))
+          .addUse(I.getOperand(2).getReg());
+    else
+      BuildMI(BB, I, I.getDebugLoc(), TII.get(TargetOpcode::COPY), ResVReg)
+          .addUse(I.getOperand(2).getReg());
+    break;
   case Intrinsic::spv_thread_id:
     return selectSpvThreadId(ResVReg, ResType, I);
   case Intrinsic::spv_fdot:
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index fa71223a341b16..8908d8965b67c5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1200,6 +1200,14 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     break;
+  case SPIRV::OpArithmeticFenceEXT:
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
+      report_fatal_error("OpArithmeticFenceEXT requires the "
+                         "following SPIR-V extension: SPV_EXT_arithmetic_fence",
+                         false);
+    Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
+    Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
+    break;
   default:
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 23cd32eff45d5b..56ff253d28796c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -303,6 +303,7 @@ defm SPV_INTEL_cache_controls : ExtensionOperand<108>;
 defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>;
 defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
 defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;
+defm SPV_EXT_arithmetic_fence : ExtensionOperand<112>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -480,6 +481,7 @@ defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_
 defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
 defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>;
 defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>;
+defm ArithmeticFenceEXT : CapabilityOperand<6144, 0, 0, [SPV_EXT_arithmetic_fence], []>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/arithmetic-fence.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/arithmetic-fence.ll
new file mode 100644
index 00000000000000..5d8f547054dbf8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/arithmetic-fence.ll
@@ -0,0 +1,60 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-linux %s -o - | FileCheck %s --check-prefixes=CHECK-NOEXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-linux %s -o - --spirv-ext=+SPV_EXT_arithmetic_fence | FileCheck %s --check-prefixes=CHECK-EXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-NOEXT-NO: OpCapability ArithmeticFenceEXT
+; CHECK-NOEXT-NO: OpExtension "SPV_EXT_arithmetic_fence"
+; CHECK-NOEXT: OpFunction
+; CHECK-NOEXT: OpFMul
+; CHECK-NOEXT: OpFAdd
+; CHECK-NOEXT-NO: OpArithmeticFenceEXT
+; CHECK-NOEXT: OpFunction
+; CHECK-NOEXT-NO: OpArithmeticFenceEXT
+; CHECK-NOEXT: OpFunction
+; CHECK-NOEXT-NO: OpArithmeticFenceEXT
+
+; CHECK-EXT: OpCapability ArithmeticFenceEXT
+; CHECK-EXT: OpExtension "SPV_EXT_arithmetic_fence"
+; CHECK-EXT: OpFunction
+; CHECK-EXT: [[R1:%.*]] = OpFMul [[I32Ty:%.*]] %[[#]] %[[#]]
+; CHECK-EXT: [[R2:%.*]] = OpArithmeticFenceEXT [[I32Ty]] [[R1]]
+; CHECK-EXT: %[[#]] = OpFAdd [[I32Ty]] [[R2]] %[[#]]
+; CHECK-EXT: OpFunction
+; CHECK-EXT: [[R3:%.*]] = OpFAdd [[I64Ty:%.*]] [[A1:%.*]] [[A1]]
+; CHECK-EXT: [[R4:%.*]] = OpArithmeticFenceEXT [[I64Ty]] [[R3]]
+; CHECK-EXT: [[R5:%.*]] = OpFAdd [[I64Ty]] [[A1]] [[A1]]
+; CHECK-EXT: %[[#]] = OpFAdd [[I64Ty]] [[R4]] [[R5]]
+; CHECK-EXT: OpFunction
+; CHECK-EXT: [[R6:%.*]] = OpFAdd [[I32VecTy:%.*]] [[A2:%.*]] [[A2]]
+; CHECK-EXT: [[R7:%.*]] = OpArithmeticFenceEXT [[I32VecTy]] [[R6]]
+; CHECK-EXT: [[R8:%.*]] = OpFAdd [[I32VecTy]] [[A2]] [[A2]]
+; CHECK-EXT: %[[#]] = OpFAdd [[I32VecTy]] [[R7]] [[R8]]
+
+define float @f1(float %a, float %b, float %c) {
+  %mul = fmul fast float %b, %a
+  %tmp = call float @llvm.arithmetic.fence.f32(float %mul)
+  %add = fadd fast float %tmp, %c
+  ret float %add
+}
+
+define double @f2(double %a) {
+  %1 = fadd fast double %a, %a
+  %t = call double @llvm.arithmetic.fence.f64(double %1)
+  %2 = fadd fast double %a, %a
+  %3 = fadd fast double %t, %2
+  ret double %3
+}
+
+define <2 x float> @f3(<2 x float> %a) {
+  %1 = fadd fast <2 x float> %a, %a
+  %t = call <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float> %1)
+  %2 = fadd fast <2 x float> %a, %a
+  %3 = fadd fast <2 x float> %t, %2
+  ret <2 x float> %3
+}
+
+declare float @llvm.arithmetic.fence.f32(float)
+declare double @llvm.arithmetic.fence.f64(double)
+declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/110500


More information about the llvm-commits mailing list