[llvm] [TargetLowering] Refactor to share most of the implementation of the two forceExpandWideMUL functions. NFC (PR #124241)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 24 00:09:41 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

This patch is split into 2 commits. The first inlines one of the forceExpandWideMUL functions into its two call sites. Each call site uses a different part of the function. The libcall part is used by the other forceExpandWideMul and the arithmetic expansion is used by ExpandIntRes_MUL. ExpandIntRes_MUL has its own libcall handling before this is called.

The second commit takes the similar code from ExpandIntRes_MUL and forceExpandWideMUL and combines them into a new function forceExpandMUL function that has no libcall support. It optionally takes HiLHS and HiRHS as inputs. If they are provided they are used in the Hi result calculation. The Signed flag can only be set if HiLHS and HiRHS are not set.

More details in the individual commit messages.

---
Full diff: https://github.com/llvm/llvm-project/pull/124241.diff


3 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+13-14) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+1-4) 
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+83-122) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 59743dbe4d2ea4..4ef9a06a8644e2 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5499,20 +5499,19 @@ class TargetLowering : public TargetLoweringBase {
   bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
                   SelectionDAG &DAG) const;
 
-  /// forceExpandWideMUL - Unconditionally expand a MUL into either a libcall or
-  /// brute force via a wide multiplication. The expansion works by
-  /// attempting to do a multiplication on a wider type twice the size of the
-  /// original operands. LL and LH represent the lower and upper halves of the
-  /// first operand. RL and RH represent the lower and upper halves of the
-  /// second operand. The upper and lower halves of the result are stored in Lo
-  /// and Hi.
-  void forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
-                          EVT WideVT, const SDValue LL, const SDValue LH,
-                          const SDValue RL, const SDValue RH, SDValue &Lo,
-                          SDValue &Hi) const;
-
-  /// Same as above, but creates the upper halves of each operand by
-  /// sign/zero-extending the operands.
+  /// Calculate the product twice the width of LHS and RHS. If HiLHS/HiRHS are
+  /// non-null they will be included in the multiplication. The expansion works
+  /// by splitting the 2 inputs into 4 pieces that we can multiply and add
+  /// together without neding MULH or MUL_LOHI.
+  void forceExpandMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
+                      SDValue &Lo, SDValue &Hi, SDValue LHS, SDValue RHS,
+                      SDValue HiLHS = SDValue(),
+                      SDValue HiRHS = SDValue()) const;
+
+  /// Calculate full product of LHS and RHS either via a libcall or through
+  /// brute force expansion of the multiplication. The expansion works by
+  /// splitting the 2 inputs into 4 pieces that we can multiply and add together
+  /// without needing MULH or MUL_LOHI.
   void forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
                           const SDValue LHS, const SDValue RHS, SDValue &Lo,
                           SDValue &Hi) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index b0a624680231e9..ad9e3e4e99b979 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -4294,10 +4294,7 @@ void DAGTypeLegalizer::ExpandIntRes_MUL(SDNode *N,
     LC = RTLIB::MUL_I128;
 
   if (LC == RTLIB::UNKNOWN_LIBCALL || !TLI.getLibcallName(LC)) {
-    // Perform a wide multiplication where the wide type is the original VT and
-    // the 4 parts are the split arguments.
-    TLI.forceExpandWideMUL(DAG, dl, /*Signed=*/true, VT, LL, LH, RL, RH, Lo,
-                           Hi);
+    TLI.forceExpandMUL(DAG, dl, /*Signed=*/false, Lo, Hi, LL, RL, LH, RH);
     return;
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 0d039860b9f0fd..523a55781ba52d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -10857,15 +10857,73 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
   return DAG.getSelect(dl, VT, Cond, SatVal, Result);
 }
 
+void TargetLowering::forceExpandMUL(SelectionDAG &DAG, const SDLoc &dl,
+                                    bool Signed, SDValue &Lo, SDValue &Hi,
+                                    SDValue LHS, SDValue RHS, SDValue HiLHS,
+                                    SDValue HiRHS) const {
+  EVT VT = LHS.getValueType();
+  assert(RHS.getValueType() == VT && "Mismatching operand types");
+
+  assert((HiLHS && HiRHS) || (!HiLHS && !HiRHS));
+  assert((!Signed || !HiLHS) &&
+         "Signed flag should only be set when HiLHS and RiRHS are null");
+
+  // We'll expand the multiplication by brute force because we have no other
+  // options. This is a trivially-generalized version of the code from
+  // Hacker's Delight (itself derived from Knuth's Algorithm M from section
+  // 4.3.1). If Signed is set, we can use arithmetic right shifts to propagate
+  // sign bits while calculating the Hi half.
+  unsigned Bits = VT.getSizeInBits();
+  unsigned HalfBits = Bits / 2;
+  SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
+  SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
+  SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
+
+  SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
+  SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
+
+  SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
+  // This is always an unsigned shift.
+  SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
+
+  unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
+  SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
+  SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
+
+  SDValue U =
+      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
+  SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
+  SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
+
+  SDValue V =
+      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
+  SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
+
+  Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
+                   DAG.getNode(ISD::SHL, dl, VT, V, Shift));
+
+  Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
+                   DAG.getNode(ISD::ADD, dl, VT, UH, VH));
+
+  // If HiLHS and HiRHS are set, multiply them by the opposite low part and add
+  // them to products to Hi.
+  if (HiLHS) {
+    Hi = DAG.getNode(ISD::ADD, dl, VT, Hi,
+                     DAG.getNode(ISD::ADD, dl, VT,
+                                 DAG.getNode(ISD::MUL, dl, VT, HiRHS, LHS),
+                                 DAG.getNode(ISD::MUL, dl, VT, RHS, HiLHS)));
+  }
+}
+
 void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
