[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 24 13:47:26 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: None (peterbell10)
<details>
<summary>Changes</summary>
SM80 has fma for bfloat16 but not add/mul/sub. Currently these are just promoted to f32 but we can instead write them in terms of the fma:
```
FADD(a, b) -> FMA(a, 1.0, b)
FMUL(a, b) -> FMA(a, b, 0.0)
FSUB(a, b) -> FMA(b, -1.0, a)
```
Unfortunately there is no `fma.ftz` so when ftz is enabled, we still fall back to promotion.
This is also the inverse of some generic DAGCombiner patterns, so I've had to add checks to avoid it reversing the legalization which would cause an infinite loop.
---
Patch is 40.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121065.diff
9 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+14-12)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+73)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+4)
- (modified) llvm/test/CodeGen/NVPTX/atomics-sm90.ll (+25-31)
- (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+28-63)
- (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+37-103)
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-contract.ll (+12-48)
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll (+12-38)
- (modified) llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll (+24-86)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6cbfef2d238bbe..a50ac311c82869 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17534,10 +17534,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
return N2;
}
+ const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
+ !TLI.isOperationLegal(ISD::FADD, VT));
+
// FIXME: Support splat of constant.
- if (N0CFP && N0CFP->isExactlyValue(1.0))
+ if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
- if (N1CFP && N1CFP->isExactlyValue(1.0))
+ if (!PreferFMAAdd && N1CFP && N1CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
// Canonicalize (fma c, x, y) -> (fma x, c, y)
@@ -17569,7 +17572,7 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
// (fma x, -1, y) -> (fadd (fneg x), y)
// FIXME: Support splat of constant.
- if (N1CFP) {
+ if (N1CFP && !PreferFMAAdd) {
if (N1CFP->isExactlyValue(1.0))
return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
@@ -17579,15 +17582,14 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
AddToWorklist(RHSNeg.getNode());
return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
}
-
- // fma (fneg x), K, y -> fma x -K, y
- if (matcher.match(N0, ISD::FNEG) &&
- (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
- (N1.hasOneUse() &&
- !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
- return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
- matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
- }
+ }
+ // fma (fneg x), K, y -> fma x -K, y
+ if (N1CFP && matcher.match(N0, ISD::FNEG) &&
+ (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+ (N1.hasOneUse() &&
+ !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
+ return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+ matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
}
// FIXME: Support splat of constant.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..47f56abae3c056 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -853,6 +853,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
+ // Lower bf16 add/mul/sub as fma when it avoids promotion
+ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
+ for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
+ if (getOperationAction(Op, VT) != Legal &&
+ getOperationAction(ISD::FMA, VT) == Legal) {
+ setOperationAction(Op, VT, Custom);
+ }
+ }
+ }
+
// f16/f16x2 neg was introduced in PTX 60, SM_53.
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
STI.getPTXVersion() >= 60 &&
@@ -2490,6 +2500,62 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
}
+static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ EVT NVT = MVT::f32;
+ if (VT.isVector()) {
+ NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
+ }
+ SDLoc DL(N);
+ SDValue Tmp0 = DAG.getFPExtendOrRound(N->getOperand(0), DL, NVT);
+ SDValue Tmp1 = DAG.getFPExtendOrRound(N->getOperand(1), DL, NVT);
+ SDValue Res = DAG.getNode(N->getOpcode(), DL, NVT, Tmp0, Tmp1, N->getFlags());
+ return DAG.getFPExtendOrRound(Res, DL, VT);
+}
+
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FADD(a, b) -> FMA(a, 1.0, b)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto One = DAG.getConstantFP(1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FSUB(a, b) -> FMA(b, -1.0, a)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+ Op->getOperand(0)};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
+ // No fma.ftz for bf16, so fall back to promotion
+ if (useF32FTZ(DAG.getMachineFunction())) {
+ return PromoteBinOpToF32(Op.getNode(), DAG);
+ }
+
+ // FMUL(a, b) -> FMA(a, b, 0.0)
+ SDLoc DL(Op);
+ auto VT = Op.getValueType();
+ auto Zero = DAG.getConstantFP(0.0, DL, VT);
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
+ return DAG.getNode(ISD::FMA, DL, VT, Operands);
+}
+
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
SelectionDAG &DAG) const {
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
@@ -2681,6 +2747,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerSTACKSAVE(Op, DAG);
case ISD::CopyToReg:
return LowerCopyToReg_128(Op, DAG);
+ case ISD::FADD:
+ return LowerFADD(Op, DAG);
+ case ISD::FSUB:
+ return LowerFSUB(Op, DAG);
+ case ISD::FMUL:
+ return LowerFMUL(Op, DAG);
+
default:
llvm_unreachable("Custom lowering not defined for operation");
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 4a98fe21b81dc6..b7d32dd5327646 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -279,6 +279,10 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFADD(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFSUB(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
index f81b785f13225c..67552b95e04915 100644
--- a/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
+++ b/llvm/test/CodeGen/NVPTX/atomics-sm90.ll
@@ -46,58 +46,52 @@ define void @test(ptr %dp0, ptr addrspace(1) %dp1, ptr addrspace(3) %dp3, bfloat
; CHECKPTX71-LABEL: test(
; CHECKPTX71: {
; CHECKPTX71-NEXT: .reg .pred %p<5>;
-; CHECKPTX71-NEXT: .reg .b16 %rs<22>;
+; CHECKPTX71-NEXT: .reg .b16 %rs<26>;
; CHECKPTX71-NEXT: .reg .b32 %r<4>;
-; CHECKPTX71-NEXT: .reg .f32 %f<12>;
; CHECKPTX71-EMPTY:
; CHECKPTX71-NEXT: // %bb.0:
; CHECKPTX71-NEXT: ld.param.b16 %rs13, [test_param_3];
; CHECKPTX71-NEXT: ld.param.u32 %r3, [test_param_2];
; CHECKPTX71-NEXT: ld.param.u32 %r2, [test_param_1];
; CHECKPTX71-NEXT: ld.param.u32 %r1, [test_param_0];
-; CHECKPTX71-NEXT: ld.b16 %rs18, [%r1];
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f1, %rs13;
+; CHECKPTX71-NEXT: ld.b16 %rs22, [%r1];
; CHECKPTX71-NEXT: $L__BB0_1: // %atomicrmw.start14
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f2, %rs18;
-; CHECKPTX71-NEXT: add.rn.f32 %f3, %f2, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs14, %f3;
-; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs18, %rs14;
-; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs18;
-; CHECKPTX71-NEXT: mov.u16 %rs18, %rs3;
+; CHECKPTX71-NEXT: mov.b16 %rs14, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs15, %rs22, %rs14, %rs13;
+; CHECKPTX71-NEXT: atom.cas.b16 %rs3, [%r1], %rs22, %rs15;
+; CHECKPTX71-NEXT: setp.ne.s16 %p1, %rs3, %rs22;
+; CHECKPTX71-NEXT: mov.u16 %rs22, %rs3;
; CHECKPTX71-NEXT: @%p1 bra $L__BB0_1;
; CHECKPTX71-NEXT: // %bb.2: // %atomicrmw.end13
-; CHECKPTX71-NEXT: ld.b16 %rs19, [%r1];
+; CHECKPTX71-NEXT: ld.b16 %rs23, [%r1];
; CHECKPTX71-NEXT: $L__BB0_3: // %atomicrmw.start8
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f4, %rs19;
-; CHECKPTX71-NEXT: add.rn.f32 %f5, %f4, 0f3F800000;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs15, %f5;
-; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs19, %rs15;
-; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs19;
-; CHECKPTX71-NEXT: mov.u16 %rs19, %rs6;
+; CHECKPTX71-NEXT: mov.b16 %rs16, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs17, %rs23, %rs16, %rs16;
+; CHECKPTX71-NEXT: atom.cas.b16 %rs6, [%r1], %rs23, %rs17;
+; CHECKPTX71-NEXT: setp.ne.s16 %p2, %rs6, %rs23;
+; CHECKPTX71-NEXT: mov.u16 %rs23, %rs6;
; CHECKPTX71-NEXT: @%p2 bra $L__BB0_3;
; CHECKPTX71-NEXT: // %bb.4: // %atomicrmw.end7
-; CHECKPTX71-NEXT: ld.global.b16 %rs20, [%r2];
+; CHECKPTX71-NEXT: ld.global.b16 %rs24, [%r2];
; CHECKPTX71-NEXT: $L__BB0_5: // %atomicrmw.start2
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f7, %rs20;
-; CHECKPTX71-NEXT: add.rn.f32 %f8, %f7, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs16, %f8;
-; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs20, %rs16;
-; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs20;
-; CHECKPTX71-NEXT: mov.u16 %rs20, %rs9;
+; CHECKPTX71-NEXT: mov.b16 %rs18, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs19, %rs24, %rs18, %rs13;
+; CHECKPTX71-NEXT: atom.global.cas.b16 %rs9, [%r2], %rs24, %rs19;
+; CHECKPTX71-NEXT: setp.ne.s16 %p3, %rs9, %rs24;
+; CHECKPTX71-NEXT: mov.u16 %rs24, %rs9;
; CHECKPTX71-NEXT: @%p3 bra $L__BB0_5;
; CHECKPTX71-NEXT: // %bb.6: // %atomicrmw.end1
-; CHECKPTX71-NEXT: ld.shared.b16 %rs21, [%r3];
+; CHECKPTX71-NEXT: ld.shared.b16 %rs25, [%r3];
; CHECKPTX71-NEXT: $L__BB0_7: // %atomicrmw.start
; CHECKPTX71-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECKPTX71-NEXT: cvt.f32.bf16 %f10, %rs21;
-; CHECKPTX71-NEXT: add.rn.f32 %f11, %f10, %f1;
-; CHECKPTX71-NEXT: cvt.rn.bf16.f32 %rs17, %f11;
-; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs21, %rs17;
-; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs21;
-; CHECKPTX71-NEXT: mov.u16 %rs21, %rs12;
+; CHECKPTX71-NEXT: mov.b16 %rs20, 0x3F80;
+; CHECKPTX71-NEXT: fma.rn.bf16 %rs21, %rs25, %rs20, %rs13;
+; CHECKPTX71-NEXT: atom.shared.cas.b16 %rs12, [%r3], %rs25, %rs21;
+; CHECKPTX71-NEXT: setp.ne.s16 %p4, %rs12, %rs25;
+; CHECKPTX71-NEXT: mov.u16 %rs25, %rs12;
; CHECKPTX71-NEXT: @%p4 bra $L__BB0_7;
; CHECKPTX71-NEXT: // %bb.8: // %atomicrmw.end
; CHECKPTX71-NEXT: ret;
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 6828bac18cad7f..eeb13b52130042 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -42,17 +42,14 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) {
;
; SM80-LABEL: test_fadd(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<4>;
-; SM80-NEXT: .reg .f32 %f<4>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_param_0];
; SM80-NEXT: ld.param.b16 %rs2, [test_fadd_param_1];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT: mov.b16 %rs3, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs4, %rs1, %rs3, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fadd(
@@ -113,17 +110,14 @@ define bfloat @test_fsub(bfloat %0, bfloat %1) {
;
; SM80-LABEL: test_fsub(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<4>;
-; SM80-NEXT: .reg .f32 %f<4>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fsub_param_0];
; SM80-NEXT: ld.param.b16 %rs2, [test_fsub_param_1];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs1;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f3;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
+; SM80-NEXT: mov.b16 %rs3, 0xBF80;
+; SM80-NEXT: fma.rn.bf16 %rs4, %rs2, %rs3, %rs1;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fsub(
@@ -202,23 +196,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_faddx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
-; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
+; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
+; SM80-NEXT: mov.b32 %r3, 1065369472;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_faddx2(
@@ -303,23 +288,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fsubx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: mov.b32 %r3, -1082081408;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fsubx2(
@@ -404,23 +380,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fmulx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
-; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_1];
+; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_0];
+; SM80-NEXT: mov.b32 %r3, 0;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r1, %r3;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fmulx2(
@@ -727,15 +694,13 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
;
; SM80-LABEL: test_fadd_imm_1(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .f32 %f<3>;
+; SM80-NEXT: .reg .b16 %rs<4>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT: mov.b16 %rs2, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fadd_imm_1(
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 03cdeb9683abae..31d089a19450e1 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -22,19 +22,14 @@ define <2 x bfloat> @test_ret_const() #0 {
define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
; SM80-LABEL: test_fadd_imm_0(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .b32 %r<3>;
-; SM80-NEXT: .reg .f32 %f<5>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fadd_imm_0_param_0];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.f32.bf16 %f3, %rs2;
-; SM80-NEXT: add.rn.f32 %f4, %f3, 0f40000000;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r2, %f4, %f2;
-; SM80-NEXT: st.param.b32 [func_retval0], %r2;
+; SM80-NEXT: mov.b32 %r2, 1073758080;
+; SM80-NEXT: mov.b32 %r3, 1065369472;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r3, %r2;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fadd_imm_0(
@@ -54,15 +49,13 @@ define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
define bfloat @test_fadd_imm_1(bfloat %a) #0 {
; SM80-LABEL: test_fadd_imm_1(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<3>;
-; SM80-NEXT: .reg .f32 %f<3>;
+; SM80-NEXT: .reg .b16 %rs<4>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b16 %rs1, [test_fadd_imm_1_param_0];
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: add.rn.f32 %f2, %f1, 0f3F800000;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs2, %f2;
-; SM80-NEXT: st.param.b16 [func_retval0], %rs2;
+; SM80-NEXT: mov.b16 %rs2, 0x3F80;
+; SM80-NEXT: fma.rn.bf16 %rs3, %rs1, %rs2, %rs2;
+; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fadd_imm_1(
@@ -82,23 +75,14 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-LABEL: test_fsubx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<5>;
-; SM80-NEXT: .reg .b32 %r<4>;
-; SM80-NEXT: .reg .f32 %f<7>;
+; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
-; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
-; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
-; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
-; SM80-NEXT: st.param.b32 [func_retval0], %r3;
+; SM80-NEXT: mov.b32 %r3, -1082081408;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM90-LABEL: test_fsubx2(
@@...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/121065
More information about the llvm-commits
mailing list