[llvm] [NVPTX] Lower bfloat16 add/mul/sub as fma on SM80 (PR #121065)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 24 12:14:50 PST 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 852feea820f3f8b2fc44c851cc3ce5fe9576fa64 81a8726c34075070443454b9a4eac189e5fe843b --extensions h,cpp -- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp llvm/lib/Target/NVPTX/NVPTXISelLowering.h
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d7eb815904..a50ac311c8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17534,8 +17534,8 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
       return N2;
   }
 
-  const bool PreferFMAAdd = (
-    TLI.isOperationLegal(ISD::FMA, VT) && !TLI.isOperationLegal(ISD::FADD, VT));
+  const bool PreferFMAAdd = (TLI.isOperationLegal(ISD::FMA, VT) &&
+                             !TLI.isOperationLegal(ISD::FADD, VT));
 
   // FIXME: Support splat of constant.
   if (!PreferFMAAdd && N0CFP && N0CFP->isExactlyValue(1.0))
@@ -17588,8 +17588,8 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
       (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
        (N1.hasOneUse() &&
         !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
-      return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
-                             matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
+    return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+                           matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
   }
 
   // FIXME: Support splat of constant.
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 256cb8abd8..eb41a71809 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -856,8 +856,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   // Lower bf16 add/mul/sub as fma when it avoids promotion
   for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
     for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
-      if (STI.getSmVersion() != 61 &&
-          getOperationAction(Op, VT) != Legal && 
+      if (STI.getSmVersion() != 61 && getOperationAction(Op, VT) != Legal &&
           getOperationAction(ISD::FMA, VT) == Legal) {
         setOperationAction(Op, VT, Custom);
       }
@@ -2514,8 +2513,7 @@ static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
   return DAG.getFPExtendOrRound(Res, DL, VT);
 }
 
-SDValue NVPTXTargetLowering::LowerFADD(SDValue Op,
-                                       SelectionDAG &DAG) const {
+SDValue NVPTXTargetLowering::LowerFADD(SDValue Op, SelectionDAG &DAG) const {
   // No fma.ftz for bf16, so fall back to promotion
   if (useF32FTZ(DAG.getMachineFunction())) {
     return PromoteBinOpToF32(Op.getNode(), DAG);
@@ -2525,16 +2523,11 @@ SDValue NVPTXTargetLowering::LowerFADD(SDValue Op,
   SDLoc DL(Op);
   auto VT = Op.getValueType();
   auto One = DAG.getConstantFP(1.0, DL, VT);
-  SmallVector<SDValue, 3> Operands{
-    Op->getOperand(0),
-    One,
-    Op->getOperand(1)
-  };
+  SmallVector<SDValue, 3> Operands{Op->getOperand(0), One, Op->getOperand(1)};
   return DAG.getNode(ISD::FMA, DL, VT, Operands);
 }
 
-SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op,
-                                       SelectionDAG &DAG) const {
+SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op, SelectionDAG &DAG) const {
   // No fma.ftz for bf16, so fall back to promotion
   if (useF32FTZ(DAG.getMachineFunction())) {
     return PromoteBinOpToF32(Op.getNode(), DAG);
@@ -2544,16 +2537,12 @@ SDValue NVPTXTargetLowering::LowerFSUB(SDValue Op,
   SDLoc DL(Op);
   auto VT = Op.getValueType();
   auto NegOne = DAG.getConstantFP(-1.0, DL, VT);
-  SmallVector<SDValue, 3> Operands{
-    Op->getOperand(1),
-    NegOne,
-    Op->getOperand(0)
-  };
+  SmallVector<SDValue, 3> Operands{Op->getOperand(1), NegOne,
+                                   Op->getOperand(0)};
   return DAG.getNode(ISD::FMA, DL, VT, Operands);
 }
 
-SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op,
-                                       SelectionDAG &DAG) const {
+SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
   // No fma.ftz for bf16, so fall back to promotion
   if (useF32FTZ(DAG.getMachineFunction())) {
     return PromoteBinOpToF32(Op.getNode(), DAG);
@@ -2563,11 +2552,7 @@ SDValue NVPTXTargetLowering::LowerFMUL(SDValue Op,
   SDLoc DL(Op);
   auto VT = Op.getValueType();
   auto Zero = DAG.getConstantFP(0.0, DL, VT);
-  SmallVector<SDValue, 3> Operands{
-    Op->getOperand(0),
-    Op->getOperand(1),
-    Zero
-  };
+  SmallVector<SDValue, 3> Operands{Op->getOperand(0), Op->getOperand(1), Zero};
   return DAG.getNode(ISD::FMA, DL, VT, Operands);
 }
 
@@ -2768,7 +2753,7 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerFSUB(Op, DAG);
   case ISD::FMUL:
     return LowerFMUL(Op, DAG);
-  
+
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/121065


More information about the llvm-commits mailing list