[Mlir-commits] [llvm] [mlir] [SelectionDAG] Expand fixed point multiplication into libcall (PR #79352)
Simon Pilgrim
llvmlistbot at llvm.org
Mon Jan 29 06:15:52 PST 2024
================
@@ -10149,6 +10149,122 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
return DAG.getSelect(dl, VT, Cond, SatVal, Result);
}
+void TargetLowering::ForceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
+ bool Signed, EVT WideVT,
+ const SDValue LL, const SDValue LH,
+ const SDValue RL, const SDValue RH,
+ SDValue &Lo, SDValue &Hi) const {
+ // We can fall back to a libcall with an illegal type for the MUL if we
+ // have a libcall big enough.
+ // Also, we can fall back to a division in some cases, but that's a big
+ // performance hit in the general case.
+ RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+ if (WideVT == MVT::i16)
+ LC = RTLIB::MUL_I16;
+ else if (WideVT == MVT::i32)
+ LC = RTLIB::MUL_I32;
+ else if (WideVT == MVT::i64)
+ LC = RTLIB::MUL_I64;
+ else if (WideVT == MVT::i128)
+ LC = RTLIB::MUL_I128;
+
+ if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
+ // We'll expand the multiplication by brute force because we have no other
+ // options. This is a trivially-generalized version of the code from
+ // Hacker's Delight (itself derived from Knuth's Algorithm M from section
+ // 4.3.1).
+ EVT VT = LL.getValueType();
+ unsigned Bits = VT.getSizeInBits();
+ unsigned HalfBits = Bits >> 1;
+ SDValue Mask =
+ DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
+ SDValue LLL = DAG.getNode(ISD::AND, dl, VT, LL, Mask);
+ SDValue RLL = DAG.getNode(ISD::AND, dl, VT, RL, Mask);
+
+ SDValue T = DAG.getNode(ISD::MUL, dl, VT, LLL, RLL);
+ SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
+
+ SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
+ SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
+ SDValue LLH = DAG.getNode(ISD::SRL, dl, VT, LL, Shift);
+ SDValue RLH = DAG.getNode(ISD::SRL, dl, VT, RL, Shift);
+
+ SDValue U = DAG.getNode(ISD::ADD, dl, VT,
+ DAG.getNode(ISD::MUL, dl, VT, LLH, RLL), TH);
+ SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
+ SDValue UH = DAG.getNode(ISD::SRL, dl, VT, U, Shift);
+
+ SDValue V = DAG.getNode(ISD::ADD, dl, VT,
+ DAG.getNode(ISD::MUL, dl, VT, LLL, RLH), UL);
+ SDValue VH = DAG.getNode(ISD::SRL, dl, VT, V, Shift);
+
+ SDValue W =
+ DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LLH, RLH),
+ DAG.getNode(ISD::ADD, dl, VT, UH, VH));
+ Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
+ DAG.getNode(ISD::SHL, dl, VT, V, Shift));
+
+ Hi = DAG.getNode(ISD::ADD, dl, VT, W,
+ DAG.getNode(ISD::ADD, dl, VT,
+ DAG.getNode(ISD::MUL, dl, VT, RH, LL),
+ DAG.getNode(ISD::MUL, dl, VT, RL, LH)));
----------------
RKSimon wrote:
(style) add a return here - and drop the trailing else (or maybe flip and do the libcall lowering inside the if()
https://github.com/llvm/llvm-project/pull/79352
More information about the Mlir-commits
mailing list