[llvm] [TargetLowering] Refactor to share most of the implementation of the two forceExpandWideMUL functions. NFC (PR #124241)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 24 00:09:08 PST 2025


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/124241

This patch is split into 2 commits. The first inlines one of the forceExpandWideMUL functions into its two call sites. Each call site uses a different part of the function. The libcall part is used by the other forceExpandWideMul and the arithmetic expansion is used by ExpandIntRes_MUL. ExpandIntRes_MUL has its own libcall handling before this is called.

The second commit takes the similar code from ExpandIntRes_MUL and forceExpandWideMUL and combines them into a new function forceExpandMUL function that has no libcall support. It optionally takes HiLHS and HiRHS as inputs. If they are provided they are used in the Hi result calculation. The Signed flag can only be set if HiLHS and HiRHS are not set.

More details in the individual commit messages.

>From 3a608ef66d1e5fcde2956061b005a2730d25aa5e Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 23 Jan 2025 13:10:01 -0800
Subject: [PATCH 1/2] [TargetLowering] Inline one of the signatures of
 forceExpandWideMul into its callers. NFC

There are two calls sites. One uses the non-libcall part and
the other uses the libcall part. Sink those pieces into their callers.

After this I'm going to merge the non-libcall part of the other
forceExpandWideMul with the code from LegalizeIntegerTypes into
a new helper.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |  18 +--
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  42 ++++++-
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 107 ++++--------------
 3 files changed, 64 insertions(+), 103 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 59743dbe4d2ea4..9c541623140311 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5499,20 +5499,10 @@ class TargetLowering : public TargetLoweringBase {
   bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
                   SelectionDAG &DAG) const;
 
-  /// forceExpandWideMUL - Unconditionally expand a MUL into either a libcall or
-  /// brute force via a wide multiplication. The expansion works by
-  /// attempting to do a multiplication on a wider type twice the size of the
-  /// original operands. LL and LH represent the lower and upper halves of the
-  /// first operand. RL and RH represent the lower and upper halves of the
-  /// second operand. The upper and lower halves of the result are stored in Lo
-  /// and Hi.
-  void forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
-                          EVT WideVT, const SDValue LL, const SDValue LH,
-                          const SDValue RL, const SDValue RH, SDValue &Lo,
-                          SDValue &Hi) const;
-
-  /// Same as above, but creates the upper halves of each operand by
-  /// sign/zero-extending the operands.
+  /// Calculate full product of LHS and RHS either via a libcall or through
+  /// brute force expansion of the multiplication. The expansion works by
+  /// splitting the 2 inputs into 4 pieces that we can multiply and add together
+  /// without needing MULH or MUL_LOHI.
   void forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
                           const SDValue LHS, const SDValue RHS, SDValue &Lo,
                           SDValue &Hi) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index b0a624680231e9..8cfcd1e234c731 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -4294,10 +4294,44 @@ void DAGTypeLegalizer::ExpandIntRes_MUL(SDNode *N,
     LC = RTLIB::MUL_I128;
 
   if (LC == RTLIB::UNKNOWN_LIBCALL || !TLI.getLibcallName(LC)) {
-    // Perform a wide multiplication where the wide type is the original VT and
-    // the 4 parts are the split arguments.
-    TLI.forceExpandWideMUL(DAG, dl, /*Signed=*/true, VT, LL, LH, RL, RH, Lo,
-                           Hi);
+    // We'll expand the multiplication by brute force because we have no other
+    // options. This is a trivially-generalized version of the code from
+    // Hacker's Delight (itself derived from Knuth's Algorithm M from section
+    // 4.3.1).
+    unsigned Bits = NVT.getSizeInBits();
+    unsigned HalfBits = Bits >> 1;
+    SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl,
+                                   NVT);
+    SDValue LLL = DAG.getNode(ISD::AND, dl, NVT, LL, Mask);
+    SDValue RLL = DAG.getNode(ISD::AND, dl, NVT, RL, Mask);
+
+    SDValue T = DAG.getNode(ISD::MUL, dl, NVT, LLL, RLL);
+    SDValue TL = DAG.getNode(ISD::AND, dl, NVT, T, Mask);
+
+    SDValue Shift = DAG.getShiftAmountConstant(HalfBits, NVT, dl);
+    SDValue TH = DAG.getNode(ISD::SRL, dl, NVT, T, Shift);
+    SDValue LLH = DAG.getNode(ISD::SRL, dl, NVT, LL, Shift);
+    SDValue RLH = DAG.getNode(ISD::SRL, dl, NVT, RL, Shift);
+
+    SDValue U = DAG.getNode(ISD::ADD, dl, NVT,
+                            DAG.getNode(ISD::MUL, dl, NVT, LLH, RLL), TH);
+    SDValue UL = DAG.getNode(ISD::AND, dl, NVT, U, Mask);
+    SDValue UH = DAG.getNode(ISD::SRL, dl, NVT, U, Shift);
+
+    SDValue V = DAG.getNode(ISD::ADD, dl, NVT,
+                            DAG.getNode(ISD::MUL, dl, NVT, LLL, RLH), UL);
+    SDValue VH = DAG.getNode(ISD::SRL, dl, NVT, V, Shift);
+
+    SDValue W = DAG.getNode(ISD::ADD, dl, NVT,
+                            DAG.getNode(ISD::MUL, dl, NVT, LLH, RLH),
+                            DAG.getNode(ISD::ADD, dl, NVT, UH, VH));
+    Lo = DAG.getNode(ISD::ADD, dl, NVT, TL,
+                     DAG.getNode(ISD::SHL, dl, NVT, V, Shift));
+
+    Hi = DAG.getNode(ISD::ADD, dl, NVT, W,
+                     DAG.getNode(ISD::ADD, dl, NVT,
+                                 DAG.getNode(ISD::MUL, dl, NVT, RH, LL),
+                                 DAG.getNode(ISD::MUL, dl, NVT, RL, LH)));
     return;
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 0d039860b9f0fd..bb90c83877d865 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -10858,14 +10858,14 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
 }
 
 void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
