[llvm] goldsteinn/cttz ctlz with p2 (PR #122512)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 10 11:21:55 PST 2025
https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/122512
- **[InstCombine] Move `takeLog2` to InstructionCombiner; NFC**
- **[InstCombine] Add convenience helper `tryGetLog2`; NFC**
- **Fixup**
- **cttz/ctlz log2**
- **cttz/ctlz of p2 canonicalize**
- **Shifts**
>From b21edc4e29f08984f08fa3e84a8a2fa63ee3e791 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 10 Jan 2025 11:32:09 -0600
Subject: [PATCH 1/6] [InstCombine] Move `takeLog2` to InstructionCombiner; NFC
---
.../Transforms/InstCombine/InstCombiner.h | 6 +
.../InstCombine/InstCombineMulDivRem.cpp | 120 +-----------------
.../InstCombine/InstructionCombining.cpp | 99 +++++++++++++++
3 files changed, 111 insertions(+), 114 deletions(-)
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index fa6b60cba15aaf..8a87ee0839b7bb 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -195,6 +195,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
PatternMatch::m_Value()));
}
+ // Take the exact integer log2 of the value. If DoFold is true, create the
+ // actual instructions, otherwise return a non-null dummy value. Return
+ // nullptr on failure. Note, if DoFold is true the caller must ensure that
+ // takeLog2 will succeed, otherwise it may create stray instructions.
+ Value *takeLog2(Value *Op, unsigned Depth, bool AssumeNonZero, bool DoFold);
+
/// Return nonnull value if V is free to invert under the condition of
/// WillInvertAllUses.
/// If Builder is nonnull, it will return a simplified ~V.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 0c34cf01bdf1a9..576c3bc585db12 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -185,9 +185,6 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands,
return nullptr;
}
-static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
- bool AssumeNonZero, bool DoFold);
-
Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V =
@@ -531,18 +528,18 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
// (shl Op1, Log2(Op0))
// if Log2(Op1) folds away ->
// (shl Op0, Log2(Op1))
- if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ if (takeLog2(Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
/*DoFold*/ false)) {
- Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ Value *Res = takeLog2(Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
/*DoFold*/ true);
BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res);
// We can only propegate nuw flag.
Shl->setHasNoUnsignedWrap(HasNUW);
return Shl;
}
- if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ if (takeLog2(Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
/*DoFold*/ false)) {
- Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ Value *Res = takeLog2(Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
/*DoFold*/ true);
BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
// We can only propegate nuw flag.
@@ -1407,111 +1404,6 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
return nullptr;
}
-static const unsigned MaxDepth = 6;
-
-// Take the exact integer log2 of the value. If DoFold is true, create the
-// actual instructions, otherwise return a non-null dummy value. Return nullptr
-// on failure.
-static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
- bool AssumeNonZero, bool DoFold) {
- auto IfFold = [DoFold](function_ref<Value *()> Fn) {
- if (!DoFold)
- return reinterpret_cast<Value *>(-1);
- return Fn();
- };
-
- // FIXME: assert that Op1 isn't/doesn't contain undef.
-
- // log2(2^C) -> C
- if (match(Op, m_Power2()))
- return IfFold([&]() {
- Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(Op));
- if (!C)
- llvm_unreachable("Failed to constant fold udiv -> logbase2");
- return C;
- });
-
- // The remaining tests are all recursive, so bail out if we hit the limit.
- if (Depth++ == MaxDepth)
- return nullptr;
-
- // log2(zext X) -> zext log2(X)
- // FIXME: Require one use?
- Value *X, *Y;
- if (match(Op, m_ZExt(m_Value(X))))
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
- return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
-
- // log2(trunc x) -> trunc log2(X)
- // FIXME: Require one use?
- if (match(Op, m_Trunc(m_Value(X)))) {
- auto *TI = cast<TruncInst>(Op);
- if (AssumeNonZero || TI->hasNoUnsignedWrap())
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
- return IfFold([&]() {
- return Builder.CreateTrunc(LogX, Op->getType(), "",
- /*IsNUW=*/TI->hasNoUnsignedWrap());
- });
- }
-
- // log2(X << Y) -> log2(X) + Y
- // FIXME: Require one use unless X is 1?
- if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
- auto *BO = cast<OverflowingBinaryOperator>(Op);
- // nuw will be set if the `shl` is trivially non-zero.
- if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap())
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
- return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
- }
-
- // log2(X >>u Y) -> log2(X) - Y
- // FIXME: Require one use?
- if (match(Op, m_LShr(m_Value(X), m_Value(Y)))) {
- auto *PEO = cast<PossiblyExactOperator>(Op);
- if (AssumeNonZero || PEO->isExact())
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
- return IfFold([&]() { return Builder.CreateSub(LogX, Y); });
- }
-
- // log2(X & Y) -> either log2(X) or log2(Y)
- // This requires `AssumeNonZero` as `X & Y` may be zero when X != Y.
- if (AssumeNonZero && match(Op, m_And(m_Value(X), m_Value(Y)))) {
- if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
- return IfFold([&]() { return LogX; });
- if (Value *LogY = takeLog2(Builder, Y, Depth, AssumeNonZero, DoFold))
- return IfFold([&]() { return LogY; });
- }
-
- // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
- // FIXME: Require one use?
- if (SelectInst *SI = dyn_cast<SelectInst>(Op))
- if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth,
- AssumeNonZero, DoFold))
- if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth,
- AssumeNonZero, DoFold))
- return IfFold([&]() {
- return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
- });
-
- // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
- // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
- auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op);
- if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) {
- // Use AssumeNonZero as false here. Otherwise we can hit case where
- // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
- if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth,
- /*AssumeNonZero*/ false, DoFold))
- if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth,
- /*AssumeNonZero*/ false, DoFold))
- return IfFold([&]() {
- return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX,
- LogY);
- });
- }
-
- return nullptr;
-}
-
/// If we have zero-extended operands of an unsigned div or rem, we may be able
/// to narrow the operation (sink the zext below the math).
static Instruction *narrowUDivURem(BinaryOperator &I,
@@ -1614,9 +1506,9 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
}
// Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
- if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
+ if (takeLog2(Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
/*DoFold*/ false)) {
- Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0,
+ Value *Res = takeLog2(Op1, /*Depth*/ 0,
/*AssumeNonZero*/ true, /*DoFold*/ true);
return replaceInstUsesWith(
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 2fb60ef11499c7..ede7aa48cd6401 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2565,6 +2565,105 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
return nullptr;
}
+Value *InstCombiner::takeLog2(Value *Op, unsigned Depth, bool AssumeNonZero,
+ bool DoFold) {
+ auto IfFold = [DoFold](function_ref<Value *()> Fn) {
+ if (!DoFold)
+ return reinterpret_cast<Value *>(-1);
+ return Fn();
+ };
+
+ // FIXME: assert that Op1 isn't/doesn't contain undef.
+
+ // log2(2^C) -> C
+ if (match(Op, m_Power2()))
+ return IfFold([&]() {
+ Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(Op));
+ if (!C)
+ llvm_unreachable("Failed to constant fold udiv -> logbase2");
+ return C;
+ });
+
+ // The remaining tests are all recursive, so bail out if we hit the limit.
+ if (Depth++ == MaxAnalysisRecursionDepth)
+ return nullptr;
+
+ // log2(zext X) -> zext log2(X)
+ // FIXME: Require one use?
+ Value *X, *Y;
+ if (match(Op, m_ZExt(m_Value(X))))
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
+
+ // log2(trunc x) -> trunc log2(X)
+ // FIXME: Require one use?
+ if (match(Op, m_Trunc(m_Value(X)))) {
+ auto *TI = cast<TruncInst>(Op);
+ if (AssumeNonZero || TI->hasNoUnsignedWrap())
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() {
+ return Builder.CreateTrunc(LogX, Op->getType(), "",
+ /*IsNUW=*/TI->hasNoUnsignedWrap());
+ });
+ }
+
+ // log2(X << Y) -> log2(X) + Y
+ // FIXME: Require one use unless X is 1?
+ if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
+ auto *BO = cast<OverflowingBinaryOperator>(Op);
+ // nuw will be set if the `shl` is trivially non-zero.
+ if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap())
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+ }
+
+ // log2(X >>u Y) -> log2(X) - Y
+ // FIXME: Require one use?
+ if (match(Op, m_LShr(m_Value(X), m_Value(Y)))) {
+ auto *PEO = cast<PossiblyExactOperator>(Op);
+ if (AssumeNonZero || PEO->isExact())
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() { return Builder.CreateSub(LogX, Y); });
+ }
+
+ // log2(X & Y) -> either log2(X) or log2(Y)
+ // This requires `AssumeNonZero` as `X & Y` may be zero when X != Y.
+ if (AssumeNonZero && match(Op, m_And(m_Value(X), m_Value(Y)))) {
+ if (Value *LogX = takeLog2(X, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() { return LogX; });
+ if (Value *LogY = takeLog2(Y, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() { return LogY; });
+ }
+
+ // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
+ // FIXME: Require one use?
+ if (SelectInst *SI = dyn_cast<SelectInst>(Op))
+ if (Value *LogX = takeLog2(SI->getOperand(1), Depth, AssumeNonZero, DoFold))
+ if (Value *LogY =
+ takeLog2(SI->getOperand(2), Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() {
+ return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
+ });
+
+ // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
+ // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
+ auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op);
+ if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) {
+ // Use AssumeNonZero as false here. Otherwise we can hit case where
+ // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
+ if (Value *LogX = takeLog2(MinMax->getLHS(), Depth,
+ /*AssumeNonZero*/ false, DoFold))
+ if (Value *LogY = takeLog2(MinMax->getRHS(), Depth,
+ /*AssumeNonZero*/ false, DoFold))
+ return IfFold([&]() {
+ return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX,
+ LogY);
+ });
+ }
+
+ return nullptr;
+}
+
Value *InstCombiner::getFreelyInvertedImpl(Value *V, bool WillInvertAllUses,
BuilderTy *Builder,
bool &DoesConsume, unsigned Depth) {
>From fea0bded1aff289c5094fe10c211980dd780ab4d Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 10 Jan 2025 11:35:47 -0600
Subject: [PATCH 2/6] [InstCombine] Add convenience helper `tryGetLog2`; NFC
This just encapsulates the common pattern:
```
if (takeLog2(..., /*DoFold=*/false)) {
Value * Log2 = takeLog2(..., /*DoFold=*/true);
...
}
```
---
.../llvm/Transforms/InstCombine/InstCombiner.h | 6 ++++++
.../InstCombine/InstCombineMulDivRem.cpp | 16 +++-------------
2 files changed, 9 insertions(+), 13 deletions(-)
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 8a87ee0839b7bb..213b1b73bed067 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -201,6 +201,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
// takeLog2 will succeed, otherwise it may create stray instructions.
Value *takeLog2(Value *Op, unsigned Depth, bool AssumeNonZero, bool DoFold);
+ Value *tryGetLog2(Value *Op, bool AssumeNonZero) {
+ if (takeLog2(Op, /*Depth=*/0, AssumeNonZero, /*DoFold=*/false))
+ return takeLog2(Op, /*Depth=*/0, AssumeNonZero, /*DoFold=*/true);
+ return nullptr;
+ }
+
/// Return nonnull value if V is free to invert under the condition of
/// WillInvertAllUses.
/// If Builder is nonnull, it will return a simplified ~V.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 576c3bc585db12..e89a2c9579f422 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -528,19 +528,13 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
// (shl Op1, Log2(Op0))
// if Log2(Op1) folds away ->
// (shl Op0, Log2(Op1))
- if (takeLog2(Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
- /*DoFold*/ false)) {
- Value *Res = takeLog2(Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
- /*DoFold*/ true);
+ if (Value *Res = tryGetLog2(Op0, /*AssumeNonZero=*/false)) {
BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res);
// We can only propegate nuw flag.
Shl->setHasNoUnsignedWrap(HasNUW);
return Shl;
}
- if (takeLog2(Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
- /*DoFold*/ false)) {
- Value *Res = takeLog2(Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
- /*DoFold*/ true);
+ if (Value *Res = tryGetLog2(Op0, /*AssumeNonZero=*/false)) {
BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
// We can only propegate nuw flag.
Shl->setHasNoUnsignedWrap(HasNUW);
@@ -1506,13 +1500,9 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
}
// Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
- if (takeLog2(Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
- /*DoFold*/ false)) {
- Value *Res = takeLog2(Op1, /*Depth*/ 0,
- /*AssumeNonZero*/ true, /*DoFold*/ true);
+ if (Value *Res = tryGetLog2(Op1, /*AssumeNonZero=*/true))
return replaceInstUsesWith(
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
- }
return nullptr;
}
>From 305c6c86e3ee7113613119b60f8de873e09d367b Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 10 Jan 2025 11:52:02 -0600
Subject: [PATCH 3/6] Fixup
---
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index e89a2c9579f422..b275b9bebdb125 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -534,7 +534,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
Shl->setHasNoUnsignedWrap(HasNUW);
return Shl;
}
- if (Value *Res = tryGetLog2(Op0, /*AssumeNonZero=*/false)) {
+ if (Value *Res = tryGetLog2(Op1, /*AssumeNonZero=*/false)) {
BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
// We can only propegate nuw flag.
Shl->setHasNoUnsignedWrap(HasNUW);
>From e3b9d314e17abb6bc9411a475eca2e4676675997 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 10 Jan 2025 13:16:53 -0600
Subject: [PATCH 4/6] cttz/ctlz log2
---
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index c55c40c88bc845..af773a82912dd9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -588,6 +588,14 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
}
}
+ if (auto *R = IC.tryGetLog2(Op0, match(Op1, m_One()))) {
+ if (IsTZ)
+ return IC.replaceInstUsesWith(II, R);
+ Type *Ty = Op0->getType();
+ return BinaryOperator::CreateXor(
+ R, ConstantInt::get(Ty, Ty->getScalarSizeInBits() - 1U));
+ }
+
KnownBits Known = IC.computeKnownBits(Op0, 0, &II);
// Create a mask for bits above (ctlz) or below (cttz) the first known one.
>From 015d2e668f1f9d1df4dd6d97b84d5c101adb2b7b Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 10 Jan 2025 13:17:07 -0600
Subject: [PATCH 5/6] cttz/ctlz of p2 canonicalize
---
.../InstCombine/InstCombineAndOrXor.cpp | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index f82a557e5760c8..4821b0bcd3dcc1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -4985,5 +4985,22 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
if (Instruction *Res = foldBitwiseLogicWithIntrinsics(I, Builder))
return Res;
+ // (xor BitWidth - 1, (ctlz X_Pow2)) -> (cttz X_Pow2)
+ // (xor BitWidth - 1, (cttz X_Pow2)) -> (ctlz X_Pow2)
+ if (match(Op1, m_SpecificInt(Op0->getType()->getScalarSizeInBits() - 1)) &&
+ Op0->hasOneUse()) {
+ Intrinsic::ID ID = Intrinsic::not_intrinsic;
+ if (match(Op0, m_Intrinsic<Intrinsic::ctlz>(m_Value(X), m_Value(Y))))
+ ID = Intrinsic::cttz;
+ else if (match(Op0, m_Intrinsic<Intrinsic::cttz>(m_Value(X), m_Value(Y))))
+ ID = Intrinsic::ctlz;
+
+ if (ID != Intrinsic::not_intrinsic &&
+ isKnownToBeAPowerOfTwo(X, match(Y, m_One())))
+ return replaceInstUsesWith(
+ I, Builder.CreateBinaryIntrinsic(ID, X,
+ ConstantInt::getTrue(Y->getType())));
+ }
+
return nullptr;
}
>From 1b4d19ba9521f5e824d8bd1a9fdd121a50673606 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Fri, 10 Jan 2025 13:17:13 -0600
Subject: [PATCH 6/6] Shifts
---
.../InstCombine/InstCombineShifts.cpp | 77 +++++++++++++++++++
1 file changed, 77 insertions(+)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index d511e79e3e48ae..8c36b0ad3e9f1d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -12,6 +12,7 @@
#include "InstCombineInternal.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
@@ -978,6 +979,76 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
return new ZExtInst(Overflow, Ty);
}
+// Various folds for (shl/lshr C, (cttz/ctlz X_P2))
+static Instruction *foldLogicalShiftWithCtzOfP2(BinaryOperator &I,
+ InstCombinerImpl &IC) {
+ assert(I.getOpcode() == Instruction::Shl ||
+ I.getOpcode() == Instruction::LShr && "Operator is not logical shift");
+ Value *Op0 = I.getOperand(0);
+ Value *Op1 = I.getOperand(1);
+ Value *CtzOp = nullptr;
+ Value *CtzMode = nullptr;
+ Intrinsic::ID ID;
+ if (I.getOpcode() == Instruction::Shl &&
+ match(Op1,
+ m_Intrinsic<Intrinsic::cttz>(m_Value(CtzOp), m_Value(CtzMode))))
+ ID = Intrinsic::cttz;
+ else if (I.getOpcode() == Instruction::LShr &&
+ match(Op1, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtzOp),
+ m_Value(CtzMode))))
+ ID = Intrinsic::ctlz;
+ else
+ return nullptr;
+
+ const APInt *C;
+ if (!match(Op0, m_APInt(C)))
+ return nullptr;
+
+ // TODO: We could extend to handle count trailing/leading ones by handling if
+ // CtzOp is (xor X, -1).
+
+ // (shl C_P2, (cttz X_P2)) -> (shl X_P2, (cttz C_P2))
+ // (lshr C_P2, (ctlz X_P2)) -> (lshr X_P2, (ctlz C_P2))
+ if (C->isPowerOf2() && IC.isKnownToBeAPowerOfTwo(CtzOp, /*OrZero=*/true)) {
+ Value *NewShAmt = ConstantInt::get(
+ CtzOp->getType(), ID == Intrinsic::cttz ? C->countTrailingZeros()
+ : C->countLeadingZeros());
+ if (I.getOpcode() == Instruction::Shl) {
+ BinaryOperator *NewBO = BinaryOperator::CreateShl(CtzOp, NewShAmt);
+ NewBO->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ return NewBO;
+ }
+ BinaryOperator *NewBO = BinaryOperator::CreateLShr(CtzOp, NewShAmt);
+ NewBO->setIsExact(I.isExact());
+ return NewBO;
+ }
+ // (shl -C_P2, (cttz X_P2)) -> (shl -X_P2, (cttz -C_P2))
+ if (I.getOpcode() == Instruction::Shl && C->isNegatedPowerOf2() &&
+ CtzOp->hasOneUse() && IC.isKnownToBeAPowerOfTwo(CtzOp, /*OrZero=*/true)) {
+ Value *NewShAmt =
+ ConstantInt::get(CtzOp->getType(), C->countTrailingZeros());
+ BinaryOperator *NewBO = BinaryOperator::CreateShl(CtzOp, NewShAmt);
+ NewBO->setHasNoSignedWrap(I.hasNoSignedWrap());
+ return NewBO;
+ }
+ // (lshr C_Mask, (ctlz X_P2))
+ // if C_Mask != -1:
+ // (lshr (sub X_P2, 1), (cttz ~C_Mask))
+ // if C_Mask == -1:
+ // (or disjoint (sub X_P2, 1), X_P2)
+ if (I.getOpcode() == Instruction::LShr && C->isMask() && CtzOp->hasOneUse() &&
+ IC.isKnownToBeAPowerOfTwo(CtzOp, /*OrZero=*/true)) {
+ Value *CtzOpMask = IC.Builder.CreateAdd(
+ CtzOp, Constant::getAllOnesValue(CtzOp->getType()));
+ if (C->isAllOnes())
+ return BinaryOperator::CreateDisjointOr(CtzOp, CtzOpMask);
+ Value *NewShAmt =
+ ConstantInt::get(CtzOp->getType(), C->countTrailingOnes());
+ return BinaryOperator::CreateLShr(CtzOpMask, NewShAmt);
+ }
+ return nullptr;
+}
+
// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
assert(I.isShift() && "Expected a shift as input");
@@ -1266,6 +1337,9 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
}
}
+ if (auto *R = foldLogicalShiftWithCtzOfP2(I, *this))
+ return R;
+
return nullptr;
}
@@ -1613,6 +1687,9 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
if (Instruction *Overflow = foldLShrOverflowBit(I))
return Overflow;
+ if (auto *R = foldLogicalShiftWithCtzOfP2(I, *this))
+ return R;
+
return nullptr;
}
More information about the llvm-commits
mailing list