[llvm] [AggressiveInstCombine] Fold i64 x i64 -> i128 multiply-by-parts (PR #156879)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 28 07:09:19 PDT 2025


================
@@ -1457,6 +1457,268 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
   return false;
 }
 
+/// Match low part of 128-bit multiplication.
+///
+/// Use counts are checked to prevent total instruction count increase as per
+/// contributors guide:
+/// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling
+static bool foldMul128Low(Instruction &I) {
+  auto *Ty = I.getType();
+  if (!Ty->isIntegerTy(64))
+    return false;
+
+  // (low_accum << 32) | lo(lo(y) * lo(x))
+  Value *LowAccum = nullptr, *YLowXLow = nullptr;
+  if (!match(&I, m_c_DisjointOr(
+                     m_OneUse(m_Shl(m_Value(LowAccum), m_SpecificInt(32))),
+                     m_OneUse(
+                         m_And(m_Value(YLowXLow), m_SpecificInt(0xffffffff))))))
+    return false;
+
+  // lo(cross_sum) + hi(lo(y) * lo(x))
+  Value *CrossSum = nullptr;
+  if (!match(
+          LowAccum,
+          m_c_Add(m_OneUse(m_And(m_Value(CrossSum), m_SpecificInt(0xffffffff))),
+                  m_OneUse(m_LShr(m_Specific(YLowXLow), m_SpecificInt(32))))) ||
+      LowAccum->hasNUsesOrMore(3))
+    return false;
+
+  // (hi(y) * lo(x)) + (lo(y) * hi(x))
+  Value *YHigh = nullptr, *XLow = nullptr, *YLowXHigh = nullptr;
+  if (!match(CrossSum, m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XLow))),
+                               m_Value(YLowXHigh))) ||
+      CrossSum->hasNUsesOrMore(4))
+    return false;
+
+  // lo(y) * lo(x)
+  Value *YLow = nullptr;
+  if (!match(YLowXLow, m_c_Mul(m_Value(YLow), m_Specific(XLow))) ||
+      YLowXLow->hasNUsesOrMore(3))
+    return false;
+
+  // lo(y) * hi(x)
+  Value *XHigh = nullptr;
+  if (!match(YLowXHigh, m_c_Mul(m_Specific(YLow), m_Value(XHigh))) ||
+      !YLowXHigh->hasNUses(2))
+    return false;
+
+  Value *X = nullptr;
+  // lo(x) = x & 0xffffffff
+  if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
+      !XLow->hasNUses(2))
+    return false;
+  // hi(x) = x >> 32
+  if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
+      !XHigh->hasNUses(2))
+    return false;
+
+  // Same for Y.
+  Value *Y = nullptr;
+  if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
+      !YLow->hasNUses(2))
+    return false;
+  if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
+      !YHigh->hasNUses(2))
+    return false;
+
+  IRBuilder<> Builder(&I);
+  Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
+  Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
+  Value *Mul128 = Builder.CreateMul(XExt, YExt);
+  Value *Res = Builder.CreateTrunc(Mul128, Builder.getInt64Ty());
+  I.replaceAllUsesWith(Res);
+
+  return true;
+}
+
+/// Match high part of 128-bit multiplication.
+///
+/// Use counts are checked to prevent total instruction count increase as per
+/// contributors guide:
+/// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling
+static bool foldMul128High(Instruction &I) {
+  auto *Ty = I.getType();
+  if (!Ty->isIntegerTy(64))
+    return false;
+
+  // intermediate_plus_carry + hi(low_accum)
+  Value *IntermediatePlusCarry = nullptr, *LowAccum = nullptr;
+  if (!match(&I,
+             m_c_Add(m_OneUse(m_Value(IntermediatePlusCarry)),
+                     m_OneUse(m_LShr(m_Value(LowAccum), m_SpecificInt(32))))))
+    return false;
+
+  // match:
+  //   (((hi(y) * hi(x)) + carry) + hi(cross_sum))
+  // or:
+  //   ((hi(cross_sum) + (hi(y) * hi(x))) + carry)
+  CmpPredicate Pred;
+  Value *CrossSum = nullptr, *XHigh = nullptr, *YHigh = nullptr,
+        *Carry = nullptr;
+  if (!match(IntermediatePlusCarry,
+             m_c_Add(m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))),
----------------
davemgreen wrote:

m_c_Mul -> m_Mul

https://github.com/llvm/llvm-project/pull/156879


More information about the llvm-commits mailing list