[llvm] [AggressiveInstCombine] Match long high-half multiply (PR #168396)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 18 10:01:28 PST 2025
================
@@ -1466,6 +1466,306 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
return false;
}
+/// Match high part of long multiplication.
+///
+/// Considering a multiply made up of high and low parts, we can split the
+/// multiply into:
+/// x * y == (xh*T + xl) * (yh*T + yl)
+/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
+/// This expands to
+/// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
+/// which can be drawn as
+/// [ xh*yh ]
+/// [ xh*yl ]
+/// [ xl*yh ]
+/// [ xl*yl ]
+/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
+/// some carrys. The carry makes this difficult and there are multiple ways of
+/// representing it. The ones we attempt to support here are:
+/// Carry: xh*yh + carry + lowsum
+/// carry = lowsum < xh*yl ? 0x1000000 : 0
+/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
+/// Ladder: xh*yh + c2>>32 + c3>>32
+/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
+/// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+/// crosssum = xh*yl + xl*yh
+/// carry = crosssum < xh*yl ? 0x1000000 : 0
+/// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
+/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+///
+/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
+/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
+static bool foldMulHigh(Instruction &I) {
+ Type *Ty = I.getType();
+ if (!Ty->isIntOrIntVectorTy())
+ return false;
+
+ unsigned BW = Ty->getScalarSizeInBits();
+ APInt LowMask = APInt::getLowBitsSet(BW, BW / 2);
+ if (BW % 2 != 0)
+ return false;
+
+ auto CreateMulHigh = [&](Value *X, Value *Y) {
+ IRBuilder<> Builder(&I);
+ Type *NTy = Ty->getWithNewBitWidth(BW * 2);
+ Value *XExt = Builder.CreateZExt(X, NTy);
+ Value *YExt = Builder.CreateZExt(Y, NTy);
+ Value *Mul = Builder.CreateMul(XExt, YExt);
+ Value *High = Builder.CreateLShr(Mul, BW);
+ Value *Res = Builder.CreateTrunc(High, Ty);
+ I.replaceAllUsesWith(Res);
+ LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
+ << *Y << "\n");
+ return true;
+ };
+
+ // Common check routines for X_lo*Y_lo and X_hi*Y_lo
+ auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
+ return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
+ m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+ };
+ auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
+ return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)),
+ m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+ };
+
+ auto foldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
+ Instruction *B) {
+ // Looking for LowSum >> 32 and carry (select)
+ if (Carry->getOpcode() != Instruction::Select)
+ std::swap(Carry, B);
+
+ // Carry = LowSum < XhYl ? 0x100000000 : 0
+ CmpPredicate Pred;
+ Value *LowSum, *XhYl;
+ if (!match(Carry,
+ m_OneUse(m_Select(
+ m_OneUse(m_ICmp(Pred, m_Value(LowSum), m_Value(XhYl))),
+ m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
----------------
topperc wrote:
APInt::getOneBitSet(BW, BW / 2)
https://github.com/llvm/llvm-project/pull/168396
More information about the llvm-commits
mailing list