[llvm] [AggressiveInstCombine] Match long high-half multiply (PR #168396)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 19 06:24:27 PST 2025


================
@@ -1466,6 +1466,307 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
   return false;
 }
 
+/// Match high part of long multiplication.
+///
+/// Considering a multiply made up of high and low parts, we can split the
+/// multiply into:
+///  x * y == (xh*T + xl) * (yh*T + yl)
+/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
+/// This expands to
+///  xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
+/// which can be drawn as
+/// [  xh*yh  ]
+///      [  xh*yl  ]
+///      [  xl*yh  ]
+///           [  xl*yl  ]
+/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
+/// some carrys. The carry makes this difficult and there are multiple ways of
+/// representing it. The ones we attempt to support here are:
+///  Carry:  xh*yh + carry + lowsum
+///          carry = lowsum < xh*yl ? 0x1000000 : 0
+///          lowsum = xh*yl + xl*yh + (xl*yl>>32)
+///  Ladder: xh*yh + c2>>32 + c3>>32
+///          c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
+///  Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+///          crosssum = xh*yl + xl*yh
+///          carry = crosssum < xh*yl ? 0x1000000 : 0
+///  Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
+///          low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+///
+/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
+/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
+static bool foldMulHigh(Instruction &I) {
+  Type *Ty = I.getType();
+  if (!Ty->isIntOrIntVectorTy())
+    return false;
+
+  unsigned BW = Ty->getScalarSizeInBits();
+  APInt LowMask = APInt::getLowBitsSet(BW, BW / 2);
+  if (BW % 2 != 0)
+    return false;
+
+  auto CreateMulHigh = [&](Value *X, Value *Y) {
+    IRBuilder<> Builder(&I);
+    Type *NTy = Ty->getWithNewBitWidth(BW * 2);
+    Value *XExt = Builder.CreateZExt(X, NTy);
+    Value *YExt = Builder.CreateZExt(Y, NTy);
+    Value *Mul = Builder.CreateMul(XExt, YExt);
+    Value *High = Builder.CreateLShr(Mul, BW);
+    Value *Res = Builder.CreateTrunc(High, Ty);
+    Res->takeName(&I);
+    I.replaceAllUsesWith(Res);
+    LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
+                      << *Y << "\n");
+    return true;
+  };
+
+  // Common check routines for X_lo*Y_lo and X_hi*Y_lo
+  auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
+    return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
+                               m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+  };
+  auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
+    return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)),
+                               m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+  };
+
+  auto FoldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
+                              Instruction *B) {
+    // Looking for LowSum >> 32 and carry (select)
+    if (Carry->getOpcode() != Instruction::Select)
+      std::swap(Carry, B);
+
+    // Carry = LowSum < XhYl ? 0x100000000 : 0
+    Value *LowSum, *XhYl;
+    if (!match(Carry,
+               m_OneUse(m_Select(
+                   m_OneUse(m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(LowSum),
+                                           m_Value(XhYl))),
+                   m_SpecificInt(APInt::getOneBitSet(BW, BW / 2)),
+                   m_SpecificInt(0)))))
+      return false;
+
+    // XhYl can be Xh*Yl or Xl*Yh
+    if (!CheckHiLo(XhYl, X, Y)) {
+      if (CheckHiLo(XhYl, Y, X))
+        std::swap(X, Y);
+      else
+        return false;
+    }
+    if (XhYl->hasNUsesOrMore(3))
+      return false;
+
+    // B = LowSum >> 16
+    if (!match(B,
+               m_OneUse(m_LShr(m_Specific(LowSum), m_SpecificInt(BW / 2)))) ||
+        LowSum->hasNUsesOrMore(3))
+      return false;
+
+    // LowSum = XhYl + XlYh + XlYl>>32
+    Value *XlYh, *XlYl;
+    auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2));
+    if (!match(LowSum,
+               m_c_Add(m_Specific(XhYl),
+                       m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) &&
+        !match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)),
+                               m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) &&
+        !match(LowSum,
+               m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl),
+                                                m_OneUse(m_Value(XlYh)))))))
+      return false;
+
+    // Check XlYl and XlYh
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+    if (!CheckHiLo(XlYh, Y, X))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  auto FoldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
+                               Instruction *B) {
+    //  xh*yh + c2>>32 + c3>>32
+    //  c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
+    Value *XlYh, *XhYl, *C2, *C3;
+    // Strip off the two expected shifts.
+    if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BW / 2))) ||
+        !match(B, m_LShr(m_Value(C3), m_SpecificInt(BW / 2))))
+      return false;
+
+    // Match c3 = c2&0xffffffff + xl*yh
+    if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)),
+                           m_Value(XhYl))))
----------------
dtcxzyw wrote:

```suggestion
                           m_Value(XlYh))))
```

https://github.com/llvm/llvm-project/pull/168396


More information about the llvm-commits mailing list