[llvm] [TargetLowering] Improve one signature of forceExpandWideMUL. (PR #123991)

Sergei Barannikov via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 24 09:03:24 PST 2025


================
@@ -10952,22 +10952,71 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
                                         SDValue &Hi) const {
   EVT VT = LHS.getValueType();
   assert(RHS.getValueType() == VT && "Mismatching operand types");
+  EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
+  // We can fall back to a libcall with an illegal type for the MUL if we
+  // have a libcall big enough.
+  RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+  if (WideVT == MVT::i16)
+    LC = RTLIB::MUL_I16;
+  else if (WideVT == MVT::i32)
+    LC = RTLIB::MUL_I32;
+  else if (WideVT == MVT::i64)
+    LC = RTLIB::MUL_I64;
+  else if (WideVT == MVT::i128)
+    LC = RTLIB::MUL_I128;
 
-  SDValue HiLHS;
-  SDValue HiRHS;
-  if (Signed) {
-    // The high part is obtained by SRA'ing all but one of the bits of low
-    // part.
-    unsigned LoSize = VT.getFixedSizeInBits();
-    SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
-    HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
-    HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
-  } else {
-    HiLHS = DAG.getConstant(0, dl, VT);
-    HiRHS = DAG.getConstant(0, dl, VT);
+  if (LC != RTLIB::UNKNOWN_LIBCALL && getLibcallName(LC)) {
+    SDValue HiLHS, HiRHS;
+    if (Signed) {
+      // The high part is obtained by SRA'ing all but one of the bits of low
+      // part.
+      unsigned LoSize = VT.getFixedSizeInBits();
+      SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
+      HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
+      HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
+    } else {
+      HiLHS = DAG.getConstant(0, dl, VT);
+      HiRHS = DAG.getConstant(0, dl, VT);
+    }
+    forceExpandWideMUL(DAG, dl, Signed, WideVT, LHS, HiLHS, RHS, HiRHS, Lo, Hi);
+    return;
   }
-  EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
-  forceExpandWideMUL(DAG, dl, Signed, WideVT, LHS, HiLHS, RHS, HiRHS, Lo, Hi);
+
+  // Expand the multiplication by brute force. This is a generalized-version of
+  // the code from Hacker's Delight (itself derived from Knuth's Algorithm M
+  // from section 4.3.1) combined with the Hacker's delight code
+  // for calculating mulhs.
+  unsigned Bits = VT.getSizeInBits();
+  unsigned HalfBits = Bits / 2;
+  SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
+  SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
+  SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
+
+  SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
+  SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
+
+  SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
+  // This is always an unsigned shift.
+  SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
+
+  unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
+  SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
+  SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
+
+  SDValue U =
+      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
----------------
s-barannikov wrote:

`LH`/`RH` are already known to produce 0/1 or 0/-1 here depending on the value of `Signed` (because of shifting all bits except the last one).
Making this a generic combine using known bits also sounds good.


https://github.com/llvm/llvm-project/pull/123991


More information about the llvm-commits mailing list