[llvm] [NVPTX] Add patterns for fma.relu.{f16|bf16} (PR #114977)

Hugh Delaney via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 5 07:00:49 PST 2024


https://github.com/hdelan updated https://github.com/llvm/llvm-project/pull/114977

>From 2b9441b292309dda9107ef8edb08567c229aeaf7 Mon Sep 17 00:00:00 2001
From: Hugh Delaney <hugh.delaney at codeplay.com>
Date: Tue, 5 Nov 2024 12:25:41 +0000
Subject: [PATCH] Add patterns for fma.relu.{f16|bf16}

Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and
bf16 types.
---
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 19 ++++++
 llvm/test/CodeGen/NVPTX/fma-relu.ll     | 77 +++++++++++++++++++++++++
 2 files changed, 96 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/fma-relu.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5f6cba397c5352..c974967be8e5ad 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3917,3 +3917,22 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm0 : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(+0.0);
+}]>;
+
+def FMARELU_F16 :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+            "fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>;
+def FMARELU_BF16 :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+            "fma.rn.relu.bf16 \t$dst, $a, $b, $c;", []>;
+
+def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+  (FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[useFP16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
+
+def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+  (FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[hasBF16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu.ll b/llvm/test/CodeGen/NVPTX/fma-relu.ll
new file mode 100644
index 00000000000000..6c340ef9d53015
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-relu.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %}
+
+define half @fma_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_f16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_f16_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_f16_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = call half @llvm.fma.f16(half %a, half %b, half %c)
+  %2 = fcmp ogt half %1, 0.0
+  %3 = select i1 %2, half %1, half 0.0
+  ret half %3
+}
+
+define half @fma_f16_expanded(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16_expanded(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_f16_expanded_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_f16_expanded_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_f16_expanded_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = fmul half %a, %b
+  %2 = fadd half %1, %c
+  %3 = fcmp ogt half %2, 0.0
+  %4 = select i1 %3, half %2, half 0.0
+  ret half %4
+}
+
+define bfloat @fma_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
+  %2 = fcmp ogt bfloat %1, 0.0
+  %3 = select i1 %2, bfloat %1, bfloat 0.0
+  ret bfloat %3
+}
+
+define bfloat @fma_bf16_expanded(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16_expanded(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_expanded_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_expanded_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_expanded_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = fmul bfloat %a, %b
+  %2 = fadd bfloat %1, %c
+  %3 = fcmp ogt bfloat %2, 0.0
+  %4 = select i1 %3, bfloat %2, bfloat 0.0
+  ret bfloat %4
+}



More information about the llvm-commits mailing list