[llvm] [TargetLowering] Pull similar code out of the forceExpandWideMUL into a helper. NFC (PR #124371)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 24 16:02:39 PST 2025


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/124371

These functions have similar code. One of them calculates the 2x width full product from 2 sources. The other calculates the product from 2 sources that have low and high halves.

This patch introduces a new function that takes HiLHS and HiRHS as optional values. If they are not null, they will be used in the calculation of the Hi half. The Signed flag can only be set when HiLHS/HiRHS are null.

>From 8de730458200de484a0bdbf365739f1945e70d17 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Fri, 24 Jan 2025 15:59:33 -0800
Subject: [PATCH] [TargetLowering] Pull similar code out of the
 forceExpandWideMUL into a helper. NFC

These functions have similar code. One of them calculates the 2x width
full product from 2 sources. The other calculates the product from
2 sources that have low and high halves.

This patch introduces a new function that takes HiLHS and HiRHS as
optional values. If they are not null, they will be used in the
calculation of the Hi half. The Signed flag can only be set when
HiLHS/HiRHS are null.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |   9 +
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 186 ++++++++++--------
 2 files changed, 108 insertions(+), 87 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 861cffdc115a46..4ad2835a70404d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5503,6 +5503,15 @@ class TargetLowering : public TargetLoweringBase {
   bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
                   SelectionDAG &DAG) const;
 
+  /// Calculate the product twice the width of LHS and RHS. If HiLHS/HiRHS are
+  /// non-null they will be included in the multiplication. The expansion works
+  /// by splitting the 2 inputs into 4 pieces that we can multiply and add
+  /// together without neding MULH or MUL_LOHI.
+  void forceExpandMultiply(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
+                           SDValue &Lo, SDValue &Hi, SDValue LHS, SDValue RHS,
+                           SDValue HiLHS = SDValue(),
+                           SDValue HiRHS = SDValue()) const;
+
   /// forceExpandWideMUL - Unconditionally expand a MUL into either a libcall or
   /// brute force via a wide multiplication. The expansion works by
   /// attempting to do a multiplication on a wider type twice the size of the
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 0d039860b9f0fd..8757d19458dc55 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -10857,6 +10857,64 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
   return DAG.getSelect(dl, VT, Cond, SatVal, Result);
 }
 
+void TargetLowering::forceExpandMultiply(SelectionDAG &DAG, const SDLoc &dl,
+                                         bool Signed, SDValue &Lo, SDValue &Hi,
+                                         SDValue LHS, SDValue RHS,
+                                         SDValue HiLHS, SDValue HiRHS) const {
+  EVT VT = LHS.getValueType();
+  assert(RHS.getValueType() == VT && "Mismatching operand types");
+
+  assert((HiLHS && HiRHS) || (!HiLHS && !HiRHS));
+  assert((!Signed || !HiLHS) &&
+         "Signed flag should only be set when HiLHS and RiRHS are null");
+
+  // We'll expand the multiplication by brute force because we have no other
+  // options. This is a trivially-generalized version of the code from
+  // Hacker's Delight (itself derived from Knuth's Algorithm M from section
+  // 4.3.1). If Signed is set, we can use arithmetic right shifts to propagate
+  // sign bits while calculating the Hi half.
+  unsigned Bits = VT.getSizeInBits();
+  unsigned HalfBits = Bits / 2;
+  SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
+  SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
+  SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
+
+  SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
+  SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
+
+  SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
+  // This is always an unsigned shift.
+  SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
+
+  unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
+  SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
+  SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
+
+  SDValue U =
+      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
+  SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
+  SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
+
+  SDValue V =
+      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
+  SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
+
+  Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
+                   DAG.getNode(ISD::SHL, dl, VT, V, Shift));
+
+  Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
+                   DAG.getNode(ISD::ADD, dl, VT, UH, VH));
+
+  // If HiLHS and HiRHS are set, multiply them by the opposite low part and add
+  // them to products to Hi.
+  if (HiLHS) {
+    Hi = DAG.getNode(ISD::ADD, dl, VT, Hi,
+                     DAG.getNode(ISD::ADD, dl, VT,
+                                 DAG.getNode(ISD::MUL, dl, VT, HiRHS, LHS),
+                                 DAG.getNode(ISD::MUL, dl, VT, RHS, HiLHS)));
+  }
+}
+
 void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
                                         bool Signed, EVT WideVT,
                                         const SDValue LL, const SDValue LH,
@@ -10877,45 +10935,7 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
     LC = RTLIB::MUL_I128;
 
   if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
