[llvm] 0d9c027 - [InstCombine] Make `takeLog2` visible in all of InstCombine; NFC
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 10 14:22:03 PST 2025
Author: Noah Goldstein
Date: 2025-01-10T16:21:35-06:00
New Revision: 0d9c027ad7fa36a607386e24d4928c9046f6ff56
URL: https://github.com/llvm/llvm-project/commit/0d9c027ad7fa36a607386e24d4928c9046f6ff56
DIFF: https://github.com/llvm/llvm-project/commit/0d9c027ad7fa36a607386e24d4928c9046f6ff56.diff
LOG: [InstCombine] Make `takeLog2` visible in all of InstCombine; NFC
Also add `tryGetLog2` helper that encapsulates the common pattern:
```
if (takeLog2(..., /*DoFold=*/false)) {
Value * Log2 = takeLog2(..., /*DoFold=*/true);
...
}
```
Closes #122498
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineInternal.h
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index f6992119280c16..83e1da98deeda0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -785,6 +785,18 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
void handlePotentiallyDeadBlocks(SmallVectorImpl<BasicBlock *> &Worklist);
void handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc);
void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr);
+
+ /// Take the exact integer log2 of the value. If DoFold is true, create the
+ /// actual instructions, otherwise return a non-null dummy value. Return
+ /// nullptr on failure. Note, if DoFold is true the caller must ensure that
+ /// takeLog2 will succeed, otherwise it may create stray instructions.
+ Value *takeLog2(Value *Op, unsigned Depth, bool AssumeNonZero, bool DoFold);
+
+ Value *tryGetLog2(Value *Op, bool AssumeNonZero) {
+ if (takeLog2(Op, /*Depth=*/0, AssumeNonZero, /*DoFold=*/false))
+ return takeLog2(Op, /*Depth=*/0, AssumeNonZero, /*DoFold=*/true);
+ return nullptr;
+ }
};
class Negator final {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 0c34cf01bdf1a9..1c5070a1b867c4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -185,9 +185,6 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands,
return nullptr;
}
-static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
- bool AssumeNonZero, bool DoFold);
-
Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V =
@@ -531,19 +528,13 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
// (shl Op1, Log2(Op0))
// if Log2(Op1) folds away ->
// (shl Op0, Log2(Op1))
- if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
- /*DoFold*/ false)) {
- Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
- /*DoFold*/ true);
+ if (Value *Res = tryGetLog2(Op0, /*AssumeNonZero=*/false)) {
BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res);
// We can only propegate nuw flag.
Shl->setHasNoUnsignedWrap(HasNUW);
return Shl;
}
- if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
- /*DoFold*/ false)) {
- Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
- /*DoFold*/ true);
+ if (Value *Res = tryGetLog2(Op1, /*AssumeNonZero=*/false)) {
BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
// We can only propegate nuw flag.
Shl->setHasNoUnsignedWrap(HasNUW);
@@ -1407,13 +1398,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
return nullptr;
}
-static const unsigned MaxDepth = 6;
-
-// Take the exact integer log2 of the value. If DoFold is true, create the
-// actual instructions, otherwise return a non-null dummy value. Return nullptr
-// on failure.
-static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
- bool AssumeNonZero, bool DoFold) {
+Value *InstCombinerImpl::takeLog2(Value *Op, unsigned Depth, bool AssumeNonZero,
+ bool DoFold) {
auto IfFold = [DoFold](function_ref<Value *()> Fn) {
if (!DoFold)
return reinterpret_cast<Value *>(-1);
@@ -1432,14 +1418,14 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
});
// The remaining tests are all recursive, so bail out if we hit the limit.
- if (Depth++ == MaxDepth)
+ if (Depth++ == MaxAnalysisRecursionDepth)
return nullptr;
// log2(zext X) -> zext log2(X)
// FIXME: Require one use?
Value *X, *Y;
if (match(Op, m_ZExt(m_Value(X))))
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
// log2(trunc x) -> trunc log2(X)
@@ -1447,7 +1433,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
if (match(Op, m_Trunc(m_Value(X)))) {
auto *TI = cast<TruncInst>(Op);
if (AssumeNonZero || TI->hasNoUnsignedWrap())
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() {
return Builder.CreateTrunc(LogX, Op->getType(), "",
/*IsNUW=*/TI->hasNoUnsignedWrap());
@@ -1460,7 +1446,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
auto *BO = cast<OverflowingBinaryOperator>(Op);
// nuw will be set if the `shl` is trivially non-zero.
if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap())
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
}
@@ -1469,26 +1455,25 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
if (match(Op, m_LShr(m_Value(X), m_Value(Y)))) {
auto *PEO = cast<PossiblyExactOperator>(Op);
if (AssumeNonZero || PEO->isExact())
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return Builder.CreateSub(LogX, Y); });
}
// log2(X & Y) -> either log2(X) or log2(Y)
// This requires `AssumeNonZero` as `X & Y` may be zero when X != Y.
if (AssumeNonZero && match(Op, m_And(m_Value(X), m_Value(Y)))) {
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return LogX; });
- if (Value *LogY = takeLog2(Builder, Y, Depth, AssumeNonZero, DoFold))
+ if (Value *LogY = takeLog2(Y, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return LogY; });
}
// log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
// FIXME: Require one use?
if (SelectInst *SI = dyn_cast<SelectInst>(Op))
- if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth,
- AssumeNonZero, DoFold))
- if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth,
- AssumeNonZero, DoFold))
+ if (Value *LogX = takeLog2(SI->getOperand(1), Depth, AssumeNonZero, DoFold))
+ if (Value *LogY =
+ takeLog2(SI->getOperand(2), Depth, AssumeNonZero, DoFold))
return IfFold([&]() {
return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
});
@@ -1499,9 +1484,9 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) {
// Use AssumeNonZero as false here. Otherwise we can hit case where
// log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
- if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth,
+ if (Value *LogX = takeLog2(MinMax->getLHS(), Depth,
/*AssumeNonZero*/ false, DoFold))
- if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth,
+ if (Value *LogY = takeLog2(MinMax->getRHS(), Depth,
/*AssumeNonZero*/ false, DoFold))
return IfFold([&]() {
return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX,
@@ -1614,13 +1599,9 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
}
// Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
- if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
- /*DoFold*/ false)) {
- Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0,
- /*AssumeNonZero*/ true, /*DoFold*/ true);
+ if (Value *Res = tryGetLog2(Op1, /*AssumeNonZero=*/true))
return replaceInstUsesWith(
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
- }
return nullptr;
}
More information about the llvm-commits
mailing list