[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 27 11:59:53 PST 2024


https://github.com/peterbell10 updated https://github.com/llvm/llvm-project/pull/121065

>From 299a919684fde942eeed16c4ac4dc3ac60dba74b Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Tue, 24 Dec 2024 19:57:39 +0000
Subject: [PATCH 1/3] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80

SM80 has fma for bfloat16 but not add/mul/sub. Currently these are
just promoted to f32 but we can instead write them in terms of the
fma:
```
FADD(a, b) -> FMA(a, 1.0, b)
FMUL(a, b) -> FMA(a, b, 0.0)
FSUB(a, b) -> FMA(b, -1.0, a)
```

Unfortunately there is no `fma.ftz` so when ftz is enabled, we still
fall back to promotion.

This is also the inverse of some generic DAGCombiner patterns, so
I've had to add checks to avoid it reversing the legalization which
would cause an infinite loop.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  26 ++--
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  73 +++++++++
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h     |   4 +
 llvm/test/CodeGen/NVPTX/atomics-sm90.ll       |  56 ++++---
 llvm/test/CodeGen/NVPTX/bf16-instructions.ll  |  91 ++++--------
 .../test/CodeGen/NVPTX/bf16x2-instructions.ll | 140 +++++-------------
 llvm/test/CodeGen/NVPTX/fma-relu-contract.ll  |  60 ++------
 .../CodeGen/NVPTX/fma-relu-fma-intrinsic.ll   |  50 ++-----
 .../NVPTX/fma-relu-instruction-flag.ll        | 110 +++-----------
 9 files changed, 229 insertions(+), 381 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6cbfef2d238bbe..a50ac311c82869 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17534,10 +17534,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
       return N2;
   }
 
+  const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
+                             !TLI.isOperationLegal(ISD::FADD, VT));
+
   // FIXME: Support splat of constant.
-  if (N0CFP && N0CFP->isExactlyValue(1.0))
+  if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
     return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
-  if (N1CFP && N1CFP->isExactlyValue(1.0))
+  if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
     return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
 
   // Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17569,7 +17572,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
 
   // (fma x, -1, y) -> (fadd (fneg x), y)
   // FIXME: Support splat of constant.