-                                        bool Signed, EVT WideVT,
-                                        const SDValue LL, const SDValue LH,
-                                        const SDValue RL, const SDValue RH,
-                                        SDValue &Lo, SDValue &Hi) const {
+                                        bool Signed, const SDValue LHS,
+                                        const SDValue RHS, SDValue &Lo,
+                                        SDValue &Hi) const {
+  EVT VT = LHS.getValueType();
+  assert(RHS.getValueType() == VT && "Mismatching operand types");
+  EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
   // We can fall back to a libcall with an illegal type for the MUL if we
   // have a libcall big enough.
-  // Also, we can fall back to a division in some cases, but that's a big
-  // performance hit in the general case.
   RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
   if (WideVT == MVT::i16)
     LC = RTLIB::MUL_I16;
@@ -10877,46 +10935,23 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
     LC = RTLIB::MUL_I128;
 
   if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
-    // We'll expand the multiplication by brute force because we have no other
-    // options. This is a trivially-generalized version of the code from
-    // Hacker's Delight (itself derived from Knuth's Algorithm M from section
-    // 4.3.1).
-    EVT VT = LL.getValueType();
-    unsigned Bits = VT.getSizeInBits();
-    unsigned HalfBits = Bits >> 1;
-    SDValue Mask =
-        DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
-    SDValue LLL = DAG.getNode(ISD::AND, dl, VT, LL, Mask);
-    SDValue RLL = DAG.getNode(ISD::AND, dl, VT, RL, Mask);
-
-    SDValue T = DAG.getNode(ISD::MUL, dl, VT, LLL, RLL);
-    SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
-
-    SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
-    SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
-    SDValue LLH = DAG.getNode(ISD::SRL, dl, VT, LL, Shift);
-    SDValue RLH = DAG.getNode(ISD::SRL, dl, VT, RL, Shift);
-
-    SDValue U = DAG.getNode(ISD::ADD, dl, VT,
-                            DAG.getNode(ISD::MUL, dl, VT, LLH, RLL), TH);
-    SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
-    SDValue UH = DAG.getNode(ISD::SRL, dl, VT, U, Shift);
-
-    SDValue V = DAG.getNode(ISD::ADD, dl, VT,
-                            DAG.getNode(ISD::MUL, dl, VT, LLL, RLH), UL);
-    SDValue VH = DAG.getNode(ISD::SRL, dl, VT, V, Shift);
-
-    SDValue W =
-        DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LLH, RLH),
-                    DAG.getNode(ISD::ADD, dl, VT, UH, VH));
-    Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
-                     DAG.getNode(ISD::SHL, dl, VT, V, Shift));
-
-    Hi = DAG.getNode(ISD::ADD, dl, VT, W,
-                     DAG.getNode(ISD::ADD, dl, VT,
-                                 DAG.getNode(ISD::MUL, dl, VT, RH, LL),
-                                 DAG.getNode(ISD::MUL, dl, VT, RL, LH)));
-  } else {
+    forceExpandMUL(DAG, dl, Signed, Lo, Hi, LHS, RHS);
+    return;
+  }
+
+    SDValue HiLHS, HiRHS;
+    if (Signed) {
+      // The high part is obtained by SRA'ing all but one of the bits of low
+      // part.
+      unsigned LoSize = VT.getFixedSizeInBits();
+      SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
+      HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
+      HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
+    } else {
+      HiLHS = DAG.getConstant(0, dl, VT);
+      HiRHS = DAG.getConstant(0, dl, VT);
+    }
+
     // Attempt a libcall.
     SDValue Ret;
     TargetLowering::MakeLibCallOptions CallOptions;
