[llvm] [NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} (PR #114977)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 13 14:33:21 PST 2024


================
@@ -3917,3 +3903,63 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm_positive_zero : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(+0.0);
+}]>;
+
+def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
+def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
+
+// Patterns will only be used if FMA has a single use, in order to mitigate register pressure
+def NVPTX_fma_oneuse : PatFrag<(ops node:$a, node:$b, node:$c),
+                                  (fma node:$a, node:$b, node:$c), [{
+  return N->hasOneUse();
+}]>;
+
+def FMARELU_F16 :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+    "fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+def FMARELU_BF16 :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+    "fma.rn.relu.bf16 \t$dst, $a, $b, $c;", []>,
+    Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
+def FMARELU_F16X2 :
+  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
+    "fma.rn.relu.f16x2 \t$dst, $a, $b, $c;", []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+def FMARELU_BF16X2 :
+  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
+    "fma.rn.relu.bf16x2 \t$dst, $a, $b, $c;", []>,
+    Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
+// FTZ variants are only supported by fp16, not bf16
+def FMARELU_F16_FTZ :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+    "fma.rn.ftz.relu.f16 \t$dst, $a, $b, $c;", []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+def FMARELU_F16X2_FTZ :
+  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
+    "fma.rn.ftz.relu.f16x2 \t$dst, $a, $b, $c;", []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+
+// FTZ variants are only supported by fp16, not bf16
+def : Pat<(f16 (fmaxnum (NVPTX_fma_oneuse Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_positive_zero)),
----------------
arsenm wrote:

We are fixing the definition of fmaxnum/fminnum to require correctly ordered signed zeroes. The fuzzy behavior can be achieved with nsz, so it doesn't make sense to define it this way.

Best to respect this now to reduce the amount of code that needs changing 

https://github.com/llvm/llvm-project/pull/114977


More information about the llvm-commits mailing list