[llvm] r281217 - [InstCombine] add helper function for foldICmpUsingKnownBits; NFCI
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 12 08:24:32 PDT 2016
Author: spatel
Date: Mon Sep 12 10:24:31 2016
New Revision: 281217
URL: http://llvm.org/viewvc/llvm-project?rev=281217&view=rev
Log:
[InstCombine] add helper function for foldICmpUsingKnownBits; NFCI
Modified:
llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=281217&r1=281216&r2=281217&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Mon Sep 12 10:24:31 2016
@@ -3135,6 +3135,274 @@ bool InstCombiner::replacedSelectWithOpe
return false;
}
+/// Try to fold the comparison based on range information we can get by checking
+/// whether bits are known to be zero or one in the inputs.
+Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ Type *Ty = Op0->getType();
+
+ // Get scalar or pointer size.
+ unsigned BitWidth = Ty->isIntOrIntVectorTy()
+ ? Ty->getScalarSizeInBits()
+ : DL.getTypeSizeInBits(Ty->getScalarType());
+
+ if (!BitWidth)
+ return nullptr;
+
+ // If this is a normal comparison, it demands all bits. If it is a sign bit
+ // comparison, it only demands the sign bit.
+ bool IsSignBit = false;
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+ bool UnusedBit;
+ IsSignBit = isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit);
+ }
+
+ APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
+ APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
+
+ if (SimplifyDemandedBits(I.getOperandUse(0),
+ DemandedBitsLHSMask(I, BitWidth, IsSignBit),
+ Op0KnownZero, Op0KnownOne, 0))
+ return &I;
+
+ if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth),
+ Op1KnownZero, Op1KnownOne, 0))
+ return &I;
+
+ // Given the known and unknown bits, compute a range that the LHS could be
+ // in. Compute the Min, Max and RHS values based on the known bits. For the
+ // EQ and NE we use unsigned values.
+ APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
+ APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
+ if (I.isSigned()) {
+ ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min,
+ Op0Max);
+ ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min,
+ Op1Max);
+ } else {
+ ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min,
+ Op0Max);
+ ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min,
+ Op1Max);
+ }
+
+ // If Min and Max are known to be the same, then SimplifyDemandedBits
+ // figured out that the LHS is a constant. Just constant fold this now so
+ // that code below can assume that Min != Max.
+ if (!isa<Constant>(Op0) && Op0Min == Op0Max)
+ return new ICmpInst(I.getPredicate(),
+ ConstantInt::get(Op0->getType(), Op0Min), Op1);
+ if (!isa<Constant>(Op1) && Op1Min == Op1Max)
+ return new ICmpInst(I.getPredicate(), Op0,
+ ConstantInt::get(Op1->getType(), Op1Min));
+
+ // Based on the range information we know about the LHS, see if we can
+ // simplify this comparison. For example, (x&4) < 8 is always true.
+ switch (I.getPredicate()) {
+ default:
+ llvm_unreachable("Unknown icmp opcode!");
+ case ICmpInst::ICMP_EQ: {
+ if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+
+ // If all bits are known zero except for one, then we know at most one
+ // bit is set. If the comparison is against zero, then this is a check
+ // to see if *that* bit is set.
+ APInt Op0KnownZeroInverted = ~Op0KnownZero;
+ if (~Op1KnownZero == 0) {
+ // If the LHS is an AND with the same constant, look through it.
+ Value *LHS = nullptr;
+ ConstantInt *LHSC = nullptr;
+ if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) ||
+ LHSC->getValue() != Op0KnownZeroInverted)
+ LHS = Op0;
+
+ // If the LHS is 1 << x, and we know the result is a power of 2 like 8,
+ // then turn "((1 << x)&8) == 0" into "x != 3".
+ // or turn "((1 << x)&7) == 0" into "x > 2".
+ Value *X = nullptr;
+ if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
+ APInt ValToCheck = Op0KnownZeroInverted;
+ if (ValToCheck.isPowerOf2()) {
+ unsigned CmpVal = ValToCheck.countTrailingZeros();
+ return new ICmpInst(ICmpInst::ICMP_NE, X,
+ ConstantInt::get(X->getType(), CmpVal));
+ } else if ((++ValToCheck).isPowerOf2()) {
+ unsigned CmpVal = ValToCheck.countTrailingZeros() - 1;
+ return new ICmpInst(ICmpInst::ICMP_UGT, X,
+ ConstantInt::get(X->getType(), CmpVal));
+ }
+ }
+
+ // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
+ // then turn "((8 >>u x)&1) == 0" into "x != 3".
+ const APInt *CI;
+ if (Op0KnownZeroInverted == 1 &&
+ match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
+ return new ICmpInst(
+ ICmpInst::ICMP_NE, X,
+ ConstantInt::get(X->getType(), CI->countTrailingZeros()));
+ }
+ break;
+ }
+ case ICmpInst::ICMP_NE: {
+ if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+
+ // If all bits are known zero except for one, then we know at most one
+ // bit is set. If the comparison is against zero, then this is a check
+ // to see if *that* bit is set.
+ APInt Op0KnownZeroInverted = ~Op0KnownZero;
+ if (~Op1KnownZero == 0) {
+ // If the LHS is an AND with the same constant, look through it.
+ Value *LHS = nullptr;
+ ConstantInt *LHSC = nullptr;
+ if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) ||
+ LHSC->getValue() != Op0KnownZeroInverted)
+ LHS = Op0;
+
+ // If the LHS is 1 << x, and we know the result is a power of 2 like 8,
+ // then turn "((1 << x)&8) != 0" into "x == 3".
+ // or turn "((1 << x)&7) != 0" into "x < 3".
+ Value *X = nullptr;
+ if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
+ APInt ValToCheck = Op0KnownZeroInverted;
+ if (ValToCheck.isPowerOf2()) {
+ unsigned CmpVal = ValToCheck.countTrailingZeros();
+ return new ICmpInst(ICmpInst::ICMP_EQ, X,
+ ConstantInt::get(X->getType(), CmpVal));
+ } else if ((++ValToCheck).isPowerOf2()) {
+ unsigned CmpVal = ValToCheck.countTrailingZeros();
+ return new ICmpInst(ICmpInst::ICMP_ULT, X,
+ ConstantInt::get(X->getType(), CmpVal));
+ }
+ }
+
+ // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
+ // then turn "((8 >>u x)&1) != 0" into "x == 3".
+ const APInt *CI;
+ if (Op0KnownZeroInverted == 1 &&
+ match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
+ return new ICmpInst(
+ ICmpInst::ICMP_EQ, X,
+ ConstantInt::get(X->getType(), CI->countTrailingZeros()));
+ }
+ break;
+ }
+ case ICmpInst::ICMP_ULT: {
+ if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+
+ const APInt *CmpC;
+ if (match(Op1, m_APInt(CmpC))) {
+ // A <u C -> A == C-1 if min(A)+1 == C
+ if (Op1Max == Op0Min + 1) {
+ Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1);
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1);
+ }
+ // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear
+ if (CmpC->isMinSignedValue()) {
+ Constant *AllOnes = Constant::getAllOnesValue(Op0->getType());
+ return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes);
+ }
+ }
+ break;
+ }
+ case ICmpInst::ICMP_UGT: {
+ if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+
+ if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+
+ if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+
+ const APInt *CmpC;
+ if (match(Op1, m_APInt(CmpC))) {
+ // A >u C -> A == C+1 if max(a)-1 == C
+ if (*CmpC == Op0Max - 1)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), *CmpC + 1));
+
+ // (x >u 2147483647) -> (x <s 0) -> true if sign bit set
+ if (CmpC->isMaxSignedValue())
+ return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
+ Constant::getNullValue(Op0->getType()));
+ }
+ break;
+ }
+ case ICmpInst::ICMP_SLT:
+ if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+ if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Builder->getInt(CI->getValue() - 1));
+ }
+ break;
+ case ICmpInst::ICMP_SGT:
+ if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+
+ if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+ if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Builder->getInt(CI->getValue() + 1));
+ }
+ break;
+ case ICmpInst::ICMP_SGE:
+ assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
+ if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ break;
+ case ICmpInst::ICMP_SLE:
+ assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
+ if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ break;
+ case ICmpInst::ICMP_UGE:
+ assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
+ if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ break;
+ case ICmpInst::ICMP_ULE:
+ assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
+ if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
+ return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
+ return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ break;
+ }
+
+ // Turn a signed comparison into an unsigned one if both operands are known to
+ // have the same sign.
+ if (I.isSigned() &&
+ ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) ||
+ (Op0KnownOne.isNegative() && Op1KnownOne.isNegative())))
+ return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
+
+ return nullptr;
+}
+
/// If we have an icmp le or icmp ge instruction with a constant operand, turn
/// it into the appropriate icmp lt or icmp gt instruction. This transform
/// allows them to be folded in visitICmpInst.
@@ -3276,14 +3544,6 @@ Instruction *InstCombiner::visitICmpInst
if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I))
return NewICmp;
- unsigned BitWidth = 0;
- if (Ty->isIntOrIntVectorTy())
- BitWidth = Ty->getScalarSizeInBits();
- else // Get pointer size.
- BitWidth = DL.getTypeSizeInBits(Ty->getScalarType());
-
- bool isSignBit = false;
-
// See if we are doing a comparison with a constant.
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
Value *A = nullptr, *B = nullptr;
@@ -3365,11 +3625,6 @@ Instruction *InstCombiner::visitICmpInst
}
}
- // If this comparison is a normal comparison, it demands all
- // bits, if it is a sign bit comparison, it only demands the sign bit.
- bool UnusedBit;
- isSignBit = isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit);
-
// Canonicalize icmp instructions based on dominating conditions.
BasicBlock *Parent = I.getParent();
BasicBlock *Dom = Parent->getSinglePredecessor();
@@ -3393,12 +3648,19 @@ Instruction *InstCombiner::visitICmpInst
return replaceInstUsesWith(I, Builder->getFalse());
if (Difference.isEmptySet())
return replaceInstUsesWith(I, Builder->getTrue());
+
+ // If this is a normal comparison, it demands all bits. If it is a sign
+ // bit comparison, it only demands the sign bit.
+ bool UnusedBit;
+ bool IsSignBit =
+ isSignBitCheck(I.getPredicate(), CI->getValue(), UnusedBit);
+
// Canonicalizing a sign bit comparison that gets used in a branch,
// pessimizes codegen by generating branch on zero instruction instead
// of a test and branch. So we avoid canonicalizing in such situations
// because test and branch instruction has better branch displacement
// than compare and branch instruction.
- if (!isBranchOnSignBitCheck(I, isSignBit) && !I.isEquality()) {
+ if (!isBranchOnSignBitCheck(I, IsSignBit) && !I.isEquality()) {
if (auto *AI = Intersection.getSingleElement())
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI));
if (auto *AD = Difference.getSingleElement())
@@ -3407,251 +3669,8 @@ Instruction *InstCombiner::visitICmpInst
}
}
- // See if we can fold the comparison based on range information we can get
- // by checking whether bits are known to be zero or one in the input.
- if (BitWidth != 0) {
- APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
- APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
-
- if (SimplifyDemandedBits(I.getOperandUse(0),
- DemandedBitsLHSMask(I, BitWidth, isSignBit),
- Op0KnownZero, Op0KnownOne, 0))
- return &I;
- if (SimplifyDemandedBits(I.getOperandUse(1),
- APInt::getAllOnesValue(BitWidth), Op1KnownZero,
- Op1KnownOne, 0))
- return &I;
-
- // Given the known and unknown bits, compute a range that the LHS could be
- // in. Compute the Min, Max and RHS values based on the known bits. For the
- // EQ and NE we use unsigned values.
- APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
- APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
- if (I.isSigned()) {
- ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
- Op0Min, Op0Max);
- ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
- Op1Min, Op1Max);
- } else {
- ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
- Op0Min, Op0Max);
- ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
- Op1Min, Op1Max);
- }
-
- // If Min and Max are known to be the same, then SimplifyDemandedBits
- // figured out that the LHS is a constant. Just constant fold this now so
- // that code below can assume that Min != Max.
- if (!isa<Constant>(Op0) && Op0Min == Op0Max)
- return new ICmpInst(I.getPredicate(),
- ConstantInt::get(Op0->getType(), Op0Min), Op1);
- if (!isa<Constant>(Op1) && Op1Min == Op1Max)
- return new ICmpInst(I.getPredicate(), Op0,
- ConstantInt::get(Op1->getType(), Op1Min));
-
- // Based on the range information we know about the LHS, see if we can
- // simplify this comparison. For example, (x&4) < 8 is always true.
- switch (I.getPredicate()) {
- default: llvm_unreachable("Unknown icmp opcode!");
- case ICmpInst::ICMP_EQ: {
- if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-
- // If all bits are known zero except for one, then we know at most one
- // bit is set. If the comparison is against zero, then this is a check
- // to see if *that* bit is set.
- APInt Op0KnownZeroInverted = ~Op0KnownZero;
- if (~Op1KnownZero == 0) {
- // If the LHS is an AND with the same constant, look through it.
- Value *LHS = nullptr;
- ConstantInt *LHSC = nullptr;
- if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) ||
- LHSC->getValue() != Op0KnownZeroInverted)
- LHS = Op0;
-
- // If the LHS is 1 << x, and we know the result is a power of 2 like 8,
- // then turn "((1 << x)&8) == 0" into "x != 3".
- // or turn "((1 << x)&7) == 0" into "x > 2".
- Value *X = nullptr;
- if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
- APInt ValToCheck = Op0KnownZeroInverted;
- if (ValToCheck.isPowerOf2()) {
- unsigned CmpVal = ValToCheck.countTrailingZeros();
- return new ICmpInst(ICmpInst::ICMP_NE, X,
- ConstantInt::get(X->getType(), CmpVal));
- } else if ((++ValToCheck).isPowerOf2()) {
- unsigned CmpVal = ValToCheck.countTrailingZeros() - 1;
- return new ICmpInst(ICmpInst::ICMP_UGT, X,
- ConstantInt::get(X->getType(), CmpVal));
- }
- }
-
- // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
- // then turn "((8 >>u x)&1) == 0" into "x != 3".
- const APInt *CI;
- if (Op0KnownZeroInverted == 1 &&
- match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
- return new ICmpInst(ICmpInst::ICMP_NE, X,
- ConstantInt::get(X->getType(),
- CI->countTrailingZeros()));
- }
- break;
- }
- case ICmpInst::ICMP_NE: {
- if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-
- // If all bits are known zero except for one, then we know at most one
- // bit is set. If the comparison is against zero, then this is a check
- // to see if *that* bit is set.
- APInt Op0KnownZeroInverted = ~Op0KnownZero;
- if (~Op1KnownZero == 0) {
- // If the LHS is an AND with the same constant, look through it.
- Value *LHS = nullptr;
- ConstantInt *LHSC = nullptr;
- if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) ||
- LHSC->getValue() != Op0KnownZeroInverted)
- LHS = Op0;
-
- // If the LHS is 1 << x, and we know the result is a power of 2 like 8,
- // then turn "((1 << x)&8) != 0" into "x == 3".
- // or turn "((1 << x)&7) != 0" into "x < 3".
- Value *X = nullptr;
- if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
- APInt ValToCheck = Op0KnownZeroInverted;
- if (ValToCheck.isPowerOf2()) {
- unsigned CmpVal = ValToCheck.countTrailingZeros();
- return new ICmpInst(ICmpInst::ICMP_EQ, X,
- ConstantInt::get(X->getType(), CmpVal));
- } else if ((++ValToCheck).isPowerOf2()) {
- unsigned CmpVal = ValToCheck.countTrailingZeros();
- return new ICmpInst(ICmpInst::ICMP_ULT, X,
- ConstantInt::get(X->getType(), CmpVal));
- }
- }
-
- // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
- // then turn "((8 >>u x)&1) != 0" into "x == 3".
- const APInt *CI;
- if (Op0KnownZeroInverted == 1 &&
- match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
- return new ICmpInst(ICmpInst::ICMP_EQ, X,
- ConstantInt::get(X->getType(),
- CI->countTrailingZeros()));
- }
- break;
- }
- case ICmpInst::ICMP_ULT: {
- if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
- return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A <u C -> A == C-1 if min(A)+1 == C
- if (Op1Max == Op0Min + 1) {
- Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1);
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1);
- }
- // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear
- if (CmpC->isMinSignedValue()) {
- Constant *AllOnes = Constant::getAllOnesValue(Op0->getType());
- return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes);
- }
- }
- break;
- }
- case ICmpInst::ICMP_UGT: {
- if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-
- if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-
- if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
- return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A >u C -> A == C+1 if max(a)-1 == C
- if (*CmpC == Op0Max - 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
-
- // (x >u 2147483647) -> (x <s 0) -> true if sign bit set
- if (CmpC->isMaxSignedValue())
- return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
- Constant::getNullValue(Op0->getType()));
- }
- break;
- }
- case ICmpInst::ICMP_SLT:
- if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
- return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
- if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Builder->getInt(CI->getValue()-1));
- }
- break;
- case ICmpInst::ICMP_SGT:
- if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-
- if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
- return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
- if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Builder->getInt(CI->getValue()+1));
- }
- break;
- case ICmpInst::ICMP_SGE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
- if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- case ICmpInst::ICMP_SLE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
- if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- case ICmpInst::ICMP_UGE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
- if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- case ICmpInst::ICMP_ULE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
- if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
-
- // Turn a signed comparison into an unsigned one if both operands
- // are known to have the same sign.
- if (I.isSigned() &&
- ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) ||
- (Op0KnownOne.isNegative() && Op1KnownOne.isNegative())))
- return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
- }
+ if (Instruction *Res = foldICmpUsingKnownBits(I))
+ return Res;
// Test if the ICmpInst instruction is used exclusively by a select as
// part of a minimum or maximum operation. If so, refrain from doing
Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h?rev=281217&r1=281216&r2=281217&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h Mon Sep 12 10:24:31 2016
@@ -556,6 +556,7 @@ private:
ICmpInst::Predicate Pred);
Instruction *foldICmpWithCastAndCast(ICmpInst &ICI);
+ Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
Instruction *foldICmpInstWithConstant(ICmpInst &Cmp);
Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc,
More information about the llvm-commits
mailing list