[llvm] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select c, (mad a, b, c)) (PR #96352)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 12:49:24 PDT 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/96352

>From dbd09da875aee8d77b7e1c3fda518f0c88569668 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Fri, 21 Jun 2024 19:58:33 +0000
Subject: [PATCH 1/2] [NVPTX] Fold (add (select 0, (mul a, b)), c) -> (select
 (mad a, b, c), c)

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 219 ++++++++++++--------
 llvm/test/CodeGen/NVPTX/combine-mad.ll      |  49 +++++
 2 files changed, 185 insertions(+), 83 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index f4ef7c9914f13..0c609554370a3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5215,103 +5215,129 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
   return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
 }
 
+static bool isConstZero(const SDValue &Operand) {
+  const auto *Const = dyn_cast<ConstantSDNode>(Operand);
+  return Const && Const->getZExtValue() == 0;
+}
+
 /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
 /// operands N0 and N1.  This is a helper for PerformADDCombine that is
 /// called with the default operands, and if that fails, with commuted
 /// operands.
-static SDValue PerformADDCombineWithOperands(
-    SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
-    const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
-  SelectionDAG  &DAG = DCI.DAG;
-  // Skip non-integer, non-scalar case
-  EVT VT=N0.getValueType();
-  if (VT.isVector())
+static SDValue
+PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+                              TargetLowering::DAGCombinerInfo &DCI) {
+  EVT VT = N0.getValueType();
+
+  // Since integer multiply-add costs the same as integer multiply
+  // but is more costly than integer add, do the fusion only when
+  // the mul is only used in the add.
+  if (!N0.getNode()->hasOneUse())
     return SDValue();
 
   // fold (add (mul a, b), c) -> (mad a, b, c)
   //
-  if (N0.getOpcode() == ISD::MUL) {
-    assert (VT.isInteger());
-    // For integer:
-    // Since integer multiply-add costs the same as integer multiply
-    // but is more costly than integer add, do the fusion only when
-    // the mul is only used in the add.
-    if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
-        !N0.getNode()->hasOneUse())
+  if (N0.getOpcode() == ISD::MUL)
+    return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
+                           N0.getOperand(1), N1);
+
+  // fold (add (select cond, 0, (mul a, b)), c)
+  //   -> (select cond, (mad a, b, c), c)
+  //
+  if (N0.getOpcode() == ISD::SELECT) {
+    bool ZeroCond;
+    if (isConstZero(N0->getOperand(1)))
+      ZeroCond = true;
+    else if (isConstZero(N0->getOperand(2)))
+      ZeroCond = false;
+    else
+      return SDValue();
+
+    SDValue M = N0->getOperand(ZeroCond ? 2 : 1);
+    if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
       return SDValue();
 
-    // Do the folding
-    return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
-                       N0.getOperand(0), N0.getOperand(1), N1);
+    SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
+                                  M->getOperand(0), M->getOperand(1), N1);
+    return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
+                             (ZeroCond ? N1 : MAD), (ZeroCond ? MAD : N1));
   }
