[Mlir-commits] [llvm] [mlir] [SelectionDAG] Expand fixed point multiplication into libcall (PR #79352)

Simon Pilgrim llvmlistbot at llvm.org
Mon Jan 29 06:15:52 PST 2024


================
@@ -10149,6 +10149,122 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
   return DAG.getSelect(dl, VT, Cond, SatVal, Result);
 }
 
+void TargetLowering::ForceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
+                                        bool Signed, EVT WideVT,
+                                        const SDValue LL, const SDValue LH,
+                                        const SDValue RL, const SDValue RH,
+                                        SDValue &Lo, SDValue &Hi) const {
+  // We can fall back to a libcall with an illegal type for the MUL if we
+  // have a libcall big enough.
+  // Also, we can fall back to a division in some cases, but that's a big
+  // performance hit in the general case.
+  RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+  if (WideVT == MVT::i16)
+    LC = RTLIB::MUL_I16;
+  else if (WideVT == MVT::i32)
+    LC = RTLIB::MUL_I32;
+  else if (WideVT == MVT::i64)
+    LC = RTLIB::MUL_I64;
+  else if (WideVT == MVT::i128)
+    LC = RTLIB::MUL_I128;
+
+  if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
+    // We'll expand the multiplication by brute force because we have no other
+    // options. This is a trivially-generalized version of the code from
+    // Hacker's Delight (itself derived from Knuth's Algorithm M from section
+    // 4.3.1).
+    EVT VT = LL.getValueType();
+    unsigned Bits = VT.getSizeInBits();
+    unsigned HalfBits = Bits >> 1;
+    SDValue Mask =
+        DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
+    SDValue LLL = DAG.getNode(ISD::AND, dl, VT, LL, Mask);
+    SDValue RLL = DAG.getNode(ISD::AND, dl, VT, RL, Mask);
+
+    SDValue T = DAG.getNode(ISD::MUL, dl, VT, LLL, RLL);
+    SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
+
+    SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
+    SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
+    SDValue LLH = DAG.getNode(ISD::SRL, dl, VT, LL, Shift);
+    SDValue RLH = DAG.getNode(ISD::SRL, dl, VT, RL, Shift);
+
+    SDValue U = DAG.getNode(ISD::ADD, dl, VT,
+                            DAG.getNode(ISD::MUL, dl, VT, LLH, RLL), TH);
+    SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
+    SDValue UH = DAG.getNode(ISD::SRL, dl, VT, U, Shift);
+
+    SDValue V = DAG.getNode(ISD::ADD, dl, VT,
+                            DAG.getNode(ISD::MUL, dl, VT, LLL, RLH), UL);
+    SDValue VH = DAG.getNode(ISD::SRL, dl, VT, V, Shift);
+
+    SDValue W =
+        DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LLH, RLH),
+                    DAG.getNode(ISD::ADD, dl, VT, UH, VH));
+    Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
+                     DAG.getNode(ISD::SHL, dl, VT, V, Shift));
+
+    Hi = DAG.getNode(ISD::ADD, dl, VT, W,
+                     DAG.getNode(ISD::ADD, dl, VT,
+                                 DAG.getNode(ISD::MUL, dl, VT, RH, LL),
+                                 DAG.getNode(ISD::MUL, dl, VT, RL, LH)));
----------------
RKSimon wrote:

(style) add a return here - and drop the trailing else (or maybe flip and do the libcall lowering inside the if()

https://github.com/llvm/llvm-project/pull/79352


More information about the Mlir-commits mailing list