[llvm] [NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} (PR #114977)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 13 11:33:57 PST 2024


================
@@ -3917,3 +3903,52 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm_any_zero : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(+0.0) | Imm.isExactlyValue(-0.0);
+}]>;
+
+def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
+def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
+
+// Perform substitution if fma only has one use, and also if instruction has
+// nnan instruction flag or if the TM has NoNaNsFPMath
+def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
+                                  (fma node:$a, node:$b, node:$c), [{
+  return N->hasOneUse() &&
+    (N->getFlags().hasNoNaNs() || TM.Options.NoNaNsFPMath);
+}]>;
+
+multiclass FMARELU<RegisterClass RC, string OpVecString = "">  {
+  def _F16# !toupper(OpVecString) : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+    !strconcat("fma.rn.relu.f16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+  // FTZ variants are only supported by fp16, not bf16
+  def _F16# !toupper(OpVecString) #_FTZ : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+    !strconcat("fma.rn.ftz.relu.f16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+  def _BF16# !toupper(OpVecString) : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+    !strconcat("fma.rn.relu.bf16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+    Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
+}
+
+defm FMARELU : FMARELU<Int16Regs>;
+defm FMARELU : FMARELU<Int32Regs, "x2">;
----------------
Artem-B wrote:

I see. My concern was mostly this seemingly duplicate `defm FMARELU`, which is what got me to stop and poke at it. 

So, how can we express it all better. Ideally we want to have just one point to parametrize the multiclass.

We have 3 instruction variants, two vector sizes. From that we derive register size.
Considering that we're talking only about 6 items, there's not much point to do anything particularly complicated, as it would not improve anything over just enumerating things.

Actually, that may be the most concise way to do it -- just define a simple class for one instruction record that can parametrize ftz/type/regtype, and then define each of the 6 variants we have. You can see an example of this approach here:
https://github.com/llvm/llvm-project/blob/cb9481dbf902adc349757eca12a0a09396dc4a23/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td#L1222

(side-note: FNEF_F16 and FNEG_BF16 could be further collapsed into a single class)

I think the net result will be shorter than what we have now.

https://github.com/llvm/llvm-project/pull/114977


More information about the llvm-commits mailing list