-  else if (N0.getOpcode() == ISD::FMUL) {
-    if (VT == MVT::f32 || VT == MVT::f64) {
-      const auto *TLI = static_cast<const NVPTXTargetLowering *>(
-          &DAG.getTargetLoweringInfo());
-      if (!TLI->allowFMA(DAG.getMachineFunction(), OptLevel))
-        return SDValue();
 
-      // For floating point:
-      // Do the fusion only when the mul has less than 5 uses and all
-      // are add.
-      // The heuristic is that if a use is not an add, then that use
-      // cannot be fused into fma, therefore mul is still needed anyway.
-      // If there are more than 4 uses, even if they are all add, fusing
-      // them will increase register pressue.
-      //
-      int numUses = 0;
-      int nonAddCount = 0;
-      for (const SDNode *User : N0.getNode()->uses()) {
-        numUses++;
-        if (User->getOpcode() != ISD::FADD)
-          ++nonAddCount;
-      }
+  return SDValue();
+}
+
+static SDValue
+PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+                               TargetLowering::DAGCombinerInfo &DCI,
+                               CodeGenOptLevel OptLevel) {
+  EVT VT = N0.getValueType();
+  if (N0.getOpcode() == ISD::FMUL) {
+    const auto *TLI = static_cast<const NVPTXTargetLowering *>(
+        &DCI.DAG.getTargetLoweringInfo());
+    if (!TLI->allowFMA(DCI.DAG.getMachineFunction(), OptLevel))
+      return SDValue();
+
+    // For floating point:
+    // Do the fusion only when the mul has less than 5 uses and all
+    // are add.
+    // The heuristic is that if a use is not an add, then that use
+    // cannot be fused into fma, therefore mul is still needed anyway.
+    // If there are more than 4 uses, even if they are all add, fusing
+    // them will increase register pressue.
+    //
+    int numUses = 0;
+    int nonAddCount = 0;
+    for (const SDNode *User : N0.getNode()->uses()) {
+      numUses++;
+      if (User->getOpcode() != ISD::FADD)
+        ++nonAddCount;
       if (numUses >= 5)
         return SDValue();
-      if (nonAddCount) {
-        int orderNo = N->getIROrder();
-        int orderNo2 = N0.getNode()->getIROrder();
-        // simple heuristics here for considering potential register
-        // pressure, the logics here is that the differnce are used
-        // to measure the distance between def and use, the longer distance
-        // more likely cause register pressure.
-        if (orderNo - orderNo2 < 500)
-          return SDValue();
-
-        // Now, check if at least one of the FMUL's operands is live beyond the node N,
-        // which guarantees that the FMA will not increase register pressure at node N.
-        bool opIsLive = false;
-        const SDNode *left = N0.getOperand(0).getNode();
-        const SDNode *right = N0.getOperand(1).getNode();
-
-        if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
-          opIsLive = true;
-
-        if (!opIsLive)
-          for (const SDNode *User : left->uses()) {
-            int orderNo3 = User->getIROrder();
-            if (orderNo3 > orderNo) {
-              opIsLive = true;
-              break;
-            }
-          }
+    }
+    if (nonAddCount) {
+      int orderNo = N->getIROrder();
+      int orderNo2 = N0.getNode()->getIROrder();
+      // simple heuristics here for considering potential register
+      // pressure, the logics here is that the differnce are used
+      // to measure the distance between def and use, the longer distance
+      // more likely cause register pressure.
+      if (orderNo - orderNo2 < 500)
+        return SDValue();
 
-        if (!opIsLive)
-          for (const SDNode *User : right->uses()) {
-            int orderNo3 = User->getIROrder();
-            if (orderNo3 > orderNo) {
-              opIsLive = true;
-              break;
-            }
+      // Now, check if at least one of the FMUL's operands is live beyond the
+      // node N, which guarantees that the FMA will not increase register
+      // pressure at node N.
+      bool opIsLive = false;
+      const SDNode *left = N0.getOperand(0).getNode();
+      const SDNode *right = N0.getOperand(1).getNode();
+
+      if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
+        opIsLive = true;
+
+      if (!opIsLive)
+        for (const SDNode *User : left->uses()) {
+          int orderNo3 = User->getIROrder();
+          if (orderNo3 > orderNo) {
+            opIsLive = true;
+            break;
           }
+        }
 
-        if (!opIsLive)
-          return SDValue();
-      }
+      if (!opIsLive)
+        for (const SDNode *User : right->uses()) {
+          int orderNo3 = User->getIROrder();
+          if (orderNo3 > orderNo) {
+            opIsLive = true;
+            break;
+          }
+        }
 
-      return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                         N0.getOperand(0), N0.getOperand(1), N1);
+      if (!opIsLive)
+        return SDValue();
     }
+
+    return DCI.DAG.getNode(ISD::FMA, SDLoc(N), VT, N0.getOperand(0),
+                           N0.getOperand(1), N1);
   }
 
   return SDValue();
@@ -5332,18 +5358,44 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
 ///
 static SDValue PerformADDCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
-                                 const NVPTXSubtarget &Subtarget,
+                                 CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+
+  // Skip non-integer, non-scalar case
+  EVT VT = N0.getValueType();
+  if (VT.isVector() || VT != MVT::i32)
+    return SDValue();
+
+  // First try with the default operand order.
+  if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
+    return Result;
+
+  // If that didn't work, try again with the operands commuted.
+  return PerformADDCombineWithOperands(N, N1, N0, DCI);
+}
+
+/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
+///
+static SDValue PerformFADDCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
                                  CodeGenOptLevel OptLevel) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
 
