[llvm] Add patterns for fma.relu.{f16|bf16} (PR #114977)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 5 04:43:39 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Hugh Delaney (hdelan)

<details>
<summary>Changes</summary>

Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types.

---
Full diff: https://github.com/llvm/llvm-project/pull/114977.diff


2 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+16) 
- (added) llvm/test/CodeGen/NVPTX/fma-relu.ll (+77) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5f6cba397c5352..52312fa9afbd7e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3917,3 +3917,19 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm0 : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(+0.0);
+}]>;
+
+def FMARELU :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+            "fma.rn.relu \t$dst, $a, $b, $c;", []>;
+
+def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+  (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[useFP16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
+
+def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+  (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[hasBF16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu.ll b/llvm/test/CodeGen/NVPTX/fma-relu.ll
new file mode 100644
index 00000000000000..6c340ef9d53015
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-relu.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %}
+
+define half @fma_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_f16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_f16_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_f16_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = call half @llvm.fma.f16(half %a, half %b, half %c)
+  %2 = fcmp ogt half %1, 0.0
+  %3 = select i1 %2, half %1, half 0.0
+  ret half %3
+}
+
+define half @fma_f16_expanded(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16_expanded(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_f16_expanded_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_f16_expanded_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_f16_expanded_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = fmul half %a, %b
+  %2 = fadd half %1, %c
+  %3 = fcmp ogt half %2, 0.0
+  %4 = select i1 %3, half %2, half 0.0
+  ret half %4
+}
+
+define bfloat @fma_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
+  %2 = fcmp ogt bfloat %1, 0.0
+  %3 = select i1 %2, bfloat %1, bfloat 0.0
+  ret bfloat %3
+}
+
+define bfloat @fma_bf16_expanded(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16_expanded(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_expanded_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_expanded_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_expanded_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = fmul bfloat %a, %b
+  %2 = fadd bfloat %1, %c
+  %3 = fcmp ogt bfloat %2, 0.0
+  %4 = select i1 %3, bfloat %2, bfloat 0.0
+  ret bfloat %4
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/114977


More information about the llvm-commits mailing list