[llvm] [NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} (PR #114977)

Hugh Delaney via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 02:31:44 PST 2024


================
@@ -3917,3 +3904,86 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm_any_zero : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(+0.0) | Imm.isExactlyValue(-0.0);
+}]>;
+
+def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
+def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
+
+// Patterns will only be used if FMA has a single use, in order to mitigate register pressure
+def NVPTX_fma_oneuse : PatFrag<(ops node:$a, node:$b, node:$c),
+                                  (fma node:$a, node:$b, node:$c), [{
+  return N->hasOneUse();
+}]>;
+// We can use the instruction flag nnan instead of relying on a function attribute
+def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
+                                  (fma node:$a, node:$b, node:$c), [{
+  return N->hasOneUse() && N->getFlags().hasNoNaNs();
+}]>;
+
+def FMARELU_F16 :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+    "fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+def FMARELU_BF16 :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+    "fma.rn.relu.bf16 \t$dst, $a, $b, $c;", []>,
+    Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
+def FMARELU_F16X2 :
+  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
+    "fma.rn.relu.f16x2 \t$dst, $a, $b, $c;", []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+def FMARELU_BF16X2 :
+  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
+    "fma.rn.relu.bf16x2 \t$dst, $a, $b, $c;", []>,
+    Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
+// FTZ variants are only supported by fp16, not bf16
+def FMARELU_F16_FTZ :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+    "fma.rn.ftz.relu.f16 \t$dst, $a, $b, $c;", []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+def FMARELU_F16X2_FTZ :
+  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
+    "fma.rn.ftz.relu.f16x2 \t$dst, $a, $b, $c;", []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+
+// Don't use function attributes, use instruction flag instead
+// FTZ variants are only supported by fp16, not bf16
+def : Pat<(f16 (fmaxnum (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
+  (FMARELU_F16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[doF32FTZ]>;
+def : Pat<(v2f16 (fmaxnum (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
+  (FMARELU_F16X2_FTZ Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>,
+  Requires<[doF32FTZ]>;
+// No FTZ
+def : Pat<(f16 (fmaxnum (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
+  (FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
+def : Pat<(bf16 (fmaxnum (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
+  (FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
+def : Pat<(v2f16 (fmaxnum (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
+  (FMARELU_F16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
+def : Pat<(v2bf16 (fmaxnum (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2bf16)),
+  (FMARELU_BF16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
+
+// Use function attributes for noNaNsFPMath, instead of instruction flag
----------------
hdelan wrote:

@Artem-B do you have an opinion on this?

https://github.com/llvm/llvm-project/pull/114977


More information about the llvm-commits mailing list