[llvm] [AggressiveInstCombine] Fold i64 x i64 -> i128 multiply-by-parts (PR #156879)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 28 07:09:19 PDT 2025
================
@@ -1457,6 +1457,268 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
return false;
}
+/// Match low part of 128-bit multiplication.
+///
+/// Use counts are checked to prevent total instruction count increase as per
+/// contributors guide:
+/// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling
+static bool foldMul128Low(Instruction &I) {
+ auto *Ty = I.getType();
+ if (!Ty->isIntegerTy(64))
+ return false;
+
+ // (low_accum << 32) | lo(lo(y) * lo(x))
+ Value *LowAccum = nullptr, *YLowXLow = nullptr;
+ if (!match(&I, m_c_DisjointOr(
+ m_OneUse(m_Shl(m_Value(LowAccum), m_SpecificInt(32))),
+ m_OneUse(
+ m_And(m_Value(YLowXLow), m_SpecificInt(0xffffffff))))))
+ return false;
+
+ // lo(cross_sum) + hi(lo(y) * lo(x))
+ Value *CrossSum = nullptr;
+ if (!match(
+ LowAccum,
+ m_c_Add(m_OneUse(m_And(m_Value(CrossSum), m_SpecificInt(0xffffffff))),
+ m_OneUse(m_LShr(m_Specific(YLowXLow), m_SpecificInt(32))))) ||
+ LowAccum->hasNUsesOrMore(3))
+ return false;
+
+ // (hi(y) * lo(x)) + (lo(y) * hi(x))
+ Value *YHigh = nullptr, *XLow = nullptr, *YLowXHigh = nullptr;
+ if (!match(CrossSum, m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XLow))),
+ m_Value(YLowXHigh))) ||
+ CrossSum->hasNUsesOrMore(4))
+ return false;
+
+ // lo(y) * lo(x)
+ Value *YLow = nullptr;
+ if (!match(YLowXLow, m_c_Mul(m_Value(YLow), m_Specific(XLow))) ||
+ YLowXLow->hasNUsesOrMore(3))
+ return false;
+
+ // lo(y) * hi(x)
+ Value *XHigh = nullptr;
+ if (!match(YLowXHigh, m_c_Mul(m_Specific(YLow), m_Value(XHigh))) ||
+ !YLowXHigh->hasNUses(2))
+ return false;
+
+ Value *X = nullptr;
+ // lo(x) = x & 0xffffffff
+ if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
+ !XLow->hasNUses(2))
+ return false;
+ // hi(x) = x >> 32
+ if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
+ !XHigh->hasNUses(2))
+ return false;
+
+ // Same for Y.
+ Value *Y = nullptr;
+ if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
+ !YLow->hasNUses(2))
+ return false;
+ if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
+ !YHigh->hasNUses(2))
+ return false;
+
+ IRBuilder<> Builder(&I);
+ Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
+ Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
+ Value *Mul128 = Builder.CreateMul(XExt, YExt);
+ Value *Res = Builder.CreateTrunc(Mul128, Builder.getInt64Ty());
+ I.replaceAllUsesWith(Res);
+
+ return true;
+}
+
+/// Match high part of 128-bit multiplication.
+///
+/// Use counts are checked to prevent total instruction count increase as per
+/// contributors guide:
+/// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling
+static bool foldMul128High(Instruction &I) {
+ auto *Ty = I.getType();
+ if (!Ty->isIntegerTy(64))
+ return false;
+
+ // intermediate_plus_carry + hi(low_accum)
+ Value *IntermediatePlusCarry = nullptr, *LowAccum = nullptr;
+ if (!match(&I,
+ m_c_Add(m_OneUse(m_Value(IntermediatePlusCarry)),
+ m_OneUse(m_LShr(m_Value(LowAccum), m_SpecificInt(32))))))
+ return false;
+
+ // match:
+ // (((hi(y) * hi(x)) + carry) + hi(cross_sum))
+ // or:
+ // ((hi(cross_sum) + (hi(y) * hi(x))) + carry)
+ CmpPredicate Pred;
+ Value *CrossSum = nullptr, *XHigh = nullptr, *YHigh = nullptr,
+ *Carry = nullptr;
+ if (!match(IntermediatePlusCarry,
+ m_c_Add(m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))),
----------------
davemgreen wrote:
m_c_Mul -> m_Mul
https://github.com/llvm/llvm-project/pull/156879
More information about the llvm-commits
mailing list