[llvm] [AggressiveInstCombine] Match long high-half multiply (PR #168396)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 19 06:24:26 PST 2025
================
@@ -1466,6 +1466,307 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
return false;
}
+/// Match high part of long multiplication.
+///
+/// Considering a multiply made up of high and low parts, we can split the
+/// multiply into:
+/// x * y == (xh*T + xl) * (yh*T + yl)
+/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
+/// This expands to
+/// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
+/// which can be drawn as
+/// [ xh*yh ]
+/// [ xh*yl ]
+/// [ xl*yh ]
+/// [ xl*yl ]
+/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
+/// some carrys. The carry makes this difficult and there are multiple ways of
+/// representing it. The ones we attempt to support here are:
+/// Carry: xh*yh + carry + lowsum
+/// carry = lowsum < xh*yl ? 0x1000000 : 0
+/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
+/// Ladder: xh*yh + c2>>32 + c3>>32
+/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
+/// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+/// crosssum = xh*yl + xl*yh
+/// carry = crosssum < xh*yl ? 0x1000000 : 0
+/// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
+/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+///
+/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
+/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
+static bool foldMulHigh(Instruction &I) {
+ Type *Ty = I.getType();
+ if (!Ty->isIntOrIntVectorTy())
+ return false;
+
+ unsigned BW = Ty->getScalarSizeInBits();
+ APInt LowMask = APInt::getLowBitsSet(BW, BW / 2);
+ if (BW % 2 != 0)
+ return false;
+
+ auto CreateMulHigh = [&](Value *X, Value *Y) {
+ IRBuilder<> Builder(&I);
+ Type *NTy = Ty->getWithNewBitWidth(BW * 2);
+ Value *XExt = Builder.CreateZExt(X, NTy);
+ Value *YExt = Builder.CreateZExt(Y, NTy);
+ Value *Mul = Builder.CreateMul(XExt, YExt);
+ Value *High = Builder.CreateLShr(Mul, BW);
+ Value *Res = Builder.CreateTrunc(High, Ty);
+ Res->takeName(&I);
+ I.replaceAllUsesWith(Res);
+ LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
+ << *Y << "\n");
+ return true;
+ };
+
+ // Common check routines for X_lo*Y_lo and X_hi*Y_lo
+ auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
+ return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
+ m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+ };
+ auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
+ return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)),
+ m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+ };
+
+ auto FoldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
+ Instruction *B) {
+ // Looking for LowSum >> 32 and carry (select)
+ if (Carry->getOpcode() != Instruction::Select)
+ std::swap(Carry, B);
+
+ // Carry = LowSum < XhYl ? 0x100000000 : 0
+ Value *LowSum, *XhYl;
+ if (!match(Carry,
+ m_OneUse(m_Select(
+ m_OneUse(m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(LowSum),
+ m_Value(XhYl))),
+ m_SpecificInt(APInt::getOneBitSet(BW, BW / 2)),
+ m_SpecificInt(0)))))
+ return false;
+
+ // XhYl can be Xh*Yl or Xl*Yh
+ if (!CheckHiLo(XhYl, X, Y)) {
+ if (CheckHiLo(XhYl, Y, X))
+ std::swap(X, Y);
+ else
+ return false;
+ }
+ if (XhYl->hasNUsesOrMore(3))
+ return false;
+
+ // B = LowSum >> 16
+ if (!match(B,
+ m_OneUse(m_LShr(m_Specific(LowSum), m_SpecificInt(BW / 2)))) ||
+ LowSum->hasNUsesOrMore(3))
+ return false;
+
+ // LowSum = XhYl + XlYh + XlYl>>32
+ Value *XlYh, *XlYl;
+ auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2));
+ if (!match(LowSum,
+ m_c_Add(m_Specific(XhYl),
+ m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) &&
+ !match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)),
+ m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) &&
+ !match(LowSum,
+ m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl),
+ m_OneUse(m_Value(XlYh)))))))
+ return false;
+
+ // Check XlYl and XlYh
+ if (!CheckLoLo(XlYl, X, Y))
+ return false;
+ if (!CheckHiLo(XlYh, Y, X))
+ return false;
+
+ return CreateMulHigh(X, Y);
+ };
+
+ auto FoldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
+ Instruction *B) {
+ // xh*yh + c2>>32 + c3>>32
+ // c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
+ Value *XlYh, *XhYl, *C2, *C3;
+ // Strip off the two expected shifts.
+ if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BW / 2))) ||
+ !match(B, m_LShr(m_Value(C3), m_SpecificInt(BW / 2))))
+ return false;
+
+ // Match c3 = c2&0xffffffff + xl*yh
+ if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)),
+ m_Value(XhYl))))
+ std::swap(C2, C3);
+ if (!match(C3,
+ m_c_Add(m_OneUse(m_And(m_Specific(C2), m_SpecificInt(LowMask))),
+ m_Value(XhYl))) ||
----------------
dtcxzyw wrote:
```suggestion
m_Value(XlYh))) ||
```
https://github.com/llvm/llvm-project/pull/168396
More information about the llvm-commits
mailing list