-                                        bool Signed, EVT WideVT,
-                                        const SDValue LL, const SDValue LH,
-                                        const SDValue RL, const SDValue RH,
-                                        SDValue &Lo, SDValue &Hi) const {
+                                        bool Signed, const SDValue LHS,
+                                        const SDValue RHS, SDValue &Lo,
+                                        SDValue &Hi) const {
+  EVT VT = LHS.getValueType();
+  assert(RHS.getValueType() == VT && "Mismatching operand types");
+  EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
   // We can fall back to a libcall with an illegal type for the MUL if we
   // have a libcall big enough.
-  // Also, we can fall back to a division in some cases, but that's a big
-  // performance hit in the general case.
   RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
   if (WideVT == MVT::i16)
     LC = RTLIB::MUL_I16;
@@ -10876,47 +10876,20 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
   else if (WideVT == MVT::i128)
     LC = RTLIB::MUL_I128;
 
-  if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
-    // We'll expand the multiplication by brute force because we have no other
-    // options. This is a trivially-generalized version of the code from
-    // Hacker's Delight (itself derived from Knuth's Algorithm M from section
-    // 4.3.1).
-    EVT VT = LL.getValueType();
-    unsigned Bits = VT.getSizeInBits();
-    unsigned HalfBits = Bits >> 1;
-    SDValue Mask =
-        DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
-    SDValue LLL = DAG.getNode(ISD::AND, dl, VT, LL, Mask);
-    SDValue RLL = DAG.getNode(ISD::AND, dl, VT, RL, Mask);
-
-    SDValue T = DAG.getNode(ISD::MUL, dl, VT, LLL, RLL);
-    SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
-
-    SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
-    SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
-    SDValue LLH = DAG.getNode(ISD::SRL, dl, VT, LL, Shift);
-    SDValue RLH = DAG.getNode(ISD::SRL, dl, VT, RL, Shift);
-
-    SDValue U = DAG.getNode(ISD::ADD, dl, VT,
-                            DAG.getNode(ISD::MUL, dl, VT, LLH, RLL), TH);
-    SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
-    SDValue UH = DAG.getNode(ISD::SRL, dl, VT, U, Shift);
-
-    SDValue V = DAG.getNode(ISD::ADD, dl, VT,
-                            DAG.getNode(ISD::MUL, dl, VT, LLL, RLH), UL);
-    SDValue VH = DAG.getNode(ISD::SRL, dl, VT, V, Shift);
-
-    SDValue W =
-        DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LLH, RLH),
-                    DAG.getNode(ISD::ADD, dl, VT, UH, VH));
-    Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
-                     DAG.getNode(ISD::SHL, dl, VT, V, Shift));
-
-    Hi = DAG.getNode(ISD::ADD, dl, VT, W,
-                     DAG.getNode(ISD::ADD, dl, VT,
-                                 DAG.getNode(ISD::MUL, dl, VT, RH, LL),
-                                 DAG.getNode(ISD::MUL, dl, VT, RL, LH)));
-  } else {
+  if (LC != RTLIB::UNKNOWN_LIBCALL && getLibcallName(LC)) {
+    SDValue HiLHS, HiRHS;
+    if (Signed) {
+      // The high part is obtained by SRA'ing all but one of the bits of low
+      // part.
+      unsigned LoSize = VT.getFixedSizeInBits();
+      SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
+      HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
+      HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
+    } else {
+      HiLHS = DAG.getConstant(0, dl, VT);
+      HiRHS = DAG.getConstant(0, dl, VT);
+    }
+
     // Attempt a libcall.
     SDValue Ret;
     TargetLowering::MakeLibCallOptions CallOptions;
@@ -10927,10 +10900,10 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
       // depending on platform endianness. This is usually handled by
       // the C calling convention, but we can't defer to it in
       // the legalizer.
-      SDValue Args[] = {LL, LH, RL, RH};
+      SDValue Args[] = {LHS, HiLHS, RHS, HiRHS};
       Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
     } else {
-      SDValue Args[] = {LH, LL, RH, RL};
+      SDValue Args[] = {HiLHS, LHS, HiRHS, RHS};
       Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
     }
     assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
@@ -10943,42 +10916,6 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
       Lo = Ret.getOperand(1);
       Hi = Ret.getOperand(0);
     }