@@ -10927,10 +10962,10 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
       // depending on platform endianness. This is usually handled by
       // the C calling convention, but we can't defer to it in
       // the legalizer.
-      SDValue Args[] = {LL, LH, RL, RH};
+      SDValue Args[] = {LHS, HiLHS, RHS, HiRHS};
       Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
     } else {
-      SDValue Args[] = {LH, LL, RH, RL};
+      SDValue Args[] = {HiLHS, LHS, HiRHS, RHS};
       Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
     }
     assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
@@ -10943,80 +10978,6 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
       Lo = Ret.getOperand(1);
       Hi = Ret.getOperand(0);
     }
-  }
-}
-
-void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
-                                        bool Signed, const SDValue LHS,
-                                        const SDValue RHS, SDValue &Lo,
-                                        SDValue &Hi) const {
-  EVT VT = LHS.getValueType();
-  assert(RHS.getValueType() == VT && "Mismatching operand types");
-  EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
-  // We can fall back to a libcall with an illegal type for the MUL if we
-  // have a libcall big enough.
-  RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
-  if (WideVT == MVT::i16)
-    LC = RTLIB::MUL_I16;
-  else if (WideVT == MVT::i32)
-    LC = RTLIB::MUL_I32;
-  else if (WideVT == MVT::i64)
-    LC = RTLIB::MUL_I64;
-  else if (WideVT == MVT::i128)
-    LC = RTLIB::MUL_I128;
-
-  if (LC != RTLIB::UNKNOWN_LIBCALL && getLibcallName(LC)) {
-    SDValue HiLHS, HiRHS;
-    if (Signed) {
-      // The high part is obtained by SRA'ing all but one of the bits of low
-      // part.
-      unsigned LoSize = VT.getFixedSizeInBits();
-      SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
-      HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
-      HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
-    } else {
-      HiLHS = DAG.getConstant(0, dl, VT);
-      HiRHS = DAG.getConstant(0, dl, VT);
-    }
-    forceExpandWideMUL(DAG, dl, Signed, WideVT, LHS, HiLHS, RHS, HiRHS, Lo, Hi);
-    return;
-  }
-
-  // Expand the multiplication by brute force. This is a generalized-version of
-  // the code from Hacker's Delight (itself derived from Knuth's Algorithm M
-  // from section 4.3.1) combined with the Hacker's delight code
-  // for calculating mulhs.
-  unsigned Bits = VT.getSizeInBits();
-  unsigned HalfBits = Bits / 2;
-  SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
-  SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
-  SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
-
-  SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
-  SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
-
-  SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
-  // This is always an unsigned shift.
-  SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
-
-  unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
-  SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
-  SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
-
-  SDValue U =
-      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
-  SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
-  SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
-
-  SDValue V =
-      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
-  SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
-
-  Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
-                   DAG.getNode(ISD::SHL, dl, VT, V, Shift));
-
-  Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
-                   DAG.getNode(ISD::ADD, dl, VT, UH, VH));
 }
 
 SDValue

``````````

</details>


https://github.com/llvm/llvm-project/pull/124241


More information about the llvm-commits mailing list