[llvm] [NVPTX] Select bfloat16 add/mul/sub as fma on SM80 (PR #121065)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 15 21:19:42 PST 2025
https://github.com/peterbell10 updated https://github.com/llvm/llvm-project/pull/121065
>From b529e0c669665530304d7d718fec8a130eb43551 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Tue, 24 Dec 2024 19:57:39 +0000
Subject: [PATCH 1/7] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80
SM80 has fma for bfloat16 but not add/mul/sub. Currently these are
just promoted to f32 but we can instead write them in terms of the
fma:
```
FADD(a, b) -> FMA(a, 1.0, b)
FMUL(a, b) -> FMA(a, b, 0.0)
FSUB(a, b) -> FMA(b, -1.0, a)
```
Unfortunately there is no `fma.ftz` so when ftz is enabled, we still
fall back to promotion.
This is also the inverse of some generic DAGCombiner patterns, so
I've had to add checks to avoid it reversing the legalization which
would cause an infinite loop.
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 26 ++--
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 73 +++++++++
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 4 +
llvm/test/CodeGen/NVPTX/atomics-sm90.ll | 56 ++++---
llvm/test/CodeGen/NVPTX/bf16-instructions.ll | 91 ++++--------
.../test/CodeGen/NVPTX/bf16x2-instructions.ll | 140 +++++-------------
llvm/test/CodeGen/NVPTX/fma-relu-contract.ll | 60 ++------
.../CodeGen/NVPTX/fma-relu-fma-intrinsic.ll | 50 ++-----
.../NVPTX/fma-relu-instruction-flag.ll | 110 +++-----------
9 files changed, 229 insertions(+), 381 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6805e0cb23ace0..3d32fb77eb5917 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17559,10 +17559,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
return N2;
}
+ const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
+ !TLI.isOperationLegal(ISD::FADD, VT));
+
// FIXME: Support splat of constant.
- if (N0CFP && N0CFP->isExactlyValue(1.0))
+ if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
- if (N1CFP && N1CFP->isExactlyValue(1.0))
+ if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
// Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17594,7 +17597,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
// (fma x, -1, y) -> (fadd (fneg x), y)
// FIXME: Support splat of constant.
- if (N1CFP) {
+ if (N1CFP && !PreferFMAAdd) {
if (N1CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
@@ -17604,15 +17607,14 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
AddToWorklist(RHSNeg.getNode());
return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
}
-
- // fma (fneg x), K, y -> fma x -K, y
- if (matcher.match(N0, ISD::FNEG) &&
- (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
- (N1.hasOneUse() &&
- !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
- return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
- matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
- }
+ }
+ // fma (fneg x), K, y -> fma x -K, y
+ if (N1CFP && matcher.match(N0, ISD::FNEG) &&
+ (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+ (N1.hasOneUse() &&
+ !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
+ return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+ matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
}
// FIXME: Support splat of constant.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 184f96b872aa62..5b41287bff842a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -862,6 +862,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
+ // Lower bf16 add/mul/sub as fma when it avoids promotion
+ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
+ for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
+ if (getOperationAction(Op, VT) != Legal &&
+ getOperationAction(ISD::FMA, VT) == Legal) {
+ setOperationAction(Op, VT, Custom);
+ }
+ }
+ }
+
// f16/f16x2 neg was introduced in PTX 60, SM_53.
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
STI.getPTXVersion() >= 60 &&
@@ -2498,6 +2508,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
}
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ EVT NVT = MVT::f32;
+ if (VT.isVector()) {
+ NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+ }
+ SDLoc DL(N);
+ SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+ SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+ SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+ return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FADD(a, b) -> FMA(a, 1.0, b)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto One = DAG.getConstantFP(1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FSUB(a, b) -> FMA(b, -1.0, a)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+ Op->getOperand(0)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FMUL(a, b) -> FMA(a, b, 0.0)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto Zero = DAG.getConstantFP(0.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
SelectionDAG &DAG) const {
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2689,6 +2755,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerSTACKSAVE(Op, DAG);
case ISD::CopyToReg:
return LowerCopyToReg_128(Op, DAG);
+ case ISD::FADD:
+ return LowerFADD(Op, DAG);
+ case ISD::FSUB:
+ return LowerFSUB(Op, DAG);
+ case ISD::FMUL:
+ return LowerFMUL(Op, DAG);
+
default:
llvm_unreachable("Custom lowering not defined for operation");
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 51265ed2179d88..514bf7cc240c9c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -278,6 +278,10 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
index f81b785f13225c..67552b95e04915 100644
--- a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
+++ b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
; CHECKPTX71-LABEL: test(
; CHECKPTX71: {
; CHECKPTX71-NEXT: .reg .pred %p<5>;
-; CHECKPTX71-NEXT: .reg .b16 %rs<22>;
+; CHECKPTX71-NEXT: .reg .b16 %rs<26>;
; CHECKPTX71-NEXT: .reg .b32 %r<4>;
-; CHECKPTX71-NEXT: .reg .f32 %f<12>;
; CHECKPTX71-EMPTY:
; CHECKPTX71-NEXT: // %bb.0:
; CHECKPTX71-NEXT: ld.param.b16 %rs13, [test_param_3];
; CHECKPTX71-NEXT: ld.param.u32 %r3, [test_param_2];
; CHECKPTX71-NEXT: ld.param.u32 %r2, [test_param_1];
; CHECKPTX71-NEXT: ld.param.u32 %r1, [test_param_0];
-; CHECKPTX71-NEXT: ld.b16 %rs18, [%r1];
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs13;
+; CHECKPTX71-NEXT: ld.b16 %rs22, [%r1];
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start14
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs18;
-; CHECKPTX71-NEXT: add.rn.f32 %f3, %f2, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs14, %f3;
-; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
-; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs18;
-; CHECKPTX71-NEXT: mov.u16 %rs18, %rs3;
+; CHECKPTX71-NEXT: mov.b16 %rs14, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
+; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
+; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs22;
+; CHECKPTX71-NEXT: mov.u16 %rs22, %rs3;
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end13
-; CHECKPTX71-NEXT: ld.b16 %rs19, [%r1];
+; CHECKPTX71-NEXT: ld.b16 %rs23, [%r1];
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start8
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs19;
-; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs15, %f5;
-; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
-; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs19;
-; CHECKPTX71-NEXT: mov.u16 %rs19, %rs6;
+; CHECKPTX71-NEXT: mov.b16 %rs16, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
+; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
+; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs23;
+; CHECKPTX71-NEXT: mov.u16 %rs23, %rs6;
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end7
-; CHECKPTX71-NEXT: ld.global.b16 %rs20, [%r2];
+; CHECKPTX71-NEXT: ld.global.b16 %rs24, [%r2];
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start2
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f7, %rs20;
-; CHECKPTX71-NEXT: add.rn.f32 %f8, %f7, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f8;
-; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
-; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs20;
-; CHECKPTX71-NEXT: mov.u16 %rs20, %rs9;
+; CHECKPTX71-NEXT: mov.b16 %rs18, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
+; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
+; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs24;
+; CHECKPTX71-NEXT: mov.u16 %rs24, %rs9;
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end1
-; CHECKPTX71-NEXT: ld.shared.b16 %rs21, [%r3];
+; CHECKPTX71-NEXT: ld.shared.b16 %rs25, [%r3];
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f10, %rs21;
-; CHECKPTX71-NEXT: add.rn.f32 %f11, %f10, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs17, %f11;
-; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
-; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs21;
-; CHECKPTX71-NEXT: mov.u16 %rs21, %rs12;
+; CHECKPTX71-NEXT: mov.b16 %rs20, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
+; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
+; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs25;
+; CHECKPTX71-NEXT: mov.u16 %rs25, %rs12;
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end
; CHECKPTX71-NEXT: ret;
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 6828bac18cad7f..eeb13b52130042 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -42,17 +42,14 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) {
;
; SM80-LABEL: test_fadd(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<4>;
-; SM80-NEXT: .reg .f32 %f<4>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_param_0];
; SM80-NEXT: ld.param.b16 %rs2, [test_fadd_param_1];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT: mov.b16 %rs3, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fadd(
@@ -113,17 +110,14 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
;
; SM80-LABEL: test_fsub(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<4>;
-; SM80-NEXT: .reg .f32 %f<4>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fsub_param_0];
; SM80-NEXT: ld.param.b16 %rs2, [test_fsub_param_1];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT: mov.b16 %rs3, 0xBF80;
+; SM80-NEXT: fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fsub(
@@ -202,23 +196,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_faddx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
-; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
+; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
+; SM80-NEXT: mov.b32 %r3, 1065369472;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_faddx2(
@@ -303,23 +288,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fsubx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: mov.b32 %r3, -1082081408;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fsubx2(
@@ -404,23 +380,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fmulx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
-; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
+; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
+; SM80-NEXT: mov.b32 %r3, 0;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fmulx2(
@@ -727,15 +694,13 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
;
; SM80-LABEL: test_fadd_imm_1(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .f32 %f<3>;
+; SM80-NEXT: .reg .b16 %rs<4>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT: mov.b16 %rs2, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fadd_imm_1(
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 03cdeb9683abae..31d089a19450e1 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -22,19 +22,14 @@ define <2 x bfloat> @test_ret_const() #0 {
define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
; SM80-LABEL: test_fadd_imm_0(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .b32 %r<3>;
-; SM80-NEXT: .reg .f32 %f<5>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fadd_imm_0_param_0];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.f32.bf16 %f3, %rs2;
-; SM80-NEXT: add.rn.f32 %f4, %f3, 0f40000000;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r2, %f4, %f2;
-; SM80-NEXT: st.param.b32 [func_retval0], %r2;
+; SM80-NEXT: mov.b32 %r2, 1073758080;
+; SM80-NEXT: mov.b32 %r3, 1065369472;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r3, %r2;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fadd_imm_0(
@@ -54,15 +49,13 @@ define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
define bfloat @test_fadd_imm_1(bfloat %a) #0 {
; SM80-LABEL: test_fadd_imm_1(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .f32 %f<3>;
+; SM80-NEXT: .reg .b16 %rs<4>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT: mov.b16 %rs2, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fadd_imm_1(
@@ -82,23 +75,14 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-LABEL: test_fsubx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: mov.b32 %r3, -1082081408;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fsubx2(
@@ -118,23 +102,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-LABEL: test_fmulx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
-; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
+; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
+; SM80-NEXT: mov.b32 %r3, 0;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fmulx2(
@@ -543,30 +518,16 @@ define <2 x bfloat> @test_fabs(<2 x bfloat> %a) #0 {
define <2 x bfloat> @test_fabs_add(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-LABEL: test_fabs_add(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
-; SM80-NEXT: .reg .b32 %r<6>;
-; SM80-NEXT: .reg .f32 %f<11>;
+; SM80-NEXT: .reg .b32 %r<7>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fabs_add_param_1];
; SM80-NEXT: ld.param.b32 %r2, [test_fabs_add_param_0];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f3, %rs2;
-; SM80-NEXT: add.rn.f32 %f4, %f3, %f3;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f4, %f2;
-; SM80-NEXT: abs.bf16x2 %r4, %r3;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r4;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
-; SM80-NEXT: mov.b32 {%rs5, %rs6}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f6, %rs5;
-; SM80-NEXT: add.rn.f32 %f7, %f5, %f6;
-; SM80-NEXT: cvt.f32.bf16 %f8, %rs4;
-; SM80-NEXT: cvt.f32.bf16 %f9, %rs6;
-; SM80-NEXT: add.rn.f32 %f10, %f8, %f9;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r5, %f10, %f7;
-; SM80-NEXT: st.param.b32 [func_retval0], %r5;
+; SM80-NEXT: mov.b32 %r3, 1065369472;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r2;
+; SM80-NEXT: abs.bf16x2 %r5, %r4;
+; SM80-NEXT: fma.rn.bf16x2 %r6, %r5, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r6;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fabs_add(
@@ -802,45 +763,18 @@ define <2 x bfloat> @test_round(<2 x bfloat> %a) #0 {
}
define <2 x bfloat> @test_copysign(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
-; SM80-LABEL: test_copysign(
-; SM80: {
-; SM80-NEXT: .reg .pred %p<3>;
-; SM80-NEXT: .reg .b16 %rs<15>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-EMPTY:
-; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_copysign_param_1];
-; SM80-NEXT: ld.param.b32 %r2, [test_copysign_param_0];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: abs.bf16 %rs3, %rs2;
-; SM80-NEXT: neg.bf16 %rs4, %rs3;
-; SM80-NEXT: mov.b32 {%rs5, %rs6}, %r1;
-; SM80-NEXT: shr.u16 %rs7, %rs6, 15;
-; SM80-NEXT: and.b16 %rs8, %rs7, 1;
-; SM80-NEXT: setp.eq.b16 %p1, %rs8, 1;
-; SM80-NEXT: selp.b16 %rs9, %rs4, %rs3, %p1;
-; SM80-NEXT: abs.bf16 %rs10, %rs1;
-; SM80-NEXT: neg.bf16 %rs11, %rs10;
-; SM80-NEXT: shr.u16 %rs12, %rs5, 15;
-; SM80-NEXT: and.b16 %rs13, %rs12, 1;
-; SM80-NEXT: setp.eq.b16 %p2, %rs13, 1;
-; SM80-NEXT: selp.b16 %rs14, %rs11, %rs10, %p2;
-; SM80-NEXT: mov.b32 %r3, {%rs14, %rs9};
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
-; SM80-NEXT: ret;
-;
-; SM90-LABEL: test_copysign(
-; SM90: {
-; SM90-NEXT: .reg .b32 %r<6>;
-; SM90-EMPTY:
-; SM90-NEXT: // %bb.0:
-; SM90-NEXT: ld.param.b32 %r1, [test_copysign_param_0];
-; SM90-NEXT: ld.param.b32 %r2, [test_copysign_param_1];
-; SM90-NEXT: and.b32 %r3, %r2, -2147450880;
-; SM90-NEXT: and.b32 %r4, %r1, 2147450879;
-; SM90-NEXT: or.b32 %r5, %r4, %r3;
-; SM90-NEXT: st.param.b32 [func_retval0], %r5;
-; SM90-NEXT: ret;
+; CHECK-LABEL: test_copysign(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [test_copysign_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_copysign_param_1];
+; CHECK-NEXT: and.b32 %r3, %r2, -2147450880;
+; CHECK-NEXT: and.b32 %r4, %r1, 2147450879;
+; CHECK-NEXT: or.b32 %r5, %r4, %r3;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r5;
+; CHECK-NEXT: ret;
%r = call <2 x bfloat> @llvm.copysign.f16(<2 x bfloat> %a, <2 x bfloat> %b)
ret <2 x bfloat> %r
}
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
index 48c94f275274bd..1643704e2ff95c 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
@@ -352,9 +352,7 @@ define bfloat @fma_bf16_expanded_no_nans(bfloat %a, bfloat %b, bfloat %c) #0 {
define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloat %c) #0 {
; CHECK-LABEL: fma_bf16_expanded_no_nans_multiple_uses_of_fma(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<9>;
-; CHECK-NEXT: .reg .b32 %r<7>;
-; CHECK-NEXT: .reg .f32 %f<6>;
+; CHECK-NEXT: .reg .b16 %rs<11>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_expanded_no_nans_multiple_uses_of_fma_param_0];
@@ -363,20 +361,11 @@ define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
; CHECK-NEXT: mov.b16 %rs5, 0x0000;
; CHECK-NEXT: max.bf16 %rs6, %rs4, %rs5;
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs4;
-; CHECK-NEXT: shl.b32 %r2, %r1, 16;
-; CHECK-NEXT: mov.b32 %f1, %r2;
-; CHECK-NEXT: add.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs7, %f2;
-; CHECK-NEXT: cvt.u32.u16 %r3, %rs6;
-; CHECK-NEXT: shl.b32 %r4, %r3, 16;
-; CHECK-NEXT: mov.b32 %f3, %r4;
-; CHECK-NEXT: cvt.u32.u16 %r5, %rs7;
-; CHECK-NEXT: shl.b32 %r6, %r5, 16;
-; CHECK-NEXT: mov.b32 %f4, %r6;
-; CHECK-NEXT: add.f32 %f5, %f3, %f4;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs8, %f5;
-; CHECK-NEXT: st.param.b16 [func_retval0], %rs8;
+; CHECK-NEXT: mov.b16 %rs7, 0x40E0;
+; CHECK-NEXT: mov.b16 %rs8, 0x3F80;
+; CHECK-NEXT: fma.rn.bf16 %rs9, %rs4, %rs8, %rs7;
+; CHECK-NEXT: fma.rn.bf16 %rs10, %rs6, %rs8, %rs9;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs10;
; CHECK-NEXT: ret;
;
; CHECK-FTZ-LABEL: fma_bf16_expanded_no_nans_multiple_uses_of_fma(
@@ -959,9 +948,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans(<2 x bfloat> %a, <2 x bfloat> %
define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) #0 {
; CHECK-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<7>;
-; CHECK-NEXT: .reg .b32 %r<20>;
-; CHECK-NEXT: .reg .f32 %f<11>;
+; CHECK-NEXT: .reg .b32 %r<11>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [fma_bf16x2_expanded_no_nans_multiple_uses_of_fma_param_2];
@@ -970,34 +957,11 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
; CHECK-NEXT: mov.b32 %r5, 0;
; CHECK-NEXT: max.bf16x2 %r6, %r4, %r5;
-; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
-; CHECK-NEXT: cvt.u32.u16 %r7, %rs2;
-; CHECK-NEXT: shl.b32 %r8, %r7, 16;
-; CHECK-NEXT: mov.b32 %f1, %r8;
-; CHECK-NEXT: add.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs3, %f2;
-; CHECK-NEXT: cvt.u32.u16 %r9, %rs1;
-; CHECK-NEXT: shl.b32 %r10, %r9, 16;
-; CHECK-NEXT: mov.b32 %f3, %r10;
-; CHECK-NEXT: add.f32 %f4, %f3, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs4, %f4;
-; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r6;
-; CHECK-NEXT: cvt.u32.u16 %r11, %rs5;
-; CHECK-NEXT: shl.b32 %r12, %r11, 16;
-; CHECK-NEXT: mov.b32 %f5, %r12;
-; CHECK-NEXT: cvt.u32.u16 %r13, %rs4;
-; CHECK-NEXT: shl.b32 %r14, %r13, 16;
-; CHECK-NEXT: mov.b32 %f6, %r14;
-; CHECK-NEXT: add.f32 %f7, %f5, %f6;
-; CHECK-NEXT: cvt.u32.u16 %r15, %rs6;
-; CHECK-NEXT: shl.b32 %r16, %r15, 16;
-; CHECK-NEXT: mov.b32 %f8, %r16;
-; CHECK-NEXT: cvt.u32.u16 %r17, %rs3;
-; CHECK-NEXT: shl.b32 %r18, %r17, 16;
-; CHECK-NEXT: mov.b32 %f9, %r18;
-; CHECK-NEXT: add.f32 %f10, %f8, %f9;
-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r19, %f10, %f7;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r19;
+; CHECK-NEXT: mov.b32 %r7, 1088438496;
+; CHECK-NEXT: mov.b32 %r8, 1065369472;
+; CHECK-NEXT: fma.rn.bf16x2 %r9, %r4, %r8, %r7;
+; CHECK-NEXT: fma.rn.bf16x2 %r10, %r6, %r8, %r9;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
; CHECK-NEXT: ret;
;
; CHECK-FTZ-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
index 561f2b0cc06730..e1e34ee9b1c159 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
@@ -221,26 +221,18 @@ define bfloat @fma_bf16_no_nans(bfloat %a, bfloat %b, bfloat %c) #0 {
define bfloat @fma_bf16_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloat %c) #0 {
; CHECK-LABEL: fma_bf16_no_nans_multiple_uses_of_fma(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<7>;
-; CHECK-NEXT: .reg .b32 %r<5>;
-; CHECK-NEXT: .reg .f32 %f<5>;
+; CHECK-NEXT: .reg .b16 %rs<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_no_nans_multiple_uses_of_fma_param_0];
; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_no_nans_multiple_uses_of_fma_param_1];
; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_no_nans_multiple_uses_of_fma_param_2];
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs4;
-; CHECK-NEXT: shl.b32 %r2, %r1, 16;
-; CHECK-NEXT: mov.b32 %f1, %r2;
-; CHECK-NEXT: add.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs5, %f2;
-; CHECK-NEXT: cvt.u32.u16 %r3, %rs5;
-; CHECK-NEXT: shl.b32 %r4, %r3, 16;
-; CHECK-NEXT: mov.b32 %f3, %r4;
-; CHECK-NEXT: add.f32 %f4, %f3, %f1;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs6, %f4;
-; CHECK-NEXT: st.param.b16 [func_retval0], %rs6;
+; CHECK-NEXT: mov.b16 %rs5, 0x40E0;
+; CHECK-NEXT: mov.b16 %rs6, 0x3F80;
+; CHECK-NEXT: fma.rn.bf16 %rs7, %rs4, %rs6, %rs5;
+; CHECK-NEXT: fma.rn.bf16 %rs8, %rs7, %rs6, %rs4;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs8;
; CHECK-NEXT: ret;
;
; CHECK-FTZ-LABEL: fma_bf16_no_nans_multiple_uses_of_fma(
@@ -642,36 +634,18 @@ define <2 x bfloat> @fma_bf16x2_no_nans(<2 x bfloat> %a, <2 x bfloat> %b, <2 x b
define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) #0 {
; CHECK-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<5>;
-; CHECK-NEXT: .reg .b32 %r<14>;
-; CHECK-NEXT: .reg .f32 %f<9>;
+; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_2];
; CHECK-NEXT: ld.param.b32 %r2, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_0];
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
-; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
-; CHECK-NEXT: cvt.u32.u16 %r5, %rs2;
-; CHECK-NEXT: shl.b32 %r6, %r5, 16;
-; CHECK-NEXT: mov.b32 %f1, %r6;
-; CHECK-NEXT: add.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs3, %f2;
-; CHECK-NEXT: cvt.u32.u16 %r7, %rs1;
-; CHECK-NEXT: shl.b32 %r8, %r7, 16;
-; CHECK-NEXT: mov.b32 %f3, %r8;
-; CHECK-NEXT: add.f32 %f4, %f3, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs4, %f4;
-; CHECK-NEXT: cvt.u32.u16 %r9, %rs4;
-; CHECK-NEXT: shl.b32 %r10, %r9, 16;
-; CHECK-NEXT: mov.b32 %f5, %r10;
-; CHECK-NEXT: add.f32 %f6, %f5, %f3;
-; CHECK-NEXT: cvt.u32.u16 %r11, %rs3;
-; CHECK-NEXT: shl.b32 %r12, %r11, 16;
-; CHECK-NEXT: mov.b32 %f7, %r12;
-; CHECK-NEXT: add.f32 %f8, %f7, %f1;
-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r13, %f8, %f6;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r13;
+; CHECK-NEXT: mov.b32 %r5, 1088438496;
+; CHECK-NEXT: mov.b32 %r6, 1065369472;
+; CHECK-NEXT: fma.rn.bf16x2 %r7, %r4, %r6, %r5;
+; CHECK-NEXT: fma.rn.bf16x2 %r8, %r7, %r6, %r4;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
;
; CHECK-FTZ-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
index b20ca24dd91a0c..ea046dc90b23f2 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
@@ -233,9 +233,7 @@ define bfloat @fma_bf16_expanded_no_nans(bfloat %a, bfloat %b, bfloat %c) {
define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloat %c) {
; CHECK-LABEL: fma_bf16_expanded_no_nans_multiple_uses_of_fma(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<9>;
-; CHECK-NEXT: .reg .b32 %r<7>;
-; CHECK-NEXT: .reg .f32 %f<6>;
+; CHECK-NEXT: .reg .b16 %rs<11>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_expanded_no_nans_multiple_uses_of_fma_param_0];
@@ -244,20 +242,11 @@ define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
; CHECK-NEXT: mov.b16 %rs5, 0x0000;
; CHECK-NEXT: max.bf16 %rs6, %rs4, %rs5;
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs4;
-; CHECK-NEXT: shl.b32 %r2, %r1, 16;
-; CHECK-NEXT: mov.b32 %f1, %r2;
-; CHECK-NEXT: add.rn.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs7, %f2;
-; CHECK-NEXT: cvt.u32.u16 %r3, %rs6;
-; CHECK-NEXT: shl.b32 %r4, %r3, 16;
-; CHECK-NEXT: mov.b32 %f3, %r4;
-; CHECK-NEXT: cvt.u32.u16 %r5, %rs7;
-; CHECK-NEXT: shl.b32 %r6, %r5, 16;
-; CHECK-NEXT: mov.b32 %f4, %r6;
-; CHECK-NEXT: add.rn.f32 %f5, %f3, %f4;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs8, %f5;
-; CHECK-NEXT: st.param.b16 [func_retval0], %rs8;
+; CHECK-NEXT: mov.b16 %rs7, 0x40E0;
+; CHECK-NEXT: mov.b16 %rs8, 0x3F80;
+; CHECK-NEXT: fma.rn.bf16 %rs9, %rs4, %rs8, %rs7;
+; CHECK-NEXT: fma.rn.bf16 %rs10, %rs6, %rs8, %rs9;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs10;
; CHECK-NEXT: ret;
;
; CHECK-FTZ-LABEL: fma_bf16_expanded_no_nans_multiple_uses_of_fma(
@@ -694,9 +683,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans(<2 x bfloat> %a, <2 x bfloat> %
define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) {
; CHECK-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<7>;
-; CHECK-NEXT: .reg .b32 %r<20>;
-; CHECK-NEXT: .reg .f32 %f<11>;
+; CHECK-NEXT: .reg .b32 %r<11>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [fma_bf16x2_expanded_no_nans_multiple_uses_of_fma_param_2];
@@ -705,34 +692,11 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
; CHECK-NEXT: mov.b32 %r5, 0;
; CHECK-NEXT: max.bf16x2 %r6, %r4, %r5;
-; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
-; CHECK-NEXT: cvt.u32.u16 %r7, %rs2;
-; CHECK-NEXT: shl.b32 %r8, %r7, 16;
-; CHECK-NEXT: mov.b32 %f1, %r8;
-; CHECK-NEXT: add.rn.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs3, %f2;
-; CHECK-NEXT: cvt.u32.u16 %r9, %rs1;
-; CHECK-NEXT: shl.b32 %r10, %r9, 16;
-; CHECK-NEXT: mov.b32 %f3, %r10;
-; CHECK-NEXT: add.rn.f32 %f4, %f3, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs4, %f4;
-; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r6;
-; CHECK-NEXT: cvt.u32.u16 %r11, %rs5;
-; CHECK-NEXT: shl.b32 %r12, %r11, 16;
-; CHECK-NEXT: mov.b32 %f5, %r12;
-; CHECK-NEXT: cvt.u32.u16 %r13, %rs4;
-; CHECK-NEXT: shl.b32 %r14, %r13, 16;
-; CHECK-NEXT: mov.b32 %f6, %r14;
-; CHECK-NEXT: add.rn.f32 %f7, %f5, %f6;
-; CHECK-NEXT: cvt.u32.u16 %r15, %rs6;
-; CHECK-NEXT: shl.b32 %r16, %r15, 16;
-; CHECK-NEXT: mov.b32 %f8, %r16;
-; CHECK-NEXT: cvt.u32.u16 %r17, %rs3;
-; CHECK-NEXT: shl.b32 %r18, %r17, 16;
-; CHECK-NEXT: mov.b32 %f9, %r18;
-; CHECK-NEXT: add.rn.f32 %f10, %f8, %f9;
-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r19, %f10, %f7;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r19;
+; CHECK-NEXT: mov.b32 %r7, 1088438496;
+; CHECK-NEXT: mov.b32 %r8, 1065369472;
+; CHECK-NEXT: fma.rn.bf16x2 %r9, %r4, %r8, %r7;
+; CHECK-NEXT: fma.rn.bf16x2 %r10, %r6, %r8, %r9;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
; CHECK-NEXT: ret;
;
; CHECK-FTZ-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
@@ -1204,26 +1168,18 @@ define bfloat @fma_bf16_no_nans(bfloat %a, bfloat %b, bfloat %c) {
define bfloat @fma_bf16_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloat %c) {
; CHECK-LABEL: fma_bf16_no_nans_multiple_uses_of_fma(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<7>;
-; CHECK-NEXT: .reg .b32 %r<5>;
-; CHECK-NEXT: .reg .f32 %f<5>;
+; CHECK-NEXT: .reg .b16 %rs<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_no_nans_multiple_uses_of_fma_param_0];
; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_no_nans_multiple_uses_of_fma_param_1];
; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_no_nans_multiple_uses_of_fma_param_2];
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs4;
-; CHECK-NEXT: shl.b32 %r2, %r1, 16;
-; CHECK-NEXT: mov.b32 %f1, %r2;
-; CHECK-NEXT: add.rn.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs5, %f2;
-; CHECK-NEXT: cvt.u32.u16 %r3, %rs5;
-; CHECK-NEXT: shl.b32 %r4, %r3, 16;
-; CHECK-NEXT: mov.b32 %f3, %r4;
-; CHECK-NEXT: add.rn.f32 %f4, %f3, %f1;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs6, %f4;
-; CHECK-NEXT: st.param.b16 [func_retval0], %rs6;
+; CHECK-NEXT: mov.b16 %rs5, 0x40E0;
+; CHECK-NEXT: mov.b16 %rs6, 0x3F80;
+; CHECK-NEXT: fma.rn.bf16 %rs7, %rs4, %rs6, %rs5;
+; CHECK-NEXT: fma.rn.bf16 %rs8, %rs7, %rs6, %rs4;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs8;
; CHECK-NEXT: ret;
;
; CHECK-FTZ-LABEL: fma_bf16_no_nans_multiple_uses_of_fma(
@@ -1629,36 +1585,18 @@ define <2 x bfloat> @fma_bf16x2_no_nans(<2 x bfloat> %a, <2 x bfloat> %b, <2 x b
define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) {
; CHECK-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<5>;
-; CHECK-NEXT: .reg .b32 %r<14>;
-; CHECK-NEXT: .reg .f32 %f<9>;
+; CHECK-NEXT: .reg .b32 %r<9>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_2];
; CHECK-NEXT: ld.param.b32 %r2, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_0];
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
-; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
-; CHECK-NEXT: cvt.u32.u16 %r5, %rs2;
-; CHECK-NEXT: shl.b32 %r6, %r5, 16;
-; CHECK-NEXT: mov.b32 %f1, %r6;
-; CHECK-NEXT: add.rn.f32 %f2, %f1, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs3, %f2;
-; CHECK-NEXT: cvt.u32.u16 %r7, %rs1;
-; CHECK-NEXT: shl.b32 %r8, %r7, 16;
-; CHECK-NEXT: mov.b32 %f3, %r8;
-; CHECK-NEXT: add.rn.f32 %f4, %f3, 0f40E00000;
-; CHECK-NEXT: cvt.rn.bf16.f32 %rs4, %f4;
-; CHECK-NEXT: cvt.u32.u16 %r9, %rs4;
-; CHECK-NEXT: shl.b32 %r10, %r9, 16;
-; CHECK-NEXT: mov.b32 %f5, %r10;
-; CHECK-NEXT: add.rn.f32 %f6, %f5, %f3;
-; CHECK-NEXT: cvt.u32.u16 %r11, %rs3;
-; CHECK-NEXT: shl.b32 %r12, %r11, 16;
-; CHECK-NEXT: mov.b32 %f7, %r12;
-; CHECK-NEXT: add.rn.f32 %f8, %f7, %f1;
-; CHECK-NEXT: cvt.rn.bf16x2.f32 %r13, %f8, %f6;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r13;
+; CHECK-NEXT: mov.b32 %r5, 1088438496;
+; CHECK-NEXT: mov.b32 %r6, 1065369472;
+; CHECK-NEXT: fma.rn.bf16x2 %r7, %r4, %r6, %r5;
+; CHECK-NEXT: fma.rn.bf16x2 %r8, %r7, %r6, %r4;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
;
; CHECK-FTZ-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
>From cda2fa4ae3643374aacc02db44ac6645f6df23c0 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Thu, 26 Dec 2024 21:42:50 +0000
Subject: [PATCH 2/7] Fix mul for negative zero
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 8 +++++---
llvm/test/CodeGen/NVPTX/bf16-instructions.ll | 2 +-
llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll | 2 +-
3 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5b41287bff842a..95f47ac12880dc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2556,11 +2556,13 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
return PromoteBinOpToF32(Op.getNode(), DAG);
}
- // FMUL(a, b) -> FMA(a, b, 0.0)
+ // FMUL(a, b) -> FMA(a, b, -0.0)
+ // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
SDLoc DL(Op);
auto VT = Op.getValueType();
- auto Zero = DAG.getConstantFP(0.0, DL, VT);
- SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
+ auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1),
+ NegZero};
return DAG.getNode(ISD::FMA, DL, VT, Operands);
}
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index eeb13b52130042..b53f82403af5f3 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -385,7 +385,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
-; SM80-NEXT: mov.b32 %r3, 0;
+; SM80-NEXT: mov.b32 %r3, -2147450880;
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 31d089a19450e1..f7ffba385df764 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -107,7 +107,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
-; SM80-NEXT: mov.b32 %r3, 0;
+; SM80-NEXT: mov.b32 %r3, -2147450880;
; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
>From 3805dd60627def3839be7d8c1338640f9f8ac44e Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Fri, 27 Dec 2024 19:49:51 +0000
Subject: [PATCH 3/7] Move fma expansions to default expand rule
---
llvm/include/llvm/CodeGen/TargetLowering.h | 17 ++++++
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 21 +++++---
.../CodeGen/SelectionDAG/TargetLowering.cpp | 54 +++++++++++++++++++
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 21 ++------
4 files changed, 88 insertions(+), 25 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index ce58777655e063..885eca15bf4bc0 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5325,6 +5325,23 @@ class TargetLowering : public TargetLoweringBase {
SDNodeFlags Flags, const SDLoc &DL,
SelectionDAG &DAG) const;
+ /// Expand floating point add
+ /// \param N Node to expand
+ /// \returns The expansion result or SDValue() if it fails.
+ SDValue expandFADD(SDNode *N, SelectionDAG &DAG) const;
+
+ /// Expand floating point multiply
+ /// \param N Node to expand
+ /// \param Result output after conversion
+ /// \returns The expansion result or SDValue() if it fails.
+ SDValue expandFMUL(SDNode *N, SelectionDAG &DAG) const;
+
+ /// Expand floating point subtract
+ /// \param N Node to expand
+ /// \param Result output after conversion
+ /// \returns The expansion result or SDValue() if it fails.
+ SDValue expandFSUB(SDNode *N, SelectionDAG &DAG) const;
+
/// Expand CTPOP nodes. Expands vector/scalar CTPOP nodes,
/// vector nodes can only succeed if all operations are legal/custom.
/// \param N Node to expand
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index c6475f02199033..912458ec0e6612 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3672,14 +3672,21 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(ExpandConstant(CP));
break;
}
+ case ISD::FADD: {
+ if (SDValue Expand = TLI.expandFADD(Node, DAG)) {
+ Results.push_back(Expand);
+ }
+ break;
+ }
+ case ISD::FMUL: {
+ if (SDValue Expand = TLI.expandFMUL(Node, DAG)) {
+ Results.push_back(Expand);
+ }
+ break;
+ }
case ISD::FSUB: {
- EVT VT = Node->getValueType(0);
- if (TLI.isOperationLegalOrCustom(ISD::FADD, VT) &&
- TLI.isOperationLegalOrCustom(ISD::FNEG, VT)) {
- const SDNodeFlags Flags = Node->getFlags();
- Tmp1 = DAG.getNode(ISD::FNEG, dl, VT, Node->getOperand(1));
- Tmp1 = DAG.getNode(ISD::FADD, dl, VT, Node->getOperand(0), Tmp1, Flags);
- Results.push_back(Tmp1);
+ if (SDValue Expand = TLI.expandFSUB(Node, DAG)) {
+ Results.push_back(Expand);
}
break;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 56194e2614af2d..4f6594c4980688 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -9068,6 +9068,60 @@ SDValue TargetLowering::expandIS_FPCLASS(EVT ResultVT, SDValue Op,
return Res;
}
+SDValue TargetLowering::expandFADD(SDNode *Node, SelectionDAG &DAG) const {
+ auto VT = Node->getValueType(0);
+ if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
+ return {};
+ }
+
+ // FADD(a, b) -> FMA(a, 1.0, b)
+ SDLoc DL(Node);
+ auto One = DAG.getConstantFP(1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Node->getOperand(0), One,
+ Node->getOperand(1)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
+}
+
+SDValue TargetLowering::expandFMUL(SDNode *Node, SelectionDAG &DAG) const {
+ auto VT = Node->getValueType(0);
+ if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
+ return {};
+ }
+
+ // FMUL(a, b) -> FMA(a, b, -0.0)
+ // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
+ SDLoc DL(Node);
+ auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Node->getOperand(0), Node->getOperand(1),
+ NegZero};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
+}
+
+SDValue TargetLowering::expandFSUB(SDNode *Node, SelectionDAG &DAG) const {
+ SDLoc DL(Node);
+ SDNodeFlags SDFlags = Node->getFlags();
+ auto VT = Node->getValueType(0);
+
+ bool CanUseFMA = isOperationLegalOrCustom(ISD::FMA, VT);
+ bool CanUseAddSub = (isOperationLegalOrCustom(ISD::FADD, VT) &&
+ isOperationLegalOrCustom(ISD::FNEG, VT));
+ bool PreferAddSub = CanUseAddSub && isFNegFree(VT);
+
+ // FSUB(a, b) -> FMA(b, -1.0, a)
+ if (CanUseFMA && !PreferAddSub) {
+ auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Node->getOperand(1), NegOne,
+ Node->getOperand(0)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands, SDFlags);
+ }
+ // FSUB(a, b) -> FADD(a, FNEG(b))
+ if (CanUseAddSub) {
+ auto Neg = DAG.getNode(ISD::FNEG, DL, VT, Node->getOperand(1));
+ return DAG.getNode(ISD::FADD, DL, VT, Node->getOperand(0), Neg, SDFlags);
+ }
+ return {};
+}
+
// Only expand vector types if we have the appropriate vector bit operations.
static bool canExpandVectorCTPOP(const TargetLowering &TLI, EVT VT) {
assert(VT.isVector() && "Expected vector type");
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 95f47ac12880dc..62f4e4cbfcefff 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2528,11 +2528,7 @@ SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
}
// FADD(a, b) -> FMA(a, 1.0, b)
- SDLoc DL(Op);
- auto VT = Op.getValueType();
- auto One = DAG.getConstantFP(1.0, DL, VT);
- SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
- return DAG.getNode(ISD::FMA, DL, VT, Operands);
+ return expandFADD(Op.getNode(), DAG);
}
SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
@@ -2542,12 +2538,7 @@ SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
}
// FSUB(a, b) -> FMA(b, -1.0, a)
- SDLoc DL(Op);
- auto VT = Op.getValueType();
- auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
- SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
- Op->getOperand(0)};
- return DAG.getNode(ISD::FMA, DL, VT, Operands);
+ return expandFSUB(Op.getNode(), DAG);
}
SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
@@ -2557,13 +2548,7 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
}
// FMUL(a, b) -> FMA(a, b, -0.0)
- // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
- SDLoc DL(Op);
- auto VT = Op.getValueType();
- auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
- SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1),
- NegZero};
- return DAG.getNode(ISD::FMA, DL, VT, Operands);
+ return expandFMUL(Op.getNode(), DAG);
}
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
>From 1c5ef7af56cc651955862341d45a837c3a5e95c4 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Thu, 9 Jan 2025 18:26:41 +0000
Subject: [PATCH 4/7] Move fma to selection stage
---
llvm/include/llvm/CodeGen/TargetLowering.h | 17 -----
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 26 ++++---
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 21 ++----
.../CodeGen/SelectionDAG/TargetLowering.cpp | 54 ---------------
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 67 +++++++++++++++++++
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 1 +
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 12 ++--
llvm/test/CodeGen/NVPTX/bf16-instructions.ll | 6 +-
.../test/CodeGen/NVPTX/bf16x2-instructions.ll | 6 +-
llvm/test/CodeGen/NVPTX/fma-relu-contract.ll | 16 ++---
.../CodeGen/NVPTX/fma-relu-fma-intrinsic.ll | 16 ++---
.../NVPTX/fma-relu-instruction-flag.ll | 32 ++++-----
12 files changed, 131 insertions(+), 143 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 885eca15bf4bc0..ce58777655e063 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5325,23 +5325,6 @@ class TargetLowering : public TargetLoweringBase {
SDNodeFlags Flags, const SDLoc &DL,
SelectionDAG &DAG) const;
- /// Expand floating point add
- /// \param N Node to expand
- /// \returns The expansion result or SDValue() if it fails.
- SDValue expandFADD(SDNode *N, SelectionDAG &DAG) const;
-
- /// Expand floating point multiply
- /// \param N Node to expand
- /// \param Result output after conversion
- /// \returns The expansion result or SDValue() if it fails.
- SDValue expandFMUL(SDNode *N, SelectionDAG &DAG) const;
-
- /// Expand floating point subtract
- /// \param N Node to expand
- /// \param Result output after conversion
- /// \returns The expansion result or SDValue() if it fails.
- SDValue expandFSUB(SDNode *N, SelectionDAG &DAG) const;
-
/// Expand CTPOP nodes. Expands vector/scalar CTPOP nodes,
/// vector nodes can only succeed if all operations are legal/custom.
/// \param N Node to expand
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 3d32fb77eb5917..6805e0cb23ace0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17559,13 +17559,10 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
return N2;
}
- const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
- !TLI.isOperationLegal(ISD::FADD, VT));
-
// FIXME: Support splat of constant.
- if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
+ if (N0CFP && N0CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
- if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
+ if (N1CFP && N1CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
// Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17597,7 +17594,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
// (fma x, -1, y) -> (fadd (fneg x), y)
// FIXME: Support splat of constant.
- if (N1CFP && !PreferFMAAdd) {
+ if (N1CFP) {
if (N1CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
@@ -17607,14 +17604,15 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
AddToWorklist(RHSNeg.getNode());
return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
}
- }
- // fma (fneg x), K, y -> fma x -K, y
- if (N1CFP && matcher.match(N0, ISD::FNEG) &&
- (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
- (N1.hasOneUse() &&
- !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
- return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
- matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
+
+ // fma (fneg x), K, y -> fma x -K, y
+ if (matcher.match(N0, ISD::FNEG) &&
+ (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+ (N1.hasOneUse() &&
+ !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
+ return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+ matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
+ }
}
// FIXME: Support splat of constant.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 912458ec0e6612..c6475f02199033 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3672,21 +3672,14 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(ExpandConstant(CP));
break;
}
- case ISD::FADD: {
- if (SDValue Expand = TLI.expandFADD(Node, DAG)) {
- Results.push_back(Expand);
- }
- break;
- }
- case ISD::FMUL: {
- if (SDValue Expand = TLI.expandFMUL(Node, DAG)) {
- Results.push_back(Expand);
- }
- break;
- }
case ISD::FSUB: {
- if (SDValue Expand = TLI.expandFSUB(Node, DAG)) {
- Results.push_back(Expand);
+ EVT VT = Node->getValueType(0);
+ if (TLI.isOperationLegalOrCustom(ISD::FADD, VT) &&
+ TLI.isOperationLegalOrCustom(ISD::FNEG, VT)) {
+ const SDNodeFlags Flags = Node->getFlags();
+ Tmp1 = DAG.getNode(ISD::FNEG, dl, VT, Node->getOperand(1));
+ Tmp1 = DAG.getNode(ISD::FADD, dl, VT, Node->getOperand(0), Tmp1, Flags);
+ Results.push_back(Tmp1);
}
break;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 4f6594c4980688..56194e2614af2d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -9068,60 +9068,6 @@ SDValue TargetLowering::expandIS_FPCLASS(EVT ResultVT, SDValue Op,
return Res;
}
-SDValue TargetLowering::expandFADD(SDNode *Node, SelectionDAG &DAG) const {
- auto VT = Node->getValueType(0);
- if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
- return {};
- }
-
- // FADD(a, b) -> FMA(a, 1.0, b)
- SDLoc DL(Node);
- auto One = DAG.getConstantFP(1.0, DL, VT);
- SmallVector<SDValue, 3> Operands{Node->getOperand(0), One,
- Node->getOperand(1)};
- return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
-}
-
-SDValue TargetLowering::expandFMUL(SDNode *Node, SelectionDAG &DAG) const {
- auto VT = Node->getValueType(0);
- if (!isOperationLegalOrCustom(ISD::FMA, VT)) {
- return {};
- }
-
- // FMUL(a, b) -> FMA(a, b, -0.0)
- // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
- SDLoc DL(Node);
- auto NegZero = DAG.getConstantFP(-0.0, DL, VT);
- SmallVector<SDValue, 3> Operands{Node->getOperand(0), Node->getOperand(1),
- NegZero};
- return DAG.getNode(ISD::FMA, DL, VT, Operands, Node->getFlags());
-}
-
-SDValue TargetLowering::expandFSUB(SDNode *Node, SelectionDAG &DAG) const {
- SDLoc DL(Node);
- SDNodeFlags SDFlags = Node->getFlags();
- auto VT = Node->getValueType(0);
-
- bool CanUseFMA = isOperationLegalOrCustom(ISD::FMA, VT);
- bool CanUseAddSub = (isOperationLegalOrCustom(ISD::FADD, VT) &&
- isOperationLegalOrCustom(ISD::FNEG, VT));
- bool PreferAddSub = CanUseAddSub && isFNegFree(VT);
-
- // FSUB(a, b) -> FMA(b, -1.0, a)
- if (CanUseFMA && !PreferAddSub) {
- auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
- SmallVector<SDValue, 3> Operands{Node->getOperand(1), NegOne,
- Node->getOperand(0)};
- return DAG.getNode(ISD::FMA, DL, VT, Operands, SDFlags);
- }
- // FSUB(a, b) -> FADD(a, FNEG(b))
- if (CanUseAddSub) {
- auto Neg = DAG.getNode(ISD::FNEG, DL, VT, Node->getOperand(1));
- return DAG.getNode(ISD::FADD, DL, VT, Node->getOperand(0), Neg, SDFlags);
- }
- return {};
-}
-
// Only expand vector types if we have the appropriate vector bit operations.
static bool canExpandVectorCTPOP(const TargetLowering &TLI, EVT VT) {
assert(VT.isVector() && "Expected vector type");
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 2e66b67dfdcc76..4b98f920330bc6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "NVPTXISelDAGToDAG.h"
+#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/ISDOpcodes.h"
@@ -190,6 +191,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
return;
}
break;
+ case ISD::FADD:
+ case ISD::FMUL:
+ case ISD::FSUB:
+ if (tryBF16ArithToFMA(N))
+ return;
+ break;
}
default:
break;
@@ -2450,6 +2457,66 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
return true;
}
+// Select bf16/bf16v2 FADD, FSUB, FMUL as fma on targets with only fma
+bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
+ EVT VT = SDValue(N, 0).getValueType();
+ if (VT.getScalarType() != MVT::bf16)
+ return false;
+
+ const NVPTXSubtarget *STI = TM.getSubtargetImpl();
+ const bool IsNativelySupported =
+ STI->getSmVersion() >= 90 && STI->getPTXVersion() >= 78;
+ if (IsNativelySupported)
+ return false;
+
+ assert(VT == MVT::bf16 || VT == MVT::v2bf16);
+ const bool IsVec = VT == MVT::v2bf16;
+ SDLoc DL(N);
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SmallVector<SDValue, 3> Operands;
+ auto GetConstant = [&](float Value) -> SDValue {
+ APFloat APF(Value);
+ bool LosesInfo;
+ APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
+ assert(!LosesInfo);
+ if (IsVec) {
+ auto API = APF.bitcastToAPInt();
+ API = API.concat(API);
+ auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
+ return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32ri, DL, VT, Const), 0);
+ }
+ auto Const = CurDAG->getTargetConstantFP(APF, DL, VT);
+ return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16ri, DL, VT, Const), 0);
+ };
+
+ switch (N->getOpcode()) {
+ case ISD::FADD: {
+ // add(a, b) -> fma(a, 1.0, b)
+ Operands = {N0, GetConstant(1.0), N1};
+ break;
+ }
+ case ISD::FSUB: {
+ // sub(a, b) -> fma(b, -1.0, a)
+ Operands = {N1, GetConstant(-1.0), N0};
+ break;
+ }
+ case ISD::FMUL: {
+ // mul(a, b) -> fma(a, b, -0.0)
+ // NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
+ Operands = {N0, N1, GetConstant(-0.0)};
+ break;
+ }
+ default:
+ llvm_unreachable("Unexpected opcode");
+ };
+
+ int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
+ MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
+ ReplaceNode(N, FMA);
+ return true;
+}
+
static inline bool isAddLike(const SDValue V) {
return V.getOpcode() == ISD::ADD ||
(V->getOpcode() == ISD::OR && V->getFlags().hasDisjoint());
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 8cadde8a822647..7661f153238fcd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -84,6 +84,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryFence(SDNode *N);
void SelectAddrSpaceCast(SDNode *N);
bool tryBFE(SDNode *N);
+ bool tryBF16ArithToFMA(SDNode *N);
bool tryConstantFP(SDNode *N);
bool SelectSETP_F16X2(SDNode *N);
bool SelectSETP_BF16X2(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 62f4e4cbfcefff..8bdb14193874d6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2527,8 +2527,8 @@ SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
return PromoteBinOpToF32(Op.getNode(), DAG);
}
- // FADD(a, b) -> FMA(a, 1.0, b)
- return expandFADD(Op.getNode(), DAG);
+ // Legal
+ return Op;
}
SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
@@ -2537,8 +2537,8 @@ SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
return PromoteBinOpToF32(Op.getNode(), DAG);
}
- // FSUB(a, b) -> FMA(b, -1.0, a)
- return expandFSUB(Op.getNode(), DAG);
+ // Legal
+ return Op;
}
SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
@@ -2547,8 +2547,8 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
return PromoteBinOpToF32(Op.getNode(), DAG);
}
- // FMUL(a, b) -> FMA(a, b, -0.0)
- return expandFMUL(Op.getNode(), DAG);
+ // Legal
+ return Op;
}
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index b53f82403af5f3..0c1b1e21669286 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -114,9 +114,9 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fsub_param_0];
-; SM80-NEXT: ld.param.b16 %rs2, [test_fsub_param_1];
-; SM80-NEXT: mov.b16 %rs3, 0xBF80;
-; SM80-NEXT: fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
+; SM80-NEXT: mov.b16 %rs2, 0xBF80;
+; SM80-NEXT: ld.param.b16 %rs3, [test_fsub_param_1];
+; SM80-NEXT: fma.rn.bf16 %rs4, %rs3, %rs2, %rs1;
; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
; SM80-NEXT: ret;
;
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index f7ffba385df764..e6d35bd5ba536b 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -26,9 +26,9 @@ define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fadd_imm_0_param_0];
-; SM80-NEXT: mov.b32 %r2, 1073758080;
-; SM80-NEXT: mov.b32 %r3, 1065369472;
-; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r3, %r2;
+; SM80-NEXT: mov.b32 %r2, 1065369472;
+; SM80-NEXT: mov.b32 %r3, 1073758080;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r2, %r3;
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
index 1643704e2ff95c..7dce894620e6bb 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
@@ -361,10 +361,10 @@ define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
; CHECK-NEXT: mov.b16 %rs5, 0x0000;
; CHECK-NEXT: max.bf16 %rs6, %rs4, %rs5;
-; CHECK-NEXT: mov.b16 %rs7, 0x40E0;
-; CHECK-NEXT: mov.b16 %rs8, 0x3F80;
-; CHECK-NEXT: fma.rn.bf16 %rs9, %rs4, %rs8, %rs7;
-; CHECK-NEXT: fma.rn.bf16 %rs10, %rs6, %rs8, %rs9;
+; CHECK-NEXT: mov.b16 %rs7, 0x3F80;
+; CHECK-NEXT: mov.b16 %rs8, 0x40E0;
+; CHECK-NEXT: fma.rn.bf16 %rs9, %rs4, %rs7, %rs8;
+; CHECK-NEXT: fma.rn.bf16 %rs10, %rs6, %rs7, %rs9;
; CHECK-NEXT: st.param.b16 [func_retval0], %rs10;
; CHECK-NEXT: ret;
;
@@ -957,10 +957,10 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
; CHECK-NEXT: mov.b32 %r5, 0;
; CHECK-NEXT: max.bf16x2 %r6, %r4, %r5;
-; CHECK-NEXT: mov.b32 %r7, 1088438496;
-; CHECK-NEXT: mov.b32 %r8, 1065369472;
-; CHECK-NEXT: fma.rn.bf16x2 %r9, %r4, %r8, %r7;
-; CHECK-NEXT: fma.rn.bf16x2 %r10, %r6, %r8, %r9;
+; CHECK-NEXT: mov.b32 %r7, 1065369472;
+; CHECK-NEXT: mov.b32 %r8, 1088438496;
+; CHECK-NEXT: fma.rn.bf16x2 %r9, %r4, %r7, %r8;
+; CHECK-NEXT: fma.rn.bf16x2 %r10, %r6, %r7, %r9;
; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
; CHECK-NEXT: ret;
;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
index e1e34ee9b1c159..eb51d7db81372d 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
@@ -228,10 +228,10 @@ define bfloat @fma_bf16_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloa
; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_no_nans_multiple_uses_of_fma_param_1];
; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_no_nans_multiple_uses_of_fma_param_2];
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
-; CHECK-NEXT: mov.b16 %rs5, 0x40E0;
-; CHECK-NEXT: mov.b16 %rs6, 0x3F80;
-; CHECK-NEXT: fma.rn.bf16 %rs7, %rs4, %rs6, %rs5;
-; CHECK-NEXT: fma.rn.bf16 %rs8, %rs7, %rs6, %rs4;
+; CHECK-NEXT: mov.b16 %rs5, 0x3F80;
+; CHECK-NEXT: mov.b16 %rs6, 0x40E0;
+; CHECK-NEXT: fma.rn.bf16 %rs7, %rs4, %rs5, %rs6;
+; CHECK-NEXT: fma.rn.bf16 %rs8, %rs7, %rs5, %rs4;
; CHECK-NEXT: st.param.b16 [func_retval0], %rs8;
; CHECK-NEXT: ret;
;
@@ -641,10 +641,10 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
; CHECK-NEXT: ld.param.b32 %r2, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_0];
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
-; CHECK-NEXT: mov.b32 %r5, 1088438496;
-; CHECK-NEXT: mov.b32 %r6, 1065369472;
-; CHECK-NEXT: fma.rn.bf16x2 %r7, %r4, %r6, %r5;
-; CHECK-NEXT: fma.rn.bf16x2 %r8, %r7, %r6, %r4;
+; CHECK-NEXT: mov.b32 %r5, 1065369472;
+; CHECK-NEXT: mov.b32 %r6, 1088438496;
+; CHECK-NEXT: fma.rn.bf16x2 %r7, %r4, %r5, %r6;
+; CHECK-NEXT: fma.rn.bf16x2 %r8, %r7, %r5, %r4;
; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
index ea046dc90b23f2..a3545f51714259 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
@@ -242,10 +242,10 @@ define bfloat @fma_bf16_expanded_no_nans_multiple_uses_of_fma(bfloat %a, bfloat
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
; CHECK-NEXT: mov.b16 %rs5, 0x0000;
; CHECK-NEXT: max.bf16 %rs6, %rs4, %rs5;
-; CHECK-NEXT: mov.b16 %rs7, 0x40E0;
-; CHECK-NEXT: mov.b16 %rs8, 0x3F80;
-; CHECK-NEXT: fma.rn.bf16 %rs9, %rs4, %rs8, %rs7;
-; CHECK-NEXT: fma.rn.bf16 %rs10, %rs6, %rs8, %rs9;
+; CHECK-NEXT: mov.b16 %rs7, 0x3F80;
+; CHECK-NEXT: mov.b16 %rs8, 0x40E0;
+; CHECK-NEXT: fma.rn.bf16 %rs9, %rs4, %rs7, %rs8;
+; CHECK-NEXT: fma.rn.bf16 %rs10, %rs6, %rs7, %rs9;
; CHECK-NEXT: st.param.b16 [func_retval0], %rs10;
; CHECK-NEXT: ret;
;
@@ -692,10 +692,10 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
; CHECK-NEXT: mov.b32 %r5, 0;
; CHECK-NEXT: max.bf16x2 %r6, %r4, %r5;
-; CHECK-NEXT: mov.b32 %r7, 1088438496;
-; CHECK-NEXT: mov.b32 %r8, 1065369472;
-; CHECK-NEXT: fma.rn.bf16x2 %r9, %r4, %r8, %r7;
-; CHECK-NEXT: fma.rn.bf16x2 %r10, %r6, %r8, %r9;
+; CHECK-NEXT: mov.b32 %r7, 1065369472;
+; CHECK-NEXT: mov.b32 %r8, 1088438496;
+; CHECK-NEXT: fma.rn.bf16x2 %r9, %r4, %r7, %r8;
+; CHECK-NEXT: fma.rn.bf16x2 %r10, %r6, %r7, %r9;
; CHECK-NEXT: st.param.b32 [func_retval0], %r10;
; CHECK-NEXT: ret;
;
@@ -1175,10 +1175,10 @@ define bfloat @fma_bf16_no_nans_multiple_uses_of_fma(bfloat %a, bfloat %b, bfloa
; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_no_nans_multiple_uses_of_fma_param_1];
; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_no_nans_multiple_uses_of_fma_param_2];
; CHECK-NEXT: fma.rn.bf16 %rs4, %rs1, %rs2, %rs3;
-; CHECK-NEXT: mov.b16 %rs5, 0x40E0;
-; CHECK-NEXT: mov.b16 %rs6, 0x3F80;
-; CHECK-NEXT: fma.rn.bf16 %rs7, %rs4, %rs6, %rs5;
-; CHECK-NEXT: fma.rn.bf16 %rs8, %rs7, %rs6, %rs4;
+; CHECK-NEXT: mov.b16 %rs5, 0x3F80;
+; CHECK-NEXT: mov.b16 %rs6, 0x40E0;
+; CHECK-NEXT: fma.rn.bf16 %rs7, %rs4, %rs5, %rs6;
+; CHECK-NEXT: fma.rn.bf16 %rs8, %rs7, %rs5, %rs4;
; CHECK-NEXT: st.param.b16 [func_retval0], %rs8;
; CHECK-NEXT: ret;
;
@@ -1592,10 +1592,10 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
; CHECK-NEXT: ld.param.b32 %r2, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [fma_bf16x2_no_nans_multiple_uses_of_fma_param_0];
; CHECK-NEXT: fma.rn.bf16x2 %r4, %r3, %r2, %r1;
-; CHECK-NEXT: mov.b32 %r5, 1088438496;
-; CHECK-NEXT: mov.b32 %r6, 1065369472;
-; CHECK-NEXT: fma.rn.bf16x2 %r7, %r4, %r6, %r5;
-; CHECK-NEXT: fma.rn.bf16x2 %r8, %r7, %r6, %r4;
+; CHECK-NEXT: mov.b32 %r5, 1065369472;
+; CHECK-NEXT: mov.b32 %r6, 1088438496;
+; CHECK-NEXT: fma.rn.bf16x2 %r7, %r4, %r5, %r6;
+; CHECK-NEXT: fma.rn.bf16x2 %r8, %r7, %r5, %r4;
; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
;
>From 67a0d7f3e55e32b0c532dee18ba667fe8a15ceeb Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Thu, 9 Jan 2025 22:35:18 +0000
Subject: [PATCH 5/7] Remove brackets from switch cases
---
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 11 ++++-------
1 file changed, 4 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 4b98f920330bc6..29ab4111b953a0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -191,13 +191,13 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
return;
}
break;
+ }
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
if (tryBF16ArithToFMA(N))
return;
break;
- }
default:
break;
}
@@ -2491,22 +2491,19 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
};
switch (N->getOpcode()) {
- case ISD::FADD: {
+ case ISD::FADD:
// add(a, b) -> fma(a, 1.0, b)
Operands = {N0, GetConstant(1.0), N1};
break;
- }
- case ISD::FSUB: {
+ case ISD::FSUB:
// sub(a, b) -> fma(b, -1.0, a)
Operands = {N1, GetConstant(-1.0), N0};
break;
- }
- case ISD::FMUL: {
+ case ISD::FMUL:
// mul(a, b) -> fma(a, b, -0.0)
// NOTE: The identity is -0, not 0, because -0 + 0 == 0 for floats
Operands = {N0, N1, GetConstant(-0.0)};
break;
- }
default:
llvm_unreachable("Unexpected opcode");
};
>From 12b5115a9bb74d7bde1be13d87d0d0f7a11d646d Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Tue, 14 Jan 2025 15:14:26 +0000
Subject: [PATCH 6/7] Cleanup duplicate lowerings
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 31 +++------------------
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 4 +--
2 files changed, 5 insertions(+), 30 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8bdb14193874d6..4ea167d4f3c4ae 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2521,33 +2521,11 @@ static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
return DAG.getFPExtendOrRound(Res, DL, VT);
}
-SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
- // No fma.ftz for bf16, so fall back to promotion
- if (useF32FTZ(DAG.getMachineFunction())) {
- return PromoteBinOpToF32(Op.getNode(), DAG);
- }
-
- // Legal
- return Op;
-}
-
-SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
- // No fma.ftz for bf16, so fall back to promotion
- if (useF32FTZ(DAG.getMachineFunction())) {
- return PromoteBinOpToF32(Op.getNode(), DAG);
- }
-
- // Legal
- return Op;
-}
-
-SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
- // No fma.ftz for bf16, so fall back to promotion
+SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
+ SelectionDAG &DAG) const {
if (useF32FTZ(DAG.getMachineFunction())) {
return PromoteBinOpToF32(Op.getNode(), DAG);
}
-
- // Legal
return Op;
}
@@ -2743,11 +2721,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::CopyToReg:
return LowerCopyToReg_128(Op, DAG);
case ISD::FADD:
- return LowerFADD(Op, DAG);
case ISD::FSUB:
- return LowerFSUB(Op, DAG);
case ISD::FMUL:
- return LowerFMUL(Op, DAG);
+ // Used only for bf16 on SM80, where we select fma for non-ftz operation
+ return PromoteBinOpIfF32FTZ(Op, DAG);
default:
llvm_unreachable("Custom lowering not defined for operation");
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 514bf7cc240c9c..5adf69d621552f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -278,9 +278,7 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
- SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
- SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
- SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
+ SDValue PromoteBinOpIfF32FTZ(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
>From 403aaee454d7fa4cec42da37f20a579fb2983a0e Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Thu, 16 Jan 2025 05:07:05 +0000
Subject: [PATCH 7/7] Resolve NITs
---
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 9 +++---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 34 ++-------------------
llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp | 32 +++++++++++++++++++
llvm/lib/Target/NVPTX/NVPTXSubtarget.h | 2 ++
4 files changed, 41 insertions(+), 36 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 29ab4111b953a0..8f6adf2c22f922 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2464,18 +2464,17 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
return false;
const NVPTXSubtarget *STI = TM.getSubtargetImpl();
- const bool IsNativelySupported =
- STI->getSmVersion() >= 90 && STI->getPTXVersion() >= 78;
- if (IsNativelySupported)
+ if (STI->hasNativeBF16Support(N->getOpcode()))
return false;
- assert(VT == MVT::bf16 || VT == MVT::v2bf16);
- const bool IsVec = VT == MVT::v2bf16;
+ const bool IsVec = VT.isVector();
+ assert(!IsVec || VT.getVectorNumElements() == 2);
SDLoc DL(N);
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SmallVector<SDValue, 3> Operands;
auto GetConstant = [&](float Value) -> SDValue {
+ // BF16 immediates must be legalized to integer register values
APFloat APF(Value);
bool LosesInfo;
APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &LosesInfo);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 4ea167d4f3c4ae..899db28a0ef642 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -535,34 +535,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
LegalizeAction NoBF16Action) {
- bool IsOpSupported = STI.hasBF16Math();
- switch (Op) {
- // Several BF16 instructions are available on sm_90 only.
- case ISD::FADD:
- case ISD::FMUL:
- case ISD::FSUB:
- case ISD::SELECT:
- case ISD::SELECT_CC:
- case ISD::SETCC:
- case ISD::FEXP2:
- case ISD::FCEIL:
- case ISD::FFLOOR:
- case ISD::FNEARBYINT:
- case ISD::FRINT:
- case ISD::FROUNDEVEN:
- case ISD::FTRUNC:
- IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
- break;
- // Several BF16 instructions are available on sm_80 only.
- case ISD::FMINNUM:
- case ISD::FMAXNUM:
- case ISD::FMAXNUM_IEEE:
- case ISD::FMINNUM_IEEE:
- case ISD::FMAXIMUM:
- case ISD::FMINIMUM:
- IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
- break;
- }
+ bool IsOpSupported = STI.hasNativeBF16Support(Op);
setOperationAction(
Op, VT, IsOpSupported ? Action : NoBF16Action);
};
@@ -862,11 +835,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
- // Lower bf16 add/mul/sub as fma when it avoids promotion
+ // On SM80, we select add/mul/sub as fma to avoid promotion to float
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
- if (getOperationAction(Op, VT) != Legal &&
- getOperationAction(ISD::FMA, VT) == Legal) {
+ if (!STI.hasNativeBF16Support(Op) && STI.hasNativeBF16Support(ISD::FMA)) {
setOperationAction(Op, VT, Custom);
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
index 74ce6a9fc4ac08..e5d680c19d9211 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
@@ -70,6 +70,38 @@ bool NVPTXSubtarget::allowFP16Math() const {
return hasFP16Math() && NoF16Math == false;
}
+bool NVPTXSubtarget::hasNativeBF16Support(int Opcode) const {
+ if (!hasBF16Math())
+ return false;
+
+ switch (Opcode) {
+ // Several BF16 instructions are available on sm_90 only.
+ case ISD::FADD:
+ case ISD::FMUL:
+ case ISD::FSUB:
+ case ISD::SELECT:
+ case ISD::SELECT_CC:
+ case ISD::SETCC:
+ case ISD::FEXP2:
+ case ISD::FCEIL:
+ case ISD::FFLOOR:
+ case ISD::FNEARBYINT:
+ case ISD::FRINT:
+ case ISD::FROUNDEVEN:
+ case ISD::FTRUNC:
+ return getSmVersion() >= 90 && getPTXVersion() >= 78;
+ // Several BF16 instructions are available on sm_80 only.
+ case ISD::FMINNUM:
+ case ISD::FMAXNUM:
+ case ISD::FMAXNUM_IEEE:
+ case ISD::FMINNUM_IEEE:
+ case ISD::FMAXIMUM:
+ case ISD::FMINIMUM:
+ return getSmVersion() >= 80 && getPTXVersion() >= 70;
+ }
+ return true;
+}
+
void NVPTXSubtarget::failIfClustersUnsupported(
std::string const &FailureMessage) const {
if (hasClusters())
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index bbc1cca7c12d85..3b5c28e357e0cc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -118,6 +118,8 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
}
bool hasTargetName() const { return !TargetName.empty(); }
+ bool hasNativeBF16Support(int Opcode) const;
+
// Get maximum value of required alignments among the supported data types.
// From the PTX ISA doc, section 8.2.3:
// The memory consistency model relates operations executed on memory
More information about the llvm-commits
mailing list