-  }
-}
-
-void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
-                                        bool Signed, const SDValue LHS,
-                                        const SDValue RHS, SDValue &Lo,
-                                        SDValue &Hi) const {
-  EVT VT = LHS.getValueType();
-  assert(RHS.getValueType() == VT && "Mismatching operand types");
-  EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
-  // We can fall back to a libcall with an illegal type for the MUL if we
-  // have a libcall big enough.
-  RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
-  if (WideVT == MVT::i16)
-    LC = RTLIB::MUL_I16;
-  else if (WideVT == MVT::i32)
-    LC = RTLIB::MUL_I32;
-  else if (WideVT == MVT::i64)
-    LC = RTLIB::MUL_I64;
-  else if (WideVT == MVT::i128)
-    LC = RTLIB::MUL_I128;
-
-  if (LC != RTLIB::UNKNOWN_LIBCALL && getLibcallName(LC)) {
-    SDValue HiLHS, HiRHS;
-    if (Signed) {
-      // The high part is obtained by SRA'ing all but one of the bits of low
-      // part.
-      unsigned LoSize = VT.getFixedSizeInBits();
-      SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
-      HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
-      HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
-    } else {
-      HiLHS = DAG.getConstant(0, dl, VT);
-      HiRHS = DAG.getConstant(0, dl, VT);
-    }
-    forceExpandWideMUL(DAG, dl, Signed, WideVT, LHS, HiLHS, RHS, HiRHS, Lo, Hi);
     return;
   }
 

>From 0f0c16114254b3675bc25f60b827dd0c61db4da9 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 23 Jan 2025 22:32:31 -0800
Subject: [PATCH 2/2] [SelectionDAG] Share code for two of our multiply
 expansions. NFC

ExpandIntRes_MUL and forceExpandWideMul have very similar code.
ExpandIntRes_MUL calculates Lo and Hi half result from the 2 sources
with Hi and Lo halves. forceExpandWideMul calculates the Lo and Hi
half of the full product of 2 values.

The only differences are that forceExpandWideMul uses ISD::SRA
instead of ISD::SRL for a signed wide multiply. ExpandIntRes_MUL
needs 2 additionals multiplies and 2 adds to multiply HiRHS*LHS and
HiLHS*RHS and add them to Hi.

This patch introduces a new function that takes HiLHS and HiRHS as
optional values. If they are not null, they will be used in the
calculation of the Hi half. The Signed flag can only be set when
HiLHS/HiRHS are null.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |   9 ++
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  39 +------
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 102 +++++++++++-------
 3 files changed, 73 insertions(+), 77 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9c541623140311..4ef9a06a8644e2 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5499,6 +5499,15 @@ class TargetLowering : public TargetLoweringBase {
   bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
                   SelectionDAG &DAG) const;
 
+  /// Calculate the product twice the width of LHS and RHS. If HiLHS/HiRHS are
+  /// non-null they will be included in the multiplication. The expansion works
+  /// by splitting the 2 inputs into 4 pieces that we can multiply and add
+  /// together without neding MULH or MUL_LOHI.
+  void forceExpandMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
+                      SDValue &Lo, SDValue &Hi, SDValue LHS, SDValue RHS,
+                      SDValue HiLHS = SDValue(),
+                      SDValue HiRHS = SDValue()) const;
+
   /// Calculate full product of LHS and RHS either via a libcall or through
   /// brute force expansion of the multiplication. The expansion works by
   /// splitting the 2 inputs into 4 pieces that we can multiply and add together
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 8cfcd1e234c731..ad9e3e4e99b979 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -4294,44 +4294,7 @@ void DAGTypeLegalizer::ExpandIntRes_MUL(SDNode *N,
     LC = RTLIB::MUL_I128;
 
   if (LC == RTLIB::UNKNOWN_LIBCALL || !TLI.getLibcallName(LC)) {
-    // We'll expand the multiplication by brute force because we have no other
-    // options. This is a trivially-generalized version of the code from
-    // Hacker's Delight (itself derived from Knuth's Algorithm M from section
-    // 4.3.1).
-    unsigned Bits = NVT.getSizeInBits();
-    unsigned HalfBits = Bits >> 1;
-    SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl,
-                                   NVT);
-    SDValue LLL = DAG.getNode(ISD::AND, dl, NVT, LL, Mask);
-    SDValue RLL = DAG.getNode(ISD::AND, dl, NVT, RL, Mask);
-
-    SDValue T = DAG.getNode(ISD::MUL, dl, NVT, LLL, RLL);
-    SDValue TL = DAG.getNode(ISD::AND, dl, NVT, T, Mask);
-
-    SDValue Shift = DAG.getShiftAmountConstant(HalfBits, NVT, dl);
-    SDValue TH = DAG.getNode(ISD::SRL, dl, NVT, T, Shift);
-    SDValue LLH = DAG.getNode(ISD::SRL, dl, NVT, LL, Shift);
-    SDValue RLH = DAG.getNode(ISD::SRL, dl, NVT, RL, Shift);
-
-    SDValue U = DAG.getNode(ISD::ADD, dl, NVT,
-                            DAG.getNode(ISD::MUL, dl, NVT, LLH, RLL), TH);
-    SDValue UL = DAG.getNode(ISD::AND, dl, NVT, U, Mask);
-    SDValue UH = DAG.getNode(ISD::SRL, dl, NVT, U, Shift);
-
-    SDValue V = DAG.getNode(ISD::ADD, dl, NVT,
-                            DAG.getNode(ISD::MUL, dl, NVT, LLL, RLH), UL);
-    SDValue VH = DAG.getNode(ISD::SRL, dl, NVT, V, Shift);
-
-    SDValue W = DAG.getNode(ISD::ADD, dl, NVT,
-                            DAG.getNode(ISD::MUL, dl, NVT, LLH, RLH),
-                            DAG.getNode(ISD::ADD, dl, NVT, UH, VH));
-    Lo = DAG.getNode(ISD::ADD, dl, NVT, TL,
-                     DAG.getNode(ISD::SHL, dl, NVT, V, Shift));
-
-    Hi = DAG.getNode(ISD::ADD, dl, NVT, W,
-                     DAG.getNode(ISD::ADD, dl, NVT,
-                                 DAG.getNode(ISD::MUL, dl, NVT, RH, LL),
-                                 DAG.getNode(ISD::MUL, dl, NVT, RL, LH)));
+    TLI.forceExpandMUL(DAG, dl, /*Signed=*/false, Lo, Hi, LL, RL, LH, RH);
     return;
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index bb90c83877d865..523a55781ba52d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -10857,6 +10857,64 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
   return DAG.getSelect(dl, VT, Cond, SatVal, Result);
 }
 
