[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 24 12:14:50 PST 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 852feea820f3f8b2fc44c851cc3ce5fe9576fa64 81a8726c34075070443454b9a4eac189e5fe843b --extensions h,cpp -- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp llvm/lib/Target/NVPTX/NVPTXISelLowering.h
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d7eb815904..a50ac311c8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17534,8 +17534,8 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
return N2;
}
- const bool PreferFMAAdd = (
- TLI.isOperationLegal(ISD::FMA, VT) && !TLI.isOperationLegal(ISD::FADD, VT));
+ const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
+ !TLI.isOperationLegal(ISD::FADD, VT));
// FIXME: Support splat of constant.
if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
@@ -17588,8 +17588,8 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
(N1.hasOneUse() &&
!TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
- return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
- matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
+ return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+ matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
}
// FIXME: Support splat of constant.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 256cb8abd8..eb41a71809 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -856,8 +856,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Lower bf16 add/mul/sub as fma when it avoids promotion
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
- if (STI.getSmVersion() != 61 &&
- getOperationAction(Op, VT) != Legal &&
+ if (STI.getSmVersion() != 61 && getOperationAction(Op, VT) != Legal &&
getOperationAction(ISD::FMA, VT) == Legal) {
setOperationAction(Op, VT, Custom);
}
@@ -2514,8 +2513,7 @@ static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
return DAG.getFPExtendOrRound(Res, DL, VT);
}
-SDValue NVPTXTargetLowering::LowerFADD(SDValue Op,
- SelectionDAG &DAG) const {
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
// No fma.ftz for bf16, so fall back to promotion
if (useF32FTZ(DAG.getMachineFunction())) {
return PromoteBinOpToF32(Op.getNode(), DAG);
@@ -2525,16 +2523,11 @@ SDValue NVPTXTargetLowering::LowerFADD(SDValue Op,
SDLoc DL(Op);
auto VT = Op.getValueType();
auto One = DAG.getConstantFP(1.0, DL, VT);
- SmallVector<SDValue, 3> Operands{
- Op->getOperand(0),
- One,
- Op->getOperand(1)
- };
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
return DAG.getNode(ISD::FMA, DL, VT, Operands);
}
-SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op,
- SelectionDAG &DAG) const {
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
// No fma.ftz for bf16, so fall back to promotion
if (useF32FTZ(DAG.getMachineFunction())) {
return PromoteBinOpToF32(Op.getNode(), DAG);
@@ -2544,16 +2537,12 @@ SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op,
SDLoc DL(Op);
auto VT = Op.getValueType();
auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
- SmallVector<SDValue, 3> Operands{
- Op->getOperand(1),
- NegOne,
- Op->getOperand(0)
- };
+ SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+ Op->getOperand(0)};
return DAG.getNode(ISD::FMA, DL, VT, Operands);
}
-SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op,
- SelectionDAG &DAG) const {
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
// No fma.ftz for bf16, so fall back to promotion
if (useF32FTZ(DAG.getMachineFunction())) {
return PromoteBinOpToF32(Op.getNode(), DAG);
@@ -2563,11 +2552,7 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op,
SDLoc DL(Op);
auto VT = Op.getValueType();
auto Zero = DAG.getConstantFP(0.0, DL, VT);
- SmallVector<SDValue, 3> Operands{
- Op->getOperand(0),
- Op->getOperand(1),
- Zero
- };
+ SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
return DAG.getNode(ISD::FMA, DL, VT, Operands);
}
@@ -2768,7 +2753,7 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerFSUB(Op, DAG);
case ISD::FMUL:
return LowerFMUL(Op, DAG);
-
+
default:
llvm_unreachable("Custom lowering not defined for operation");
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/121065
More information about the llvm-commits
mailing list