-    // We'll expand the multiplication by brute force because we have no other
-    // options. This is a trivially-generalized version of the code from
-    // Hacker's Delight (itself derived from Knuth's Algorithm M from section
-    // 4.3.1).
-    EVT VT = LL.getValueType();
-    unsigned Bits = VT.getSizeInBits();
-    unsigned HalfBits = Bits >> 1;
-    SDValue Mask =
-        DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
-    SDValue LLL = DAG.getNode(ISD::AND, dl, VT, LL, Mask);
-    SDValue RLL = DAG.getNode(ISD::AND, dl, VT, RL, Mask);
-
-    SDValue T = DAG.getNode(ISD::MUL, dl, VT, LLL, RLL);
-    SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
-
-    SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
-    SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
-    SDValue LLH = DAG.getNode(ISD::SRL, dl, VT, LL, Shift);
-    SDValue RLH = DAG.getNode(ISD::SRL, dl, VT, RL, Shift);
-
-    SDValue U = DAG.getNode(ISD::ADD, dl, VT,
-                            DAG.getNode(ISD::MUL, dl, VT, LLH, RLL), TH);
-    SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
-    SDValue UH = DAG.getNode(ISD::SRL, dl, VT, U, Shift);
-
-    SDValue V = DAG.getNode(ISD::ADD, dl, VT,
-                            DAG.getNode(ISD::MUL, dl, VT, LLL, RLH), UL);
-    SDValue VH = DAG.getNode(ISD::SRL, dl, VT, V, Shift);
-
-    SDValue W =
-        DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LLH, RLH),
-                    DAG.getNode(ISD::ADD, dl, VT, UH, VH));
-    Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
-                     DAG.getNode(ISD::SHL, dl, VT, V, Shift));
-
-    Hi = DAG.getNode(ISD::ADD, dl, VT, W,
-                     DAG.getNode(ISD::ADD, dl, VT,
-                                 DAG.getNode(ISD::MUL, dl, VT, RH, LL),
-                                 DAG.getNode(ISD::MUL, dl, VT, RL, LH)));
+    forceExpandMultiply(DAG, dl, /*Signed=*/false, Lo, Hi, LL, RL, LH, RH);
   } else {
     // Attempt a libcall.
     SDValue Ret;
@@ -10965,58 +10985,50 @@ void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
   else if (WideVT == MVT::i128)
     LC = RTLIB::MUL_I128;
 
-  if (LC != RTLIB::UNKNOWN_LIBCALL && getLibcallName(LC)) {
-    SDValue HiLHS, HiRHS;
-    if (Signed) {
-      // The high part is obtained by SRA'ing all but one of the bits of low
-      // part.
-      unsigned LoSize = VT.getFixedSizeInBits();
-      SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
-      HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
-      HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
-    } else {
-      HiLHS = DAG.getConstant(0, dl, VT);
-      HiRHS = DAG.getConstant(0, dl, VT);
-    }
-    forceExpandWideMUL(DAG, dl, Signed, WideVT, LHS, HiLHS, RHS, HiRHS, Lo, Hi);
+  if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
+    forceExpandMultiply(DAG, dl, Signed, Lo, Hi, LHS, RHS);
     return;
   }
 
-  // Expand the multiplication by brute force. This is a generalized-version of
-  // the code from Hacker's Delight (itself derived from Knuth's Algorithm M
-  // from section 4.3.1) combined with the Hacker's delight code
-  // for calculating mulhs.
-  unsigned Bits = VT.getSizeInBits();
-  unsigned HalfBits = Bits / 2;
-  SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
-  SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
-  SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
-
-  SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
-  SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
-
-  SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
-  // This is always an unsigned shift.
-  SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
-
-  unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
-  SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
-  SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
-
-  SDValue U =
-      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
-  SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
-  SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
-
-  SDValue V =
-      DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
-  SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
-
-  Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
-                   DAG.getNode(ISD::SHL, dl, VT, V, Shift));
+  SDValue HiLHS, HiRHS;
+  if (Signed) {
+    // The high part is obtained by SRA'ing all but one of the bits of low
+    // part.
+    unsigned LoSize = VT.getFixedSizeInBits();
+    SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
+    HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
+    HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
+  } else {
+    HiLHS = DAG.getConstant(0, dl, VT);
+    HiRHS = DAG.getConstant(0, dl, VT);
+  }
 
-  Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
-                   DAG.getNode(ISD::ADD, dl, VT, UH, VH));
+  // Attempt a libcall.
+  SDValue Ret;
+  TargetLowering::MakeLibCallOptions CallOptions;
+  CallOptions.setIsSigned(Signed);
+  CallOptions.setIsPostTypeLegalization(true);
+  if (shouldSplitFunctionArgumentsAsLittleEndian(DAG.getDataLayout())) {
+    // Halves of WideVT are packed into registers in different order
+    // depending on platform endianness. This is usually handled by
+    // the C calling convention, but we can't defer to it in
+    // the legalizer.
+    SDValue Args[] = {LHS, HiLHS, RHS, HiRHS};
+    Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
+  } else {
+    SDValue Args[] = {HiLHS, LHS, HiRHS, RHS};
+    Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
+  }
+  assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
+         "Ret value is a collection of constituent nodes holding result.");
+  if (DAG.getDataLayout().isLittleEndian()) {
+    // Same as above.
+    Lo = Ret.getOperand(0);
+    Hi = Ret.getOperand(1);
+  } else {
+    Lo = Ret.getOperand(1);
+    Hi = Ret.getOperand(0);
+  }
 }
 
 SDValue



More information about the llvm-commits mailing list