+void TargetLowering::forceExpandMUL(SelectionDAG &DAG, const SDLoc &dl,
+                                    bool Signed, SDValue &Lo, SDValue &Hi,
+                                    SDValue LHS, SDValue RHS, SDValue HiLHS,
+                                    SDValue HiRHS) const {
+  EVT VT = LHS.getValueType();
+  assert(RHS.getValueType() == VT && "Mismatching operand types");
+
+  assert((HiLHS && HiRHS) || (!HiLHS && !HiRHS));
+  assert((!Signed || !HiLHS) &&
+         "Signed flag should only be set when HiLHS and RiRHS are null");
+
+  // We'll expand the multiplication by brute force because we have no other
+  // options. This is a trivially-generalized version of the code from
+  // Hacker's Delight (itself derived from Knuth's Algorithm M from section
+  // 4.3.1). If Signed is set, we can use arithmetic right shifts to propagate
+  // sign bits while calculating the Hi half.
+  unsigned Bits = VT.getSizeInBits();
+  unsigned HalfBits = Bits / 2;
+  SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
+  SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
+  SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
+
+  SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
+  SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
+
+  SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
+  // This is always an unsigned shift.
+  SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
+
+  unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
+  SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
+  SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
+
+  SDValue U =
+      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
+  SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
+  SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
+
+  SDValue V =
+      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
+  SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
+
+  Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
+                   DAG.getNode(ISD::SHL, dl, VT, V, Shift));
+
+  Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
+                   DAG.getNode(ISD::ADD, dl, VT, UH, VH));
+
+  // If HiLHS and HiRHS are set, multiply them by the opposite low part and add
+  // them to products to Hi.
+  if (HiLHS) {
+    Hi = DAG.getNode(ISD::ADD, dl, VT, Hi,
+                     DAG.getNode(ISD::ADD, dl, VT,
+                                 DAG.getNode(ISD::MUL, dl, VT, HiRHS, LHS),
+                                 DAG.getNode(ISD::MUL, dl, VT, RHS, HiLHS)));
+  }
+}
+
 void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
                                         bool Signed, const SDValue LHS,
                                         const SDValue RHS, SDValue &Lo,
