[llvm] [InstCombine] lshr (mul (X, 2^N + 1)), N -> add (X, lshr(X, N)) (PR #90295)

via llvm-commits llvm-commits at lists.llvm.org
Sat May 18 09:46:20 PDT 2024


================
@@ -1483,6 +1494,16 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       }
     }
 
+    // lshr (mul nsw (X, 2^N + 1)), N -> add nsw (X, lshr(X, N))
+    if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC))))) {
+      if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+          MulC->logBase2() == ShAmtC) {
+        return BinaryOperator::CreateNSWAdd(
+            X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
+                                  I.isExact()));
+      }
+    }
+
----------------
goldsteinn wrote:

At least imo, the code would be simpler as:
```
if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) && BitWidth > 2 &&
    (*MulC - 1).isPowerOf2() && MulC->logBase2() == ShAmtC) {
  auto *BO0 = cast<OverflowingBinaryOperator>(Op0);
  // Look for a "splat" mul pattern - it replicates bits across each half
  // of a value, so a right shift is just a mask of the low bits:
  // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
  if (ShAmtC * 2 == BitWidth && BO0->hasNoUnsignedWrap())
    return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));

  // lshr (mul nuw (X, 2^N + 1)), N -> add nuw (X, lshr(X, N))
  if (Op0->hasOneUse() &&
      (BO0->hasNoUnsignedWrap() || &&BO0->hasNoSignedWrap())) {
    auto *NewAdd = BinaryOperator::CreateAdd(
        X,
        Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "", I.isExact()));
    NewAdd->setHasNoSignedWrap(BO0->hasNoSignedWrap());
    NewAdd->setHasNoUnsignedWrap(BO0->hasNoSignedWrap());
    return NewAdd;
  }
}
```

https://github.com/llvm/llvm-project/pull/90295


More information about the llvm-commits mailing list