[llvm] [NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} (PR #114977)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 11:50:52 PST 2024


================
@@ -3917,3 +3903,52 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm_any_zero : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(+0.0) | Imm.isExactlyValue(-0.0);
+}]>;
+
+def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
+def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
+
+// Perform substitution if fma only has one use, and also if instruction has
+// nnan instruction flag or if the TM has NoNaNsFPMath
+def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
+                                  (fma node:$a, node:$b, node:$c), [{
+  return N->hasOneUse() &&
+    (N->getFlags().hasNoNaNs() || TM.Options.NoNaNsFPMath);
+}]>;
+
+multiclass FMARELU<RegisterClass RC, string OpVecString = "">  {
+  def _F16# !toupper(OpVecString) : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+    !strconcat("fma.rn.relu.f16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+  // FTZ variants are only supported by fp16, not bf16
+  def _F16# !toupper(OpVecString) #_FTZ : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+    !strconcat("fma.rn.ftz.relu.f16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+    Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+  def _BF16# !toupper(OpVecString) : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+    !strconcat("fma.rn.relu.bf16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+    Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
+}
+
+defm FMARELU : FMARELU<Int16Regs>;
+defm FMARELU : FMARELU<Int32Regs, "x2">;
----------------
Artem-B wrote:

Nit: I'd prefer constructing the name up-front here, rather than deriving the final record name from the argument string.

E.g:
```
multiclass FMARELU<RegisterClass RC, string OpVecString = "">  {
  def _F16 : ...
  // FTZ variants are only supported by fp16, not bf16
  def _F16_FTZ : ...
  def _BF16: ...
}

defm FMARELU : FMARELU<Int16Regs>;
defm FMARELUv2 : FMARELU<Int32Regs, "x2">;
```

https://github.com/llvm/llvm-project/pull/114977


More information about the llvm-commits mailing list