[llvm] [NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} (PR #114977)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 12 11:50:52 PST 2024
================
@@ -3917,3 +3903,52 @@ def atomic_thread_fence_seq_cst_cta :
def atomic_thread_fence_acq_rel_cta :
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm_any_zero : FPImmLeaf<fAny, [{
+ return Imm.isExactlyValue(+0.0) | Imm.isExactlyValue(-0.0);
+}]>;
+
+def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
+def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
+
+// Perform substitution if fma only has one use, and also if instruction has
+// nnan instruction flag or if the TM has NoNaNsFPMath
+def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
+ (fma node:$a, node:$b, node:$c), [{
+ return N->hasOneUse() &&
+ (N->getFlags().hasNoNaNs() || TM.Options.NoNaNsFPMath);
+}]>;
+
+multiclass FMARELU<RegisterClass RC, string OpVecString = ""> {
+ def _F16# !toupper(OpVecString) : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+ !strconcat("fma.rn.relu.f16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+ Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+ // FTZ variants are only supported by fp16, not bf16
+ def _F16# !toupper(OpVecString) #_FTZ : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+ !strconcat("fma.rn.ftz.relu.f16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+ Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
+ def _BF16# !toupper(OpVecString) : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+ !strconcat("fma.rn.relu.bf16", OpVecString, "\t$dst, $a, $b, $c;"), []>,
+ Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
+}
+
+defm FMARELU : FMARELU<Int16Regs>;
+defm FMARELU : FMARELU<Int32Regs, "x2">;
----------------
Artem-B wrote:
Nit: I'd prefer constructing the name up-front here, rather than deriving the final record name from the argument string.
E.g:
```
multiclass FMARELU<RegisterClass RC, string OpVecString = ""> {
def _F16 : ...
// FTZ variants are only supported by fp16, not bf16
def _F16_FTZ : ...
def _BF16: ...
}
defm FMARELU : FMARELU<Int16Regs>;
defm FMARELUv2 : FMARELU<Int32Regs, "x2">;
```
https://github.com/llvm/llvm-project/pull/114977
More information about the llvm-commits
mailing list