@@ -10876,7 +10934,11 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
   else if (WideVT == MVT::i128)
     LC = RTLIB::MUL_I128;
 
-  if (LC != RTLIB::UNKNOWN_LIBCALL && getLibcallName(LC)) {
+  if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
+    forceExpandMUL(DAG, dl, Signed, Lo, Hi, LHS, RHS);
+    return;
+  }
+
     SDValue HiLHS, HiRHS;
     if (Signed) {
       // The high part is obtained by SRA'ing all but one of the bits of low
@@ -10916,44 +10978,6 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
       Lo = Ret.getOperand(1);
       Hi = Ret.getOperand(0);
     }
-    return;
-  }
-
-  // Expand the multiplication by brute force. This is a generalized-version of
-  // the code from Hacker's Delight (itself derived from Knuth's Algorithm M
-  // from section 4.3.1) combined with the Hacker's delight code
-  // for calculating mulhs.
-  unsigned Bits = VT.getSizeInBits();
-  unsigned HalfBits = Bits / 2;
-  SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
-  SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
-  SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
-
-  SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
-  SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
-
-  SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
-  // This is always an unsigned shift.
-  SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
-
-  unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
-  SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
-  SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
-
-  SDValue U =
-      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
-  SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
-  SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
-
-  SDValue V =
-      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
-  SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
-
-  Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
-                   DAG.getNode(ISD::SHL, dl, VT, V, Shift));
-
-  Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
-                   DAG.getNode(ISD::ADD, dl, VT, UH, VH));
 }
 
 SDValue



More information about the llvm-commits mailing list