[llvm] [InstCombine] Fold X * (2^N + 1) >> N -> X + X >> N, or directly to X if X >> N is 0 (PR #90295)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 20 20:43:46 PDT 2024


================
@@ -1479,6 +1479,30 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
   if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
     return X;
 
+  // Look for a "splat" mul pattern - it replicates bits across each half
+  // of a value, so a right shift is just a mask of the low bits:
+  // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
+  const APInt *MulC;
+  const APInt *ShAmt;
+  if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
+      match(Op1, m_APInt(ShAmt))) {
+    unsigned ShAmtC = ShAmt->getZExtValue();
+    unsigned BitWidth = ShAmt->getBitWidth();
+    if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+        MulC->logBase2() == ShAmtC) {
+      // FIXME: This condition should be covered by the computeKnownBits, but
+      // for some reason it is not, so keep this in for now. This has no
+      // negative affects, but KnownBits should be able to infer a number of
+      // leading bits based on 2^N + 1 not wrapping, as that means 2^N must not
+      // wrap either, which means the top N bits of X must be 0.
+      if (ShAmtC * 2 == BitWidth)
+        return X;
+      const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
+      if (XKnown.countMaxActiveBits() <= ShAmtC)
+        return X;
----------------
AtariDreams wrote:

Fixed it!

https://github.com/llvm/llvm-project/pull/90295


More information about the llvm-commits mailing list