[llvm] [NVPTX] Add patterns for fma.relu.{f16|bf16} (PR #114977)
Hugh Delaney via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 5 11:29:32 PST 2024
================
@@ -3917,3 +3917,22 @@ def atomic_thread_fence_seq_cst_cta :
def atomic_thread_fence_acq_rel_cta :
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm0 : FPImmLeaf<fAny, [{
+ return Imm.isExactlyValue(+0.0);
+}]>;
+
+def FMARELU_F16 :
+ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+ "fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>;
+def FMARELU_BF16 :
+ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+ "fma.rn.relu.bf16 \t$dst, $a, $b, $c;", []>;
+
+def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+ (FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+ Requires<[useFP16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
+
+def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+ (FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+ Requires<[hasBF16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
----------------
hdelan wrote:
Oops good catch. Thanks
https://github.com/llvm/llvm-project/pull/114977
More information about the llvm-commits
mailing list