[llvm] deaa678 - AMDGPU/SDAG: Factor out the fold (add (mul x, y), y) --> mad_[iu]64_[iu]32

Nicolai Hähnle via llvm-commits llvm-commits at lists.llvm.org
Mon May 2 15:40:19 PDT 2022


Author: Nicolai Hähnle
Date: 2022-05-02T17:40:03-05:00
New Revision: deaa678137e52f51ca694fdfd1dc9988360fb69b

URL: https://github.com/llvm/llvm-project/commit/deaa678137e52f51ca694fdfd1dc9988360fb69b
DIFF: https://github.com/llvm/llvm-project/commit/deaa678137e52f51ca694fdfd1dc9988360fb69b.diff

LOG: AMDGPU/SDAG: Factor out the fold (add (mul x, y), y) --> mad_[iu]64_[iu]32

Refactor to simplify a follow-up change.

No functional change intended. However, there is a rather subtle logic
change: the subsequent combines (e.g. reassociation) are skipped *always*
when one of the operands of the add is a mul, instead of only when
additionally mad64_32 etc. are available. This change makes sense because
the subsequent combines should never apply when one of the operands is a
mul.

Differential Revision: https://reviews.llvm.org/D123833

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/SIISelLowering.cpp
    llvm/lib/Target/AMDGPU/SIISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 76c8edc4c6e9a..2018cc33807df 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -10661,39 +10661,64 @@ static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL,
   return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad);
 }
 
-SDValue SITargetLowering::performAddCombine(SDNode *N,
+// Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z).
+SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
                                             DAGCombinerInfo &DCI) const {
+  assert(N->getOpcode() == ISD::ADD);
+
   SelectionDAG &DAG = DCI.DAG;
   EVT VT = N->getValueType(0);
   SDLoc SL(N);
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
 
-  if ((LHS.getOpcode() == ISD::MUL || RHS.getOpcode() == ISD::MUL)
-      && Subtarget->hasMad64_32() &&
-      !VT.isVector() && VT.getScalarSizeInBits() > 32 &&
-      VT.getScalarSizeInBits() <= 64) {
-    if (LHS.getOpcode() != ISD::MUL)
-      std::swap(LHS, RHS);
+  if (VT.isVector())
+    return SDValue();
 
-    SDValue MulLHS = LHS.getOperand(0);
-    SDValue MulRHS = LHS.getOperand(1);
-    SDValue AddRHS = RHS;
+  unsigned NumBits = VT.getScalarSizeInBits();
+  if (NumBits <= 32 || NumBits > 64)
+    return SDValue();
 
-    // TODO: Maybe restrict if SGPR inputs.
-    if (numBitsUnsigned(MulLHS, DAG) <= 32 &&
-        numBitsUnsigned(MulRHS, DAG) <= 32) {
-      MulLHS = DAG.getZExtOrTrunc(MulLHS, SL, MVT::i32);
-      MulRHS = DAG.getZExtOrTrunc(MulRHS, SL, MVT::i32);
-      AddRHS = DAG.getZExtOrTrunc(AddRHS, SL, MVT::i64);
-      return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false);
-    }
+  if (LHS.getOpcode() != ISD::MUL) {
+    assert(RHS.getOpcode() == ISD::MUL);
+    std::swap(LHS, RHS);
+  }
+
+  SDValue MulLHS = LHS.getOperand(0);
+  SDValue MulRHS = LHS.getOperand(1);
+  SDValue AddRHS = RHS;
+
+  // TODO: Maybe restrict if SGPR inputs.
+  if (numBitsUnsigned(MulLHS, DAG) <= 32 &&
+      numBitsUnsigned(MulRHS, DAG) <= 32) {
+    MulLHS = DAG.getZExtOrTrunc(MulLHS, SL, MVT::i32);
+    MulRHS = DAG.getZExtOrTrunc(MulRHS, SL, MVT::i32);
+    AddRHS = DAG.getZExtOrTrunc(AddRHS, SL, MVT::i64);
+    return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false);
+  }
+
+  if (numBitsSigned(MulLHS, DAG) <= 32 && numBitsSigned(MulRHS, DAG) <= 32) {
+    MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32);
+    MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32);
+    AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64);
+    return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, true);
+  }
+
+  return SDValue();
+}
+
+SDValue SITargetLowering::performAddCombine(SDNode *N,
+                                            DAGCombinerInfo &DCI) const {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+  SDLoc SL(N);
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
 
-    if (numBitsSigned(MulLHS, DAG) <= 32 && numBitsSigned(MulRHS, DAG) <= 32) {
-      MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32);
-      MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32);
-      AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64);
-      return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, true);
+  if (LHS.getOpcode() == ISD::MUL || RHS.getOpcode() == ISD::MUL) {
+    if (Subtarget->hasMad64_32()) {
+      if (SDValue Folded = tryFoldToMad64_32(N, DCI))
+        return Folded;
     }
 
     return SDValue();

diff  --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h
index 72a00c1596fe3..6105fe6b0a1fa 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h
@@ -197,6 +197,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
   SDValue reassociateScalarOps(SDNode *N, SelectionDAG &DAG) const;
   unsigned getFusedOpcode(const SelectionDAG &DAG,
                           const SDNode *N0, const SDNode *N1) const;
+  SDValue tryFoldToMad64_32(SDNode *N, DAGCombinerInfo &DCI) const;
   SDValue performAddCombine(SDNode *N, DAGCombinerInfo &DCI) const;
   SDValue performAddCarrySubCarryCombine(SDNode *N, DAGCombinerInfo &DCI) const;
   SDValue performSubCombine(SDNode *N, DAGCombinerInfo &DCI) const;


        


More information about the llvm-commits mailing list