[mlir] [llvm] [SelectionDAG] Expand fixed point multiplication into libcall (PR #79352)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 29 06:15:52 PST 2024


================
@@ -10149,6 +10149,122 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
   return DAG.getSelect(dl, VT, Cond, SatVal, Result);
 }
 
+void TargetLowering::ForceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
+                                        bool Signed, EVT WideVT,
+                                        const SDValue LL, const SDValue LH,
+                                        const SDValue RL, const SDValue RH,
+                                        SDValue &Lo, SDValue &Hi) const {
+  // We can fall back to a libcall with an illegal type for the MUL if we
+  // have a libcall big enough.
+  // Also, we can fall back to a division in some cases, but that's a big
+  // performance hit in the general case.
+  RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+  if (WideVT == MVT::i16)
+    LC = RTLIB::MUL_I16;
+  else if (WideVT == MVT::i32)
+    LC = RTLIB::MUL_I32;
+  else if (WideVT == MVT::i64)
+    LC = RTLIB::MUL_I64;
+  else if (WideVT == MVT::i128)
+    LC = RTLIB::MUL_I128;
+
+  if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
+    // We'll expand the multiplication by brute force because we have no other
+    // options. This is a trivially-generalized version of the code from
+    // Hacker's Delight (itself derived from Knuth's Algorithm M from section
+    // 4.3.1).
+    EVT VT = LL.getValueType();
+    unsigned Bits = VT.getSizeInBits();
+    unsigned HalfBits = Bits >> 1;
+    SDValue Mask =
+        DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
+    SDValue LLL = DAG.getNode(ISD::AND, dl, VT, LL, Mask);
+    SDValue RLL = DAG.getNode(ISD::AND, dl, VT, RL, Mask);
+
+    SDValue T = DAG.getNode(ISD::MUL, dl, VT, LLL, RLL);
+    SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
+
+    SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
+    SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
+    SDValue LLH = DAG.getNode(ISD::SRL, dl, VT, LL, Shift);
+    SDValue RLH = DAG.getNode(ISD::SRL, dl, VT, RL, Shift);
+
+    SDValue U = DAG.getNode(ISD::ADD, dl, VT,
+                            DAG.getNode(ISD::MUL, dl, VT, LLH, RLL), TH);
+    SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
+    SDValue UH = DAG.getNode(ISD::SRL, dl, VT, U, Shift);
+
+    SDValue V = DAG.getNode(ISD::ADD, dl, VT,
+                            DAG.getNode(ISD::MUL, dl, VT, LLL, RLH), UL);
+    SDValue VH = DAG.getNode(ISD::SRL, dl, VT, V, Shift);
+
+    SDValue W =
+        DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LLH, RLH),
+                    DAG.getNode(ISD::ADD, dl, VT, UH, VH));
+    Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
+                     DAG.getNode(ISD::SHL, dl, VT, V, Shift));
+
+    Hi = DAG.getNode(ISD::ADD, dl, VT, W,
+                     DAG.getNode(ISD::ADD, dl, VT,
+                                 DAG.getNode(ISD::MUL, dl, VT, RH, LL),
+                                 DAG.getNode(ISD::MUL, dl, VT, RL, LH)));
+  } else {
+    // Attempt a libcall.
+    SDValue Ret;
+    TargetLowering::MakeLibCallOptions CallOptions;
+    CallOptions.setSExt(Signed);
+    CallOptions.setIsPostTypeLegalization(true);
+    if (shouldSplitFunctionArgumentsAsLittleEndian(DAG.getDataLayout())) {
+      // Halves of WideVT are packed into registers in different order
+      // depending on platform endianness. This is usually handled by
+      // the C calling convention, but we can't defer to it in
+      // the legalizer.
+      SDValue Args[] = {LL, LH, RL, RH};
+      Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
+    } else {
+      SDValue Args[] = {LH, LL, RH, RL};
+      Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
+    }
+    assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
+           "Ret value is a collection of constituent nodes holding result.");
+    if (DAG.getDataLayout().isLittleEndian()) {
+      // Same as above.
+      Lo = Ret.getOperand(0);
+      Hi = Ret.getOperand(1);
+    } else {
+      Lo = Ret.getOperand(1);
+      Hi = Ret.getOperand(0);
+    }
+  }
+}
+
+void TargetLowering::ForceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
+                                        bool Signed, const SDValue LHS,
+                                        const SDValue RHS, SDValue &Lo,
+                                        SDValue &Hi) const {
+  EVT VT = LHS.getValueType();
+  assert(RHS.getValueType() == VT && "Mismatching operand types");
+
+  SDValue HiLHS;
+  SDValue HiRHS;
+  if (Signed) {
+    // The high part is obtained by SRA'ing all but one of the bits of low
+    // part.
+    unsigned LoSize = VT.getFixedSizeInBits();
+    HiLHS = DAG.getNode(
+        ISD::SRA, dl, VT, LHS,
+        DAG.getConstant(LoSize - 1, dl, getPointerTy(DAG.getDataLayout())));
+    HiRHS = DAG.getNode(
+        ISD::SRA, dl, VT, RHS,
+        DAG.getConstant(LoSize - 1, dl, getPointerTy(DAG.getDataLayout())));
----------------
RKSimon wrote:

Use getShiftAmountConstant ?

https://github.com/llvm/llvm-project/pull/79352


More information about the llvm-commits mailing list