-  if (N1CFP) {
+  if (N1CFP && !PreferFMAAdd) {
     if (N1CFP->isExactlyValue(1.0))
       return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
 
@@ -17579,15 +17582,14 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
       AddToWorklist(RHSNeg.getNode());
       return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
     }
-
-    // fma (fneg x), K, y -> fma x -K, y
-    if (matcher.match(N0, ISD::FNEG) &&
-        (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
-         (N1.hasOneUse() &&
-          !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
-      return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
-                             matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
-    }
+  }
+  // fma (fneg x), K, y -> fma x -K, y
+  if (N1CFP && matcher.match(N0, ISD::FNEG) &&
+      (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+       (N1.hasOneUse() &&
+        !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
+    return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+                           matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
   }
 
   // FIXME: Support splat of constant.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..47f56abae3c056 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -853,6 +853,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
       AddPromotedToType(Op, MVT::bf16, MVT::f32);
   }
 
+  // Lower bf16 add/mul/sub as fma when it avoids promotion
+  for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
+    for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
+      if (getOperationAction(Op, VT) != Legal &&
+          getOperationAction(ISD::FMA, VT) == Legal) {
+        setOperationAction(Op, VT, Custom);
+      }
+    }
+  }
+
   // f16/f16x2 neg was introduced in PTX 60, SM_53.
   const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
                                         STI.getPTXVersion() >= 60 &&
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
   return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
 }
 
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+  EVT VT = N->getValueType(0);
+  EVT NVT = MVT::f32;
+  if (VT.isVector()) {
+    NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+  }
+  SDLoc DL(N);
+  SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+  SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+  SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+  return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FADD(a, b) -> FMA(a, 1.0, b)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto One = DAG.getConstantFP(1.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FSUB(a, b) -> FMA(b, -1.0, a)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+                                   Op->getOperand(0)};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+  // No fma.ftz for bf16, so fall back to promotion
+  if (useF32FTZ(DAG.getMachineFunction())) {
+    return PromoteBinOpToF32(Op.getNode(), DAG);
+  }
+
+  // FMUL(a, b) -> FMA(a, b, 0.0)
+  SDLoc DL(Op);
+  auto VT = Op.getValueType();
+  auto Zero = DAG.getConstantFP(0.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
 SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
                                             SelectionDAG &DAG) const {
   assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2681,6 +2747,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerSTACKSAVE(Op, DAG);
   case ISD::CopyToReg:
     return LowerCopyToReg_128(Op, DAG);
+  case ISD::FADD:
+    return LowerFADD(Op, DAG);
+  case ISD::FSUB:
+    return LowerFSUB(Op, DAG);
+  case ISD::FMUL:
+    return LowerFMUL(Op, DAG);
+
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 4a98fe21b81dc6..b7d32dd5327646 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -279,6 +279,10 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
index f81b785f13225c..67552b95e04915 100644
--- a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
+++ b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
 ; CHECKPTX71-LABEL: test(
 ; CHECKPTX71:       {
 ; CHECKPTX71-NEXT:    .reg .pred %p<5>;
-; CHECKPTX71-NEXT:    .reg .b16 %rs<22>;
+; CHECKPTX71-NEXT:    .reg .b16 %rs<26>;
 ; CHECKPTX71-NEXT:    .reg .b32 %r<4>;
-; CHECKPTX71-NEXT:    .reg .f32 %f<12>;
 ; CHECKPTX71-EMPTY:
 ; CHECKPTX71-NEXT:  // %bb.0:
 ; CHECKPTX71-NEXT:    ld.param.b16 %rs13, [test_param_3];
 ; CHECKPTX71-NEXT:    ld.param.u32 %r3, [test_param_2];
 ; CHECKPTX71-NEXT:    ld.param.u32 %r2, [test_param_1];
 ; CHECKPTX71-NEXT:    ld.param.u32 %r1, [test_param_0];
-; CHECKPTX71-NEXT:    ld.b16 %rs18, [%r1];
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f1, %rs13;
+; CHECKPTX71-NEXT:    ld.b16 %rs22, [%r1];
 ; CHECKPTX71-NEXT:  $L__BB0_1: // %atomicrmw.start14
 ; CHECKPTX71-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f2, %rs18;
-; CHECKPTX71-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; CHECKPTX71-NEXT:    cvt.rn.bf16.f32 %rs14, %f3;
-; CHECKPTX71-NEXT:    atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
-; CHECKPTX71-NEXT:    setp.ne.s16 %p1, %rs3, %rs18;
-; CHECKPTX71-NEXT:    mov.u16 %rs18, %rs3;
+; CHECKPTX71-NEXT:    mov.b16 %rs14, 0x3F80;
+; CHECKPTX71-NEXT:    fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
+; CHECKPTX71-NEXT:    atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
+; CHECKPTX71-NEXT:    setp.ne.s16 %p1, %rs3, %rs22;
+; CHECKPTX71-NEXT:    mov.u16 %rs22, %rs3;
 ; CHECKPTX71-NEXT:    @%p1 bra $L__BB0_1;
 ; CHECKPTX71-NEXT:  // %bb.2: // %atomicrmw.end13
-; CHECKPTX71-NEXT:    ld.b16 %rs19, [%r1];
+; CHECKPTX71-NEXT:    ld.b16 %rs23, [%r1];
 ; CHECKPTX71-NEXT:  $L__BB0_3: // %atomicrmw.start8
 ; CHECKPTX71-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f4, %rs19;
-; CHECKPTX71-NEXT:    add.rn.f32 %f5, %f4, 0f3F800000;
-; CHECKPTX71-NEXT:    cvt.rn.bf16.f32 %rs15, %f5;
-; CHECKPTX71-NEXT:    atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
-; CHECKPTX71-NEXT:    setp.ne.s16 %p2, %rs6, %rs19;
-; CHECKPTX71-NEXT:    mov.u16 %rs19, %rs6;
+; CHECKPTX71-NEXT:    mov.b16 %rs16, 0x3F80;
+; CHECKPTX71-NEXT:    fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
+; CHECKPTX71-NEXT:    atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
+; CHECKPTX71-NEXT:    setp.ne.s16 %p2, %rs6, %rs23;
+; CHECKPTX71-NEXT:    mov.u16 %rs23, %rs6;
 ; CHECKPTX71-NEXT:    @%p2 bra $L__BB0_3;
 ; CHECKPTX71-NEXT:  // %bb.4: // %atomicrmw.end7
-; CHECKPTX71-NEXT:    ld.global.b16 %rs20, [%r2];
+; CHECKPTX71-NEXT:    ld.global.b16 %rs24, [%r2];
 ; CHECKPTX71-NEXT:  $L__BB0_5: // %atomicrmw.start2
 ; CHECKPTX71-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f7, %rs20;
-; CHECKPTX71-NEXT:    add.rn.f32 %f8, %f7, %f1;
-; CHECKPTX71-NEXT:    cvt.rn.bf16.f32 %rs16, %f8;
-; CHECKPTX71-NEXT:    atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
-; CHECKPTX71-NEXT:    setp.ne.s16 %p3, %rs9, %rs20;
-; CHECKPTX71-NEXT:    mov.u16 %rs20, %rs9;
+; CHECKPTX71-NEXT:    mov.b16 %rs18, 0x3F80;
+; CHECKPTX71-NEXT:    fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
+; CHECKPTX71-NEXT:    atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
+; CHECKPTX71-NEXT:    setp.ne.s16 %p3, %rs9, %rs24;
+; CHECKPTX71-NEXT:    mov.u16 %rs24, %rs9;
 ; CHECKPTX71-NEXT:    @%p3 bra $L__BB0_5;
 ; CHECKPTX71-NEXT:  // %bb.6: // %atomicrmw.end1
-; CHECKPTX71-NEXT:    ld.shared.b16 %rs21, [%r3];
+; CHECKPTX71-NEXT:    ld.shared.b16 %rs25, [%r3];
 ; CHECKPTX71-NEXT:  $L__BB0_7: // %atomicrmw.start
 ; CHECKPTX71-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT:    cvt.f32.bf16 %f10, %rs21;
-; CHECKPTX71-NEXT:    add.rn.f32 %f11, %f10, %f1;
-; CHECKPTX71-NEXT:    cvt.rn.bf16.f32 %rs17, %f11;
-; CHECKPTX71-NEXT:    atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
-; CHECKPTX71-NEXT:    setp.ne.s16 %p4, %rs12, %rs21;
-; CHECKPTX71-NEXT:    mov.u16 %rs21, %rs12;
+; CHECKPTX71-NEXT:    mov.b16 %rs20, 0x3F80;
+; CHECKPTX71-NEXT:    fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
+; CHECKPTX71-NEXT:    atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
+; CHECKPTX71-NEXT:    setp.ne.s16 %p4, %rs12, %rs25;
+; CHECKPTX71-NEXT:    mov.u16 %rs25, %rs12;
 ; CHECKPTX71-NEXT:    @%p4 bra $L__BB0_7;
 ; CHECKPTX71-NEXT:  // %bb.8: // %atomicrmw.end
 ; CHECKPTX71-NEXT:    ret;
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 6828bac18cad7f..eeb13b52130042 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -42,17 +42,14 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) {
 ;
 ; SM80-LABEL: test_fadd(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<4>;
-; SM80-NEXT:    .reg .f32 %f<4>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b16 %rs1, [test_fadd_param_0];
 ; SM80-NEXT:    ld.param.b16 %rs2, [test_fadd_param_1];
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT:    st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT:    mov.b16 %rs3, 0x3F80;
+; SM80-NEXT:    fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
+; SM80-NEXT:    st.param.b16 [func_retval0], %rs4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fadd(
@@ -113,17 +110,14 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
 ;
 ; SM80-LABEL: test_fsub(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<4>;
-; SM80-NEXT:    .reg .f32 %f<4>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b16 %rs1, [test_fsub_param_0];
 ; SM80-NEXT:    ld.param.b16 %rs2, [test_fsub_param_1];
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT:    sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT:    st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT:    mov.b16 %rs3, 0xBF80;
+; SM80-NEXT:    fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
+; SM80-NEXT:    st.param.b16 [func_retval0], %rs4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fsub(
@@ -202,23 +196,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_faddx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
-; SM80-NEXT:    ld.param.b32 %r1, [test_faddx2_param_0];
-; SM80-NEXT:    ld.param.b32 %r2, [test_faddx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    ld.param.b32 %r1, [test_faddx2_param_1];
+; SM80-NEXT:    ld.param.b32 %r2, [test_faddx2_param_0];
+; SM80-NEXT:    mov.b32 %r3, 1065369472;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_faddx2(
@@ -303,23 +288,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fsubx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fsubx2_param_0];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    mov.b32 %r3, -1082081408;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fsubx2(
@@ -404,23 +380,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fmulx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
-; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_0];
-; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_1];
+; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_0];
+; SM80-NEXT:    mov.b32 %r3, 0;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r1, %r3;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fmulx2(
@@ -727,15 +694,13 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
 ;
 ; SM80-LABEL: test_fadd_imm_1(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<3>;
-; SM80-NEXT:    .reg .f32 %f<3>;
+; SM80-NEXT:    .reg .b16 %rs<4>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT:    st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT:    mov.b16 %rs2, 0x3F80;
+; SM80-NEXT:    fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT:    st.param.b16 [func_retval0], %rs3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fadd_imm_1(
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 03cdeb9683abae..31d089a19450e1 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -22,19 +22,14 @@ define <2 x bfloat> @test_ret_const() #0 {
 define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
 ; SM80-LABEL: test_fadd_imm_0(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<3>;
-; SM80-NEXT:    .reg .b32 %r<3>;
-; SM80-NEXT:    .reg .f32 %f<5>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fadd_imm_0_param_0];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT:    cvt.f32.bf16 %f3, %rs2;
-; SM80-NEXT:    add.rn.f32 %f4, %f3, 0f40000000;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r2, %f4, %f2;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r2;
+; SM80-NEXT:    mov.b32 %r2, 1073758080;
+; SM80-NEXT:    mov.b32 %r3, 1065369472;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r1, %r3, %r2;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fadd_imm_0(
@@ -54,15 +49,13 @@ define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
 define bfloat @test_fadd_imm_1(bfloat %a) #0 {
 ; SM80-LABEL: test_fadd_imm_1(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<3>;
-; SM80-NEXT:    .reg .f32 %f<3>;
+; SM80-NEXT:    .reg .b16 %rs<4>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT:    st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT:    mov.b16 %rs2, 0x3F80;
+; SM80-NEXT:    fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT:    st.param.b16 [func_retval0], %rs3;
 ; SM80-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fadd_imm_1(
@@ -82,23 +75,14 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
 define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-LABEL: test_fsubx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fsubx2_param_0];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    mov.b32 %r3, -1082081408;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fsubx2(
@@ -118,23 +102,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-LABEL: test_fmulx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<5>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-NEXT:    .reg .f32 %f<7>;
+; SM80-NEXT:    .reg .b32 %r<5>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
-; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_0];
-; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_1];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT:    mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT:    mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
+; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_1];
+; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_0];
+; SM80-NEXT:    mov.b32 %r3, 0;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r1, %r3;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fmulx2(
@@ -543,30 +518,16 @@ define <2 x bfloat> @test_fabs(<2 x bfloat> %a) #0 {
 define <2 x bfloat> @test_fabs_add(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-LABEL: test_fabs_add(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
-; SM80-NEXT:    .reg .b32 %r<6>;
-; SM80-NEXT:    .reg .f32 %f<11>;
+; SM80-NEXT:    .reg .b32 %r<7>;
 ; SM80-EMPTY:
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fabs_add_param_1];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fabs_add_param_0];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT:    add.rn.f32 %f2, %f1, %f1;
-; SM80-NEXT:    cvt.f32.bf16 %f3, %rs2;
-; SM80-NEXT:    add.rn.f32 %f4, %f3, %f3;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f4, %f2;
-; SM80-NEXT:    abs.bf16x2 %r4, %r3;
-; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r4;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
-; SM80-NEXT:    mov.b32 {%rs5, %rs6}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f6, %rs5;
-; SM80-NEXT:    add.rn.f32 %f7, %f5, %f6;
-; SM80-NEXT:    cvt.f32.bf16 %f8, %rs4;
-; SM80-NEXT:    cvt.f32.bf16 %f9, %rs6;
-; SM80-NEXT:    add.rn.f32 %f10, %f8, %f9;
-; SM80-NEXT:    cvt.rn.bf16x2.f32 %r5, %f10, %f7;
-; SM80-NEXT:    st.param.b32 [func_retval0], %r5;
+; SM80-NEXT:    mov.b32 %r3, 1065369472;
+; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r3, %r2;
+; SM80-NEXT:    abs.bf16x2 %r5, %r4;
+; SM80-NEXT:    fma.rn.bf16x2 %r6, %r5, %r3, %r1;
+; SM80-NEXT:    st.param.b32 [func_retval0], %r6;
 ; SM80-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fabs_add(
@@ -802,45 +763,18 @@ define <2 x bfloat> @test_round(<2 x bfloat> %a) #0 {
 }
 
 define <2 x bfloat> @test_copysign(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
-; SM80-LABEL: test_copysign(
-; SM80:       {
-; SM80-NEXT:    .reg .pred %p<3>;
-; SM80-NEXT:    .reg .b16 %rs<15>;
-; SM80-NEXT:    .reg .b32 %r<4>;
-; SM80-EMPTY:
-; SM80-NEXT:  // %bb.0:
-; SM80-NEXT:    ld.param.b32 %r1, [test_copysign_param_1];
-; SM80-NEXT:    ld.param.b32 %r2, [test_copysign_param_0];
-; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    abs.bf16 %rs3, %rs2;
-; SM80-NEXT:    neg.bf16 %rs4, %rs3;
-; SM80-NEXT:    mov.b32 {%rs5, %rs6}, %r1;
-; SM80-NEXT:    shr.u16 %rs7, %rs6, 15;
-; SM80-NEXT:    and.b16 %rs8, %rs7, 1;
-; SM80-NEXT:    setp.eq.b16 %p1, %rs8, 1;
-; SM80-NEXT:    selp.b16 %rs9, %rs4, %rs3, %p1;
-; SM80-NEXT:    abs.bf16 %rs10, %rs1;
-; SM80-NEXT:    neg.bf16 %rs11, %rs10;
-; SM80-NEXT:    shr.u16 %rs12, %rs5, 15;
-; SM80-NEXT:    and.b16 %rs13, %rs12, 1;
-; SM80-NEXT:    setp.eq.b16 %p2, %rs13, 1;
-; SM80-NEXT:    selp.b16 %rs14, %rs11, %rs10, %p2;
-; SM80-NEXT:    mov.b32 %r3, {%rs14, %rs9};
-; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
-; SM80-NEXT:    ret;
-;
-; SM90-LABEL: test_copysign(
-; SM90:       {
-; SM90-NEXT:    .reg .b32 %r<6>;
-; SM90-EMPTY:
-; SM90-NEXT:  // %bb.0:
-; SM90-NEXT:    ld.param.b32 %r1, [test_copysign_param_0];
-; SM90-NEXT:    ld.param.b32 %r2, [test_copysign_param_1];
-; SM90-NEXT:    and.b32 %r3, %r2, -2147450880;
-; SM90-NEXT:    and.b32 %r4, %r1, 2147450879;
-; SM90-NEXT:    or.b32 %r5, %r4, %r3;
-; SM90-NEXT:    st.param.b32 [func_retval0], %r5;
-; SM90-NEXT:    ret;
+; CHECK-LABEL: test_copysign(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [test_copysign_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_copysign_param_1];
+; CHECK-NEXT:    and.b32 %r3, %r2, -2147450880;
+; CHECK-NEXT:    and.b32 %r4, %r1, 2147450879;
+; CHECK-NEXT:    or.b32 %r5, %r4, %r3;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r5;
+; CHECK-NEXT:    ret;
   %r = call <2 x bfloat> @llvm.copysign.f16(<2 x bfloat> %a, <2 x bfloat> %b)
   ret <2 x bfloat> %r
 }
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
index 48c94f275274bd..1643704e2ff95c 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
@@ -352,9 +352,7 @@ define bfloat @fma_bf16_expanded_no_nans(bfloat %a, bfloat %b, bfloat %c) #0 {
 define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloat %c) #0 {
 ; CHECK-LABEL: fma_bf16_expanded_no_nans_multiple_uses_of_fma(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<9>;
-; CHECK-NEXT:    .reg .b32 %r<7>;
-; CHECK-NEXT:    .reg .f32 %f<6>;
+; CHECK-NEXT:    .reg .b16 %rs<11>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_expanded_no_nans_multiple_uses_of_fma_param_0];
@@ -363,20 +361,11 @@ define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat
 ; CHECK-NEXT:    fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
 ; CHECK-NEXT:    mov.b16 %rs5, 0x0000;
 ; CHECK-NEXT:    max.bf16 %rs6, %rs4, %rs5;
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs4;
-; CHECK-NEXT:    shl.b32 %r2, %r1, 16;
-; CHECK-NEXT:    mov.b32 %f1, %r2;
-; CHECK-NEXT:    add.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs7, %f2;
-; CHECK-NEXT:    cvt.u32.u16 %r3, %rs6;
-; CHECK-NEXT:    shl.b32 %r4, %r3, 16;
-; CHECK-NEXT:    mov.b32 %f3, %r4;
-; CHECK-NEXT:    cvt.u32.u16 %r5, %rs7;
-; CHECK-NEXT:    shl.b32 %r6, %r5, 16;
-; CHECK-NEXT:    mov.b32 %f4, %r6;
-; CHECK-NEXT:    add.f32 %f5, %f3, %f4;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs8, %f5;
-; CHECK-NEXT:    st.param.b16 [func_retval0], %rs8;
+; CHECK-NEXT:    mov.b16 %rs7, 0x40E0;
+; CHECK-NEXT:    mov.b16 %rs8, 0x3F80;
+; CHECK-NEXT:    fma.rn.bf16 %rs9, %rs4, %rs8, %rs7;
+; CHECK-NEXT:    fma.rn.bf16 %rs10, %rs6, %rs8, %rs9;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs10;
 ; CHECK-NEXT:    ret;
 ;
 ; CHECK-FTZ-LABEL: fma_bf16_expanded_no_nans_multiple_uses_of_fma(
@@ -959,9 +948,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans(<2 x bfloat> %a, <2 x bfloat> %
 define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) #0 {
 ; CHECK-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<7>;
-; CHECK-NEXT:    .reg .b32 %r<20>;
-; CHECK-NEXT:    .reg .f32 %f<11>;
+; CHECK-NEXT:    .reg .b32 %r<11>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b32 %r1, [fma_bf16x2_expanded_no_nans_multiple_uses_of_fma_param_2];
@@ -970,34 +957,11 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
 ; CHECK-NEXT:    fma.rn.bf16x2 %r4, %r3, %r2, %r1;
 ; CHECK-NEXT:    mov.b32 %r5, 0;
 ; CHECK-NEXT:    max.bf16x2 %r6, %r4, %r5;
-; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
-; CHECK-NEXT:    cvt.u32.u16 %r7, %rs2;
-; CHECK-NEXT:    shl.b32 %r8, %r7, 16;
-; CHECK-NEXT:    mov.b32 %f1, %r8;
-; CHECK-NEXT:    add.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs3, %f2;
-; CHECK-NEXT:    cvt.u32.u16 %r9, %rs1;
-; CHECK-NEXT:    shl.b32 %r10, %r9, 16;
-; CHECK-NEXT:    mov.b32 %f3, %r10;
-; CHECK-NEXT:    add.f32 %f4, %f3, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs4, %f4;
-; CHECK-NEXT:    mov.b32 {%rs5, %rs6}, %r6;
-; CHECK-NEXT:    cvt.u32.u16 %r11, %rs5;
-; CHECK-NEXT:    shl.b32 %r12, %r11, 16;
-; CHECK-NEXT:    mov.b32 %f5, %r12;
-; CHECK-NEXT:    cvt.u32.u16 %r13, %rs4;
-; CHECK-NEXT:    shl.b32 %r14, %r13, 16;
-; CHECK-NEXT:    mov.b32 %f6, %r14;
-; CHECK-NEXT:    add.f32 %f7, %f5, %f6;
-; CHECK-NEXT:    cvt.u32.u16 %r15, %rs6;
-; CHECK-NEXT:    shl.b32 %r16, %r15, 16;
-; CHECK-NEXT:    mov.b32 %f8, %r16;
-; CHECK-NEXT:    cvt.u32.u16 %r17, %rs3;
-; CHECK-NEXT:    shl.b32 %r18, %r17, 16;
-; CHECK-NEXT:    mov.b32 %f9, %r18;
-; CHECK-NEXT:    add.f32 %f10, %f8, %f9;
-; CHECK-NEXT:    cvt.rn.bf16x2.f32 %r19, %f10, %f7;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r19;
+; CHECK-NEXT:    mov.b32 %r7, 1088438496;
+; CHECK-NEXT:    mov.b32 %r8, 1065369472;
+; CHECK-NEXT:    fma.rn.bf16x2 %r9, %r4, %r8, %r7;
+; CHECK-NEXT:    fma.rn.bf16x2 %r10, %r6, %r8, %r9;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r10;
 ; CHECK-NEXT:    ret;
 ;
 ; CHECK-FTZ-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
index 561f2b0cc06730..e1e34ee9b1c159 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
@@ -221,26 +221,18 @@ define bfloat @fma_bf16_no_nans(bfloat %a, bfloat %b, bfloat %c) #0 {
 define bfloat @fma_bf16_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloat %c) #0 {
 ; CHECK-LABEL: fma_bf16_no_nans_multiple_uses_of_fma(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<7>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
-; CHECK-NEXT:    .reg .f32 %f<5>;
+; CHECK-NEXT:    .reg .b16 %rs<9>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_no_nans_multiple_uses_of_fma_param_0];
 ; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_no_nans_multiple_uses_of_fma_param_1];
 ; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_no_nans_multiple_uses_of_fma_param_2];
 ; CHECK-NEXT:    fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs4;
-; CHECK-NEXT:    shl.b32 %r2, %r1, 16;
-; CHECK-NEXT:    mov.b32 %f1, %r2;
-; CHECK-NEXT:    add.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs5, %f2;
-; CHECK-NEXT:    cvt.u32.u16 %r3, %rs5;
-; CHECK-NEXT:    shl.b32 %r4, %r3, 16;
-; CHECK-NEXT:    mov.b32 %f3, %r4;
-; CHECK-NEXT:    add.f32 %f4, %f3, %f1;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs6, %f4;
-; CHECK-NEXT:    st.param.b16 [func_retval0], %rs6;
+; CHECK-NEXT:    mov.b16 %rs5, 0x40E0;
+; CHECK-NEXT:    mov.b16 %rs6, 0x3F80;
+; CHECK-NEXT:    fma.rn.bf16 %rs7, %rs4, %rs6, %rs5;
+; CHECK-NEXT:    fma.rn.bf16 %rs8, %rs7, %rs6, %rs4;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs8;
 ; CHECK-NEXT:    ret;
 ;
 ; CHECK-FTZ-LABEL: fma_bf16_no_nans_multiple_uses_of_fma(
@@ -642,36 +634,18 @@ define <2 x bfloat> @fma_bf16x2_no_nans(<2 x bfloat> %a, <2 x bfloat> %b, <2 x b
 define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) #0 {
 ; CHECK-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<5>;
-; CHECK-NEXT:    .reg .b32 %r<14>;
-; CHECK-NEXT:    .reg .f32 %f<9>;
+; CHECK-NEXT:    .reg .b32 %r<9>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b32 %r1, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_2];
 ; CHECK-NEXT:    ld.param.b32 %r2, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_1];
 ; CHECK-NEXT:    ld.param.b32 %r3, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_0];
 ; CHECK-NEXT:    fma.rn.bf16x2 %r4, %r3, %r2, %r1;
-; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
-; CHECK-NEXT:    cvt.u32.u16 %r5, %rs2;
-; CHECK-NEXT:    shl.b32 %r6, %r5, 16;
-; CHECK-NEXT:    mov.b32 %f1, %r6;
-; CHECK-NEXT:    add.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs3, %f2;
-; CHECK-NEXT:    cvt.u32.u16 %r7, %rs1;
-; CHECK-NEXT:    shl.b32 %r8, %r7, 16;
-; CHECK-NEXT:    mov.b32 %f3, %r8;
-; CHECK-NEXT:    add.f32 %f4, %f3, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs4, %f4;
-; CHECK-NEXT:    cvt.u32.u16 %r9, %rs4;
-; CHECK-NEXT:    shl.b32 %r10, %r9, 16;
-; CHECK-NEXT:    mov.b32 %f5, %r10;
-; CHECK-NEXT:    add.f32 %f6, %f5, %f3;
-; CHECK-NEXT:    cvt.u32.u16 %r11, %rs3;
-; CHECK-NEXT:    shl.b32 %r12, %r11, 16;
-; CHECK-NEXT:    mov.b32 %f7, %r12;
-; CHECK-NEXT:    add.f32 %f8, %f7, %f1;
-; CHECK-NEXT:    cvt.rn.bf16x2.f32 %r13, %f8, %f6;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r13;
+; CHECK-NEXT:    mov.b32 %r5, 1088438496;
+; CHECK-NEXT:    mov.b32 %r6, 1065369472;
+; CHECK-NEXT:    fma.rn.bf16x2 %r7, %r4, %r6, %r5;
+; CHECK-NEXT:    fma.rn.bf16x2 %r8, %r7, %r6, %r4;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r8;
 ; CHECK-NEXT:    ret;
 ;
 ; CHECK-FTZ-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
index b20ca24dd91a0c..ea046dc90b23f2 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
@@ -233,9 +233,7 @@ define bfloat @fma_bf16_expanded_no_nans(bfloat %a, bfloat %b, bfloat %c)  {
 define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloat %c)  {
 ; CHECK-LABEL: fma_bf16_expanded_no_nans_multiple_uses_of_fma(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<9>;
-; CHECK-NEXT:    .reg .b32 %r<7>;
-; CHECK-NEXT:    .reg .f32 %f<6>;
+; CHECK-NEXT:    .reg .b16 %rs<11>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_expanded_no_nans_multiple_uses_of_fma_param_0];
@@ -244,20 +242,11 @@ define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat
 ; CHECK-NEXT:    fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
 ; CHECK-NEXT:    mov.b16 %rs5, 0x0000;
 ; CHECK-NEXT:    max.bf16 %rs6, %rs4, %rs5;
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs4;
-; CHECK-NEXT:    shl.b32 %r2, %r1, 16;
-; CHECK-NEXT:    mov.b32 %f1, %r2;
-; CHECK-NEXT:    add.rn.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs7, %f2;
-; CHECK-NEXT:    cvt.u32.u16 %r3, %rs6;
-; CHECK-NEXT:    shl.b32 %r4, %r3, 16;
-; CHECK-NEXT:    mov.b32 %f3, %r4;
-; CHECK-NEXT:    cvt.u32.u16 %r5, %rs7;
-; CHECK-NEXT:    shl.b32 %r6, %r5, 16;
-; CHECK-NEXT:    mov.b32 %f4, %r6;
-; CHECK-NEXT:    add.rn.f32 %f5, %f3, %f4;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs8, %f5;
-; CHECK-NEXT:    st.param.b16 [func_retval0], %rs8;
+; CHECK-NEXT:    mov.b16 %rs7, 0x40E0;
+; CHECK-NEXT:    mov.b16 %rs8, 0x3F80;
+; CHECK-NEXT:    fma.rn.bf16 %rs9, %rs4, %rs8, %rs7;
+; CHECK-NEXT:    fma.rn.bf16 %rs10, %rs6, %rs8, %rs9;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs10;
 ; CHECK-NEXT:    ret;
 ;
 ; CHECK-FTZ-LABEL: fma_bf16_expanded_no_nans_multiple_uses_of_fma(
@@ -694,9 +683,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans(<2 x bfloat> %a, <2 x bfloat> %
 define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)  {
 ; CHECK-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<7>;
-; CHECK-NEXT:    .reg .b32 %r<20>;
-; CHECK-NEXT:    .reg .f32 %f<11>;
+; CHECK-NEXT:    .reg .b32 %r<11>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b32 %r1, [fma_bf16x2_expanded_no_nans_multiple_uses_of_fma_param_2];
@@ -705,34 +692,11 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
 ; CHECK-NEXT:    fma.rn.bf16x2 %r4, %r3, %r2, %r1;
 ; CHECK-NEXT:    mov.b32 %r5, 0;
 ; CHECK-NEXT:    max.bf16x2 %r6, %r4, %r5;
-; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
-; CHECK-NEXT:    cvt.u32.u16 %r7, %rs2;
-; CHECK-NEXT:    shl.b32 %r8, %r7, 16;
-; CHECK-NEXT:    mov.b32 %f1, %r8;
-; CHECK-NEXT:    add.rn.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs3, %f2;
-; CHECK-NEXT:    cvt.u32.u16 %r9, %rs1;
-; CHECK-NEXT:    shl.b32 %r10, %r9, 16;
-; CHECK-NEXT:    mov.b32 %f3, %r10;
-; CHECK-NEXT:    add.rn.f32 %f4, %f3, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs4, %f4;
-; CHECK-NEXT:    mov.b32 {%rs5, %rs6}, %r6;
-; CHECK-NEXT:    cvt.u32.u16 %r11, %rs5;
-; CHECK-NEXT:    shl.b32 %r12, %r11, 16;
-; CHECK-NEXT:    mov.b32 %f5, %r12;
-; CHECK-NEXT:    cvt.u32.u16 %r13, %rs4;
-; CHECK-NEXT:    shl.b32 %r14, %r13, 16;
-; CHECK-NEXT:    mov.b32 %f6, %r14;
-; CHECK-NEXT:    add.rn.f32 %f7, %f5, %f6;
-; CHECK-NEXT:    cvt.u32.u16 %r15, %rs6;
-; CHECK-NEXT:    shl.b32 %r16, %r15, 16;
-; CHECK-NEXT:    mov.b32 %f8, %r16;
-; CHECK-NEXT:    cvt.u32.u16 %r17, %rs3;
-; CHECK-NEXT:    shl.b32 %r18, %r17, 16;
-; CHECK-NEXT:    mov.b32 %f9, %r18;
-; CHECK-NEXT:    add.rn.f32 %f10, %f8, %f9;
-; CHECK-NEXT:    cvt.rn.bf16x2.f32 %r19, %f10, %f7;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r19;
+; CHECK-NEXT:    mov.b32 %r7, 1088438496;
+; CHECK-NEXT:    mov.b32 %r8, 1065369472;
+; CHECK-NEXT:    fma.rn.bf16x2 %r9, %r4, %r8, %r7;
+; CHECK-NEXT:    fma.rn.bf16x2 %r10, %r6, %r8, %r9;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r10;
 ; CHECK-NEXT:    ret;
 ;
 ; CHECK-FTZ-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
@@ -1204,26 +1168,18 @@ define bfloat @fma_bf16_no_nans(bfloat %a, bfloat %b, bfloat %c)  {
 define bfloat @fma_bf16_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloat %c)  {
 ; CHECK-LABEL: fma_bf16_no_nans_multiple_uses_of_fma(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<7>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
-; CHECK-NEXT:    .reg .f32 %f<5>;
+; CHECK-NEXT:    .reg .b16 %rs<9>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_no_nans_multiple_uses_of_fma_param_0];
 ; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_no_nans_multiple_uses_of_fma_param_1];
 ; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_no_nans_multiple_uses_of_fma_param_2];
 ; CHECK-NEXT:    fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs4;
-; CHECK-NEXT:    shl.b32 %r2, %r1, 16;
-; CHECK-NEXT:    mov.b32 %f1, %r2;
-; CHECK-NEXT:    add.rn.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs5, %f2;
-; CHECK-NEXT:    cvt.u32.u16 %r3, %rs5;
-; CHECK-NEXT:    shl.b32 %r4, %r3, 16;
-; CHECK-NEXT:    mov.b32 %f3, %r4;
-; CHECK-NEXT:    add.rn.f32 %f4, %f3, %f1;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs6, %f4;
-; CHECK-NEXT:    st.param.b16 [func_retval0], %rs6;
+; CHECK-NEXT:    mov.b16 %rs5, 0x40E0;
+; CHECK-NEXT:    mov.b16 %rs6, 0x3F80;
+; CHECK-NEXT:    fma.rn.bf16 %rs7, %rs4, %rs6, %rs5;
+; CHECK-NEXT:    fma.rn.bf16 %rs8, %rs7, %rs6, %rs4;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs8;
 ; CHECK-NEXT:    ret;
 ;
 ; CHECK-FTZ-LABEL: fma_bf16_no_nans_multiple_uses_of_fma(
@@ -1629,36 +1585,18 @@ define <2 x bfloat> @fma_bf16x2_no_nans(<2 x bfloat> %a, <2 x bfloat> %b, <2 x b
 define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)  {
 ; CHECK-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<5>;
-; CHECK-NEXT:    .reg .b32 %r<14>;
-; CHECK-NEXT:    .reg .f32 %f<9>;
+; CHECK-NEXT:    .reg .b32 %r<9>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b32 %r1, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_2];
 ; CHECK-NEXT:    ld.param.b32 %r2, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_1];
 ; CHECK-NEXT:    ld.param.b32 %r3, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_0];
 ; CHECK-NEXT:    fma.rn.bf16x2 %r4, %r3, %r2, %r1;
-; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
-; CHECK-NEXT:    cvt.u32.u16 %r5, %rs2;
-; CHECK-NEXT:    shl.b32 %r6, %r5, 16;
-; CHECK-NEXT:    mov.b32 %f1, %r6;
-; CHECK-NEXT:    add.rn.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs3, %f2;
-; CHECK-NEXT:    cvt.u32.u16 %r7, %rs1;
-; CHECK-NEXT:    shl.b32 %r8, %r7, 16;
-; CHECK-NEXT:    mov.b32 %f3, %r8;
-; CHECK-NEXT:    add.rn.f32 %f4, %f3, 0f40E00000;
-; CHECK-NEXT:    cvt.rn.bf16.f32 %rs4, %f4;
-; CHECK-NEXT:    cvt.u32.u16 %r9, %rs4;
-; CHECK-NEXT:    shl.b32 %r10, %r9, 16;
-; CHECK-NEXT:    mov.b32 %f5, %r10;
-; CHECK-NEXT:    add.rn.f32 %f6, %f5, %f3;
-; CHECK-NEXT:    cvt.u32.u16 %r11, %rs3;
-; CHECK-NEXT:    shl.b32 %r12, %r11, 16;
-; CHECK-NEXT:    mov.b32 %f7, %r12;
-; CHECK-NEXT:    add.rn.f32 %f8, %f7, %f1;
-; CHECK-NEXT:    cvt.rn.bf16x2.f32 %r13, %f8, %f6;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r13;
+; CHECK-NEXT:    mov.b32 %r5, 1088438496;
+; CHECK-NEXT:    mov.b32 %r6, 1065369472;
+; CHECK-NEXT:    fma.rn.bf16x2 %r7, %r4, %r6, %r5;
+; CHECK-NEXT:    fma.rn.bf16x2 %r8, %r7, %r6, %r4;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r8;
 ; CHECK-NEXT:    ret;
 ;
 ; CHECK-FTZ-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(

>From 0806ba584ce51991aed57d3c8dce134940a70529 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Thu, 26 Dec 2024 21:42:50 +0000
Subject: [PATCH 2/3] Fix mul for negative zero

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp    | 8 +++++---
 llvm/test/CodeGen/NVPTX/bf16-instructions.ll   | 2 +-
 llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll | 2 +-
 3 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 47f56abae3c056..39623de749ac4b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2548,11 +2548,13 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
     return PromoteBinOpToF32(Op.getNode(), DAG);
   }
 
-  // FMUL(a, b) -> FMA(a, b, 0.0)
+  // FMUL(a, b) -> FMA(a, b, -0.0)
+  // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
   SDLoc DL(Op);
   auto VT = Op.getValueType();
-  auto Zero = DAG.getConstantFP(0.0, DL, VT);
-  SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
+  auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1),
+                                   NegZero};
   return DAG.getNode(ISD::FMA, DL, VT, Operands);
 }
 
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index eeb13b52130042..b53f82403af5f3 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -385,7 +385,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_1];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_0];
-; SM80-NEXT:    mov.b32 %r3, 0;
+; SM80-NEXT:    mov.b32 %r3, -2147450880;
 ; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r1, %r3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 31d089a19450e1..f7ffba385df764 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -107,7 +107,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:  // %bb.0:
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_1];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_0];
-; SM80-NEXT:    mov.b32 %r3, 0;
+; SM80-NEXT:    mov.b32 %r3, -2147450880;
 ; SM80-NEXT:    fma.rn.bf16x2 %r4, %r2, %r1, %r3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r4;
 ; SM80-NEXT:    ret;

>From ee6635b20c2625756945dfadf6e0d7803978ca57 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Fri, 27 Dec 2024 19:49:51 +0000
Subject: [PATCH 3/3] Move fma expansions to default expand rule

---
 llvm/include/llvm/CodeGen/TargetLowering.h    | 17 ++++++
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 21 +++++---
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 54 +++++++++++++++++++
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 21 ++------
 4 files changed, 88 insertions(+), 25 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 3751aac4df8ead..d5ed219af639d7 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5325,6 +5325,23 @@ class TargetLowering : public TargetLoweringBase {
                            SDNodeFlags Flags, const SDLoc &DL,
                            SelectionDAG &DAG) const;
 
+  /// Expand floating point add
+  /// \param N Node to expand
+  /// \returns The expansion result or SDValue() if it fails.
+  SDValue expandFADD(SDNode *N, SelectionDAG &DAG) const;
+
+  /// Expand floating point multiply
+  /// \param N Node to expand
+  /// \param Result output after conversion
+  /// \returns The expansion result or SDValue() if it fails.
+  SDValue expandFMUL(SDNode *N, SelectionDAG &DAG) const;
+
+  /// Expand floating point subtract
+  /// \param N Node to expand
+  /// \param Result output after conversion
+  /// \returns The expansion result or SDValue() if it fails.
+  SDValue expandFSUB(SDNode *N, SelectionDAG &DAG) const;
+
   /// Expand CTPOP nodes. Expands vector/scalar CTPOP nodes,
   /// vector nodes can only succeed if all operations are legal/custom.
   /// \param N Node to expand
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 595a410101eca1..a91c0c531e188e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3671,14 +3671,21 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     Results.push_back(ExpandConstant(CP));
     break;
   }
+  case ISD::FADD: {
+    if (SDValue Expand = TLI.expandFADD(Node, DAG)) {
+      Results.push_back(Expand);
+    }
+    break;
+  }
+  case ISD::FMUL: {
+    if (SDValue Expand = TLI.expandFMUL(Node, DAG)) {
+      Results.push_back(Expand);
+    }
+    break;
+  }
   case ISD::FSUB: {
-    EVT VT = Node->getValueType(0);
-    if (TLI.isOperationLegalOrCustom(ISD::FADD, VT) &&
-        TLI.isOperationLegalOrCustom(ISD::FNEG, VT)) {
-      const SDNodeFlags Flags = Node->getFlags();
-      Tmp1 = DAG.getNode(ISD::FNEG, dl, VT, Node->getOperand(1));
-      Tmp1 = DAG.getNode(ISD::FADD, dl, VT, Node->getOperand(0), Tmp1, Flags);
-      Results.push_back(Tmp1);
+    if (SDValue Expand = TLI.expandFSUB(Node, DAG)) {
+      Results.push_back(Expand);
     }
     break;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index e87d809f88eb97..845f2319f67426 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -9071,6 +9071,60 @@ SDValue TargetLowering::expandIS_FPCLASS(EVT ResultVT, SDValue Op,
   return Res;
 }
 
+SDValue TargetLowering::expandFADD(SDNode *Node, SelectionDAG &DAG) const {
+  auto VT = Node->getValueType(0);
+  if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
+    return {};
+  }
+
+  // FADD(a, b) -> FMA(a, 1.0, b)
+  SDLoc DL(Node);
+  auto One = DAG.getConstantFP(1.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Node->getOperand(0), One,
+                                   Node->getOperand(1)};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
+}
+
+SDValue TargetLowering::expandFMUL(SDNode *Node, SelectionDAG &DAG) const {
+  auto VT = Node->getValueType(0);
+  if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
+    return {};
+  }
+
+  // FMUL(a, b) -> FMA(a, b, -0.0)
+  // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
+  SDLoc DL(Node);
+  auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
+  SmallVector<SDValue, 3> Operands{Node->getOperand(0), Node->getOperand(1),
+                                   NegZero};
+  return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
+}
+
+SDValue TargetLowering::expandFSUB(SDNode *Node, SelectionDAG &DAG) const {
+  SDLoc DL(Node);
+  SDNodeFlags SDFlags = Node->getFlags();
+  auto VT = Node->getValueType(0);
+
+  bool CanUseFMA = isOperationLegalOrCustom(ISD::FMA, VT);
+  bool CanUseAddSub = (isOperationLegalOrCustom(ISD::FADD, VT) &&
+                       isOperationLegalOrCustom(ISD::FNEG, VT));
+  bool PreferAddSub = CanUseAddSub && isFNegFree(VT);
+
+  // FSUB(a, b) -> FMA(b, -1.0, a)
+  if (CanUseFMA && !PreferAddSub) {
+    auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+    SmallVector<SDValue, 3> Operands{Node->getOperand(1), NegOne,
+                                     Node->getOperand(0)};
+    return DAG.getNode(ISD::FMA, DL, VT, Operands, SDFlags);
+  }
+  // FSUB(a, b) -> FADD(a, FNEG(b))
+  if (CanUseAddSub) {
+    auto Neg = DAG.getNode(ISD::FNEG, DL, VT, Node->getOperand(1));
+    return DAG.getNode(ISD::FADD, DL, VT, Node->getOperand(0), Neg, SDFlags);
+  }
+  return {};
+}
+
 // Only expand vector types if we have the appropriate vector bit operations.
 static bool canExpandVectorCTPOP(const TargetLowering &TLI, EVT VT) {
   assert(VT.isVector() && "Expected vector type");
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 39623de749ac4b..3c2cb18c946efe 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2520,11 +2520,7 @@ SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
   }
 
   // FADD(a, b) -> FMA(a, 1.0, b)
-  SDLoc DL(Op);
-  auto VT = Op.getValueType();
-  auto One = DAG.getConstantFP(1.0, DL, VT);
-  SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
-  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+  return expandFADD(Op.getNode(), DAG);
 }
 
 SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
@@ -2534,12 +2530,7 @@ SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
   }
 
   // FSUB(a, b) -> FMA(b, -1.0, a)
-  SDLoc DL(Op);
-  auto VT = Op.getValueType();
-  auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
-  SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
-                                   Op->getOperand(0)};
-  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+  return expandFSUB(Op.getNode(), DAG);
 }
 
 SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
@@ -2549,13 +2540,7 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
   }
 
   // FMUL(a, b) -> FMA(a, b, -0.0)
-  // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
-  SDLoc DL(Op);
-  auto VT = Op.getValueType();
-  auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
-  SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1),
-                                   NegZero};
-  return DAG.getNode(ISD::FMA, DL, VT, Operands);
+  return expandFMUL(Op.getNode(), DAG);
 }
 
 SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,



More information about the llvm-commits mailing list