[llvm] [NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} (PR #114977)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 14 10:35:23 PST 2024


================
@@ -3959,3 +3945,54 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm_any_zero : FPImmLeaf<fAny, [{
+  return Imm.isZero();
+}]>;
+
+def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
+def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
+
+// Perform substitution if fma only has one use, and also if instruction has
+// nnan instruction flag or if the TM has NoNaNsFPMath
+def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
+                                  (fma node:$a, node:$b, node:$c), [{
+  return N->hasOneUse() &&
+    (N->getFlags().hasNoNaNs() || TM.Options.NoNaNsFPMath);
+}]>;
+// fmaxnum will differentiate between signed and unsigned zeros soon, so this
+// PatFrag is for a fmaxnum node with nsz
+def NVPTX_fmaxnum_nsz : PatFrag<(ops node:$a, node:$b),
+                                  (fmaxnum node:$a, node:$b), [{
+  return N->getFlags().hasNoSignedZeros() || TM.Options.NoSignedZerosFPMath;
+}]>;
+
+class FMARELU<RegisterClass RC, string OpVecString, list<Predicate> Preds>
----------------
Artem-B wrote:

Nit: `OpVecString` should probably be called `Instruction` now.

Side note: It looks like we've simplified it to a pretty generic `3-operand instruction` helper class. We could probably use it in other places. It seems to be a fairly common pattern in other back-ends.


https://github.com/llvm/llvm-project/pull/114977


More information about the llvm-commits mailing list