[mlir] [llvm] [SelectionDAG] Expand fixed point multiplication into libcall (PR #79352)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 29 06:15:52 PST 2024
================
@@ -10149,6 +10149,122 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
return DAG.getSelect(dl, VT, Cond, SatVal, Result);
}
+void TargetLowering::ForceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
+ bool Signed, EVT WideVT,
+ const SDValue LL, const SDValue LH,
+ const SDValue RL, const SDValue RH,
+ SDValue &Lo, SDValue &Hi) const {
+ // We can fall back to a libcall with an illegal type for the MUL if we
+ // have a libcall big enough.
+ // Also, we can fall back to a division in some cases, but that's a big
+ // performance hit in the general case.
+ RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+ if (WideVT == MVT::i16)
+ LC = RTLIB::MUL_I16;
+ else if (WideVT == MVT::i32)
+ LC = RTLIB::MUL_I32;
+ else if (WideVT == MVT::i64)
+ LC = RTLIB::MUL_I64;
+ else if (WideVT == MVT::i128)
+ LC = RTLIB::MUL_I128;
+
+ if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
+ // We'll expand the multiplication by brute force because we have no other
+ // options. This is a trivially-generalized version of the code from
+ // Hacker's Delight (itself derived from Knuth's Algorithm M from section
+ // 4.3.1).
+ EVT VT = LL.getValueType();
+ unsigned Bits = VT.getSizeInBits();
+ unsigned HalfBits = Bits >> 1;
+ SDValue Mask =
+ DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
+ SDValue LLL = DAG.getNode(ISD::AND, dl, VT, LL, Mask);
+ SDValue RLL = DAG.getNode(ISD::AND, dl, VT, RL, Mask);
+
+ SDValue T = DAG.getNode(ISD::MUL, dl, VT, LLL, RLL);
+ SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
+
+ SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
+ SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
+ SDValue LLH = DAG.getNode(ISD::SRL, dl, VT, LL, Shift);
+ SDValue RLH = DAG.getNode(ISD::SRL, dl, VT, RL, Shift);
+
+ SDValue U = DAG.getNode(ISD::ADD, dl, VT,
+ DAG.getNode(ISD::MUL, dl, VT, LLH, RLL), TH);
+ SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
+ SDValue UH = DAG.getNode(ISD::SRL, dl, VT, U, Shift);
+
+ SDValue V = DAG.getNode(ISD::ADD, dl, VT,
+ DAG.getNode(ISD::MUL, dl, VT, LLL, RLH), UL);
+ SDValue VH = DAG.getNode(ISD::SRL, dl, VT, V, Shift);
+
+ SDValue W =
+ DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LLH, RLH),
+ DAG.getNode(ISD::ADD, dl, VT, UH, VH));
+ Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
+ DAG.getNode(ISD::SHL, dl, VT, V, Shift));
+
+ Hi = DAG.getNode(ISD::ADD, dl, VT, W,
+ DAG.getNode(ISD::ADD, dl, VT,
+ DAG.getNode(ISD::MUL, dl, VT, RH, LL),
+ DAG.getNode(ISD::MUL, dl, VT, RL, LH)));
+ } else {
+ // Attempt a libcall.
+ SDValue Ret;
+ TargetLowering::MakeLibCallOptions CallOptions;
+ CallOptions.setSExt(Signed);
+ CallOptions.setIsPostTypeLegalization(true);
+ if (shouldSplitFunctionArgumentsAsLittleEndian(DAG.getDataLayout())) {
+ // Halves of WideVT are packed into registers in different order
+ // depending on platform endianness. This is usually handled by
+ // the C calling convention, but we can't defer to it in
+ // the legalizer.
+ SDValue Args[] = {LL, LH, RL, RH};
+ Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
+ } else {
+ SDValue Args[] = {LH, LL, RH, RL};
+ Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
+ }
+ assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
+ "Ret value is a collection of constituent nodes holding result.");
+ if (DAG.getDataLayout().isLittleEndian()) {
+ // Same as above.
+ Lo = Ret.getOperand(0);
+ Hi = Ret.getOperand(1);
+ } else {
+ Lo = Ret.getOperand(1);
+ Hi = Ret.getOperand(0);
+ }
+ }
+}
+
+void TargetLowering::ForceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
+ bool Signed, const SDValue LHS,
+ const SDValue RHS, SDValue &Lo,
+ SDValue &Hi) const {
+ EVT VT = LHS.getValueType();
+ assert(RHS.getValueType() == VT && "Mismatching operand types");
+
+ SDValue HiLHS;
+ SDValue HiRHS;
+ if (Signed) {
+ // The high part is obtained by SRA'ing all but one of the bits of low
+ // part.
+ unsigned LoSize = VT.getFixedSizeInBits();
+ HiLHS = DAG.getNode(
+ ISD::SRA, dl, VT, LHS,
+ DAG.getConstant(LoSize - 1, dl, getPointerTy(DAG.getDataLayout())));
+ HiRHS = DAG.getNode(
+ ISD::SRA, dl, VT, RHS,
+ DAG.getConstant(LoSize - 1, dl, getPointerTy(DAG.getDataLayout())));
----------------
RKSimon wrote:
Use getShiftAmountConstant ?
https://github.com/llvm/llvm-project/pull/79352
More information about the llvm-commits
mailing list