+  EVT VT = N0.getValueType();
+  if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
+    return SDValue();
+
   // First try with the default operand order.
-  if (SDValue Result =
-          PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget, OptLevel))
+  if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
     return Result;
 
   // If that didn't work, try again with the operands commuted.
-  return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel);
+  return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
 }
 
 static SDValue PerformANDCombine(SDNode *N,
@@ -5876,8 +5928,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
   switch (N->getOpcode()) {
     default: break;
     case ISD::ADD:
+      return PerformADDCombine(N, DCI, OptLevel);
     case ISD::FADD:
-      return PerformADDCombine(N, DCI, STI, OptLevel);
+      return PerformFADDCombine(N, DCI, OptLevel);
     case ISD::MUL:
       return PerformMULCombine(N, DCI, OptLevel);
     case ISD::SHL:
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 0637bc916ea49..56bfaa14c5877 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -134,3 +134,52 @@ define i32 @test3(i32 %n, i32 %m, i32 %s) {
   %mul = mul i32 %sel, %m
   ret i32 %mul
 }
+
+;; (add (select 0, (mul a, b)), c) -> (select (mad a, b, c), c)
+define i32 @test4(i32 %a, i32 %b, i32 %c, i1 %p) {
+; CHECK-LABEL: test4(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u8 %rs1, [test4_param_3];
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT:    ld.param.u32 %r1, [test4_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test4_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test4_param_2];
+; CHECK-NEXT:    mad.lo.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    selp.b32 %r5, %r4, %r3, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %a, %b
+  %sel = select i1 %p, i32 %mul, i32 0
+  %add = add i32 %c, %sel
+  ret i32 %add
+}
+
+define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
+; CHECK-LABEL: test4_rev(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u8 %rs1, [test4_rev_param_3];
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.eq.b16 %p1, %rs2, 1;
+; CHECK-NEXT:    ld.param.u32 %r1, [test4_rev_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test4_rev_param_1];
+; CHECK-NEXT:    ld.param.u32 %r3, [test4_rev_param_2];
+; CHECK-NEXT:    mad.lo.s32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    selp.b32 %r5, %r3, %r4, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0+0], %r5;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %a, %b
+  %sel = select i1 %p, i32 0, i32 %mul
+  %add = add i32 %c, %sel
+  ret i32 %add
+}

>From 94733c9cc88ec1a249984a81dfd5e19cff1528ee Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Tue, 25 Jun 2024 19:49:07 +0000
Subject: [PATCH 2/2] address comments

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 0c609554370a3..d59f330cf0b77 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5242,25 +5242,26 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
                            N0.getOperand(1), N1);
 
   // fold (add (select cond, 0, (mul a, b)), c)
-  //   -> (select cond, (mad a, b, c), c)
+  //   -> (select cond, c, (mad a, b, c))
   //
   if (N0.getOpcode() == ISD::SELECT) {
-    bool ZeroCond;
+    unsigned ZeroOpNum;
     if (isConstZero(N0->getOperand(1)))
-      ZeroCond = true;
+      ZeroOpNum = 1;
     else if (isConstZero(N0->getOperand(2)))
-      ZeroCond = false;
+      ZeroOpNum = 2;
     else
       return SDValue();
 
-    SDValue M = N0->getOperand(ZeroCond ? 2 : 1);
+    SDValue M = N0->getOperand((ZeroOpNum == 1) ? 2 : 1);
     if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
       return SDValue();
 
     SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
                                   M->getOperand(0), M->getOperand(1), N1);
     return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
-                             (ZeroCond ? N1 : MAD), (ZeroCond ? MAD : N1));
+                             ((ZeroOpNum == 1) ? N1 : MAD),
+                             ((ZeroOpNum == 1) ? MAD : N1));
   }
 
   return SDValue();



More information about the llvm-commits mailing list