[llvm] 5350e1b - [KnownBits] Implement accurate unsigned and signed max and min
Jay Foad via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 7 01:10:15 PDT 2020
Author: Jay Foad
Date: 2020-09-07T09:09:01+01:00
New Revision: 5350e1b5096aa4707aa525baf7398d93b4a4f1a5
URL: https://github.com/llvm/llvm-project/commit/5350e1b5096aa4707aa525baf7398d93b4a4f1a5
DIFF: https://github.com/llvm/llvm-project/commit/5350e1b5096aa4707aa525baf7398d93b4a4f1a5.diff
LOG: [KnownBits] Implement accurate unsigned and signed max and min
Use the new implementation in ValueTracking, SelectionDAG and
GlobalISel.
Differential Revision: https://reviews.llvm.org/D87034
Added:
Modified:
llvm/include/llvm/Support/KnownBits.h
llvm/lib/Analysis/ValueTracking.cpp
llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/Support/KnownBits.cpp
llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
llvm/unittests/Support/KnownBitsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 5b3de63cd359..a29e150b904a 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -173,6 +173,10 @@ struct KnownBits {
One.extractBits(NumBits, BitPosition));
}
+ /// Return KnownBits based on this, but updated given that the underlying
+ /// value is known to be greater than or equal to Val.
+ KnownBits makeGE(const APInt &Val) const;
+
/// Returns the minimum number of trailing zero bits.
unsigned countMinTrailingZeros() const {
return Zero.countTrailingOnes();
@@ -241,6 +245,18 @@ struct KnownBits {
static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
KnownBits RHS);
+ /// Compute known bits for umax(LHS, RHS).
+ static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS);
+
+ /// Compute known bits for umin(LHS, RHS).
+ static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS);
+
+ /// Compute known bits for smax(LHS, RHS).
+ static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS);
+
+ /// Compute known bits for smin(LHS, RHS).
+ static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
+
/// Insert the bits from a smaller known bits starting at bitPosition.
void insertBits(const KnownBits &SubBits, unsigned BitPosition) {
Zero.insertBits(SubBits.Zero, BitPosition);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 3a6ee355c646..6e5a7195bb19 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1212,59 +1212,41 @@ static void computeKnownBitsFromOperator(const Operator *I,
if (SelectPatternResult::isMinOrMax(SPF)) {
computeKnownBits(RHS, Known, Depth + 1, Q);
computeKnownBits(LHS, Known2, Depth + 1, Q);
- } else {
- computeKnownBits(I->getOperand(2), Known, Depth + 1, Q);
- computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+ switch (SPF) {
+ default:
+ llvm_unreachable("Unhandled select pattern flavor!");
+ case SPF_SMAX:
+ Known = KnownBits::smax(Known, Known2);
+ break;
+ case SPF_SMIN:
+ Known = KnownBits::smin(Known, Known2);
+ break;
+ case SPF_UMAX:
+ Known = KnownBits::umax(Known, Known2);
+ break;
+ case SPF_UMIN:
+ Known = KnownBits::umin(Known, Known2);
+ break;
+ }
+ break;
}
- unsigned MaxHighOnes = 0;
- unsigned MaxHighZeros = 0;
- if (SPF == SPF_SMAX) {
- // If both sides are negative, the result is negative.
- if (Known.isNegative() && Known2.isNegative())
- // We can derive a lower bound on the result by taking the max of the
- // leading one bits.
- MaxHighOnes =
- std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes());
- // If either side is non-negative, the result is non-negative.
- else if (Known.isNonNegative() || Known2.isNonNegative())
- MaxHighZeros = 1;
- } else if (SPF == SPF_SMIN) {
- // If both sides are non-negative, the result is non-negative.
- if (Known.isNonNegative() && Known2.isNonNegative())
- // We can derive an upper bound on the result by taking the max of the
- // leading zero bits.
- MaxHighZeros = std::max(Known.countMinLeadingZeros(),
- Known2.countMinLeadingZeros());
- // If either side is negative, the result is negative.
- else if (Known.isNegative() || Known2.isNegative())
- MaxHighOnes = 1;
- } else if (SPF == SPF_UMAX) {
- // We can derive a lower bound on the result by taking the max of the
- // leading one bits.
- MaxHighOnes =
- std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes());
- } else if (SPF == SPF_UMIN) {
- // We can derive an upper bound on the result by taking the max of the
- // leading zero bits.
- MaxHighZeros =
- std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros());
- } else if (SPF == SPF_ABS) {
+ computeKnownBits(I->getOperand(2), Known, Depth + 1, Q);
+ computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+
+ // Only known if known in both the LHS and RHS.
+ Known.One &= Known2.One;
+ Known.Zero &= Known2.Zero;
+
+ if (SPF == SPF_ABS) {
// RHS from matchSelectPattern returns the negation part of abs pattern.
// If the negate has an NSW flag we can assume the sign bit of the result
// will be 0 because that makes abs(INT_MIN) undefined.
if (match(RHS, m_Neg(m_Specific(LHS))) &&
Q.IIQ.hasNoSignedWrap(cast<Instruction>(RHS)))
- MaxHighZeros = 1;
+ Known.Zero.setSignBit();
}
- // Only known if known in both the LHS and RHS.
- Known.One &= Known2.One;
- Known.Zero &= Known2.Zero;
- if (MaxHighOnes > 0)
- Known.One.setHighBits(MaxHighOnes);
- if (MaxHighZeros > 0)
- Known.Zero.setHighBits(MaxHighZeros);
break;
}
case Instruction::FPTrunc:
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index c615462af407..3ebbac9fd659 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -308,11 +308,24 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
Known, DemandedElts, Depth + 1);
break;
}
- case TargetOpcode::G_SMIN:
+ case TargetOpcode::G_SMIN: {
+ // TODO: Handle clamp pattern with number of sign bits
+ KnownBits KnownRHS;
+ computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
+ Depth + 1);
+ computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
+ Depth + 1);
+ Known = KnownBits::smin(Known, KnownRHS);
+ break;
+ }
case TargetOpcode::G_SMAX: {
// TODO: Handle clamp pattern with number of sign bits
- computeKnownBitsMin(MI.getOperand(1).getReg(), MI.getOperand(2).getReg(),
- Known, DemandedElts, Depth + 1);
+ KnownBits KnownRHS;
+ computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
+ Depth + 1);
+ computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
+ Depth + 1);
+ Known = KnownBits::smax(Known, KnownRHS);
break;
}
case TargetOpcode::G_UMIN: {
@@ -321,13 +334,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
DemandedElts, Depth + 1);
computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS,
DemandedElts, Depth + 1);
-
- // UMIN - we know that the result will have the maximum of the
- // known zero leading bits of the inputs.
- unsigned LeadZero = Known.countMinLeadingZeros();
- LeadZero = std::max(LeadZero, KnownRHS.countMinLeadingZeros());
- Known &= KnownRHS;
- Known.Zero.setHighBits(LeadZero);
+ Known = KnownBits::umin(Known, KnownRHS);
break;
}
case TargetOpcode::G_UMAX: {
@@ -336,14 +343,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
DemandedElts, Depth + 1);
computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS,
DemandedElts, Depth + 1);
-
- // UMAX - we know that the result will have the maximum of the
- // known one leading bits of the inputs.
- unsigned LeadOne = Known.countMinLeadingOnes();
- LeadOne = std::max(LeadOne, KnownRHS.countMinLeadingOnes());
- Known.Zero &= KnownRHS.Zero;
- Known.One &= KnownRHS.One;
- Known.One.setHighBits(LeadOne);
+ Known = KnownBits::umax(Known, KnownRHS);
break;
}
case TargetOpcode::G_FCMP:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 73e042c47540..d2b3e009c202 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3390,29 +3390,13 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::UMIN: {
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-
- // UMIN - we know that the result will have the maximum of the
- // known zero leading bits of the inputs.
- unsigned LeadZero = Known.countMinLeadingZeros();
- LeadZero = std::max(LeadZero, Known2.countMinLeadingZeros());
-
- Known.Zero &= Known2.Zero;
- Known.One &= Known2.One;
- Known.Zero.setHighBits(LeadZero);
+ Known = KnownBits::umin(Known, Known2);
break;
}
case ISD::UMAX: {
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-
- // UMAX - we know that the result will have the maximum of the
- // known one leading bits of the inputs.
- unsigned LeadOne = Known.countMinLeadingOnes();
- LeadOne = std::max(LeadOne, Known2.countMinLeadingOnes());
-
- Known.Zero &= Known2.Zero;
- Known.One &= Known2.One;
- Known.One.setHighBits(LeadOne);
+ Known = KnownBits::umax(Known, Known2);
break;
}
case ISD::SMIN:
@@ -3446,12 +3430,13 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
}
}
- // Fallback - just get the shared known bits of the operands.
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
if (Known.isUnknown()) break; // Early-out
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known.Zero &= Known2.Zero;
- Known.One &= Known2.One;
+ if (IsMax)
+ Known = KnownBits::smax(Known, Known2);
+ else
+ Known = KnownBits::smin(Known, Known2);
break;
}
case ISD::FrameIndex:
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 1ff66d504cbe..aad50e124034 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -83,6 +83,68 @@ KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
return KnownOut;
}
+KnownBits KnownBits::makeGE(const APInt &Val) const {
+ // Count the number of leading bit positions where our underlying value is
+ // known to be less than or equal to Val.
+ unsigned N = (Zero | Val).countLeadingOnes();
+
+ // For each of those bit positions, if Val has a 1 in that bit then our
+ // underlying value must also have a 1.
+ APInt MaskedVal(Val);
+ MaskedVal.clearLowBits(getBitWidth() - N);
+ return KnownBits(Zero, One | MaskedVal);
+}
+
+KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
+ // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
+ // RHS. Ideally our caller would already have spotted these cases and
+ // optimized away the umax operation, but we handle them here for
+ // completeness.
+ if (LHS.getMinValue().uge(RHS.getMaxValue()))
+ return LHS;
+ if (RHS.getMinValue().uge(LHS.getMaxValue()))
+ return RHS;
+
+ // If the result of the umax is LHS then it must be greater than or equal to
+ // the minimum possible value of RHS. Likewise for RHS. Any known bits that
+ // are common to these two values are also known in the result.
+ KnownBits L = LHS.makeGE(RHS.getMinValue());
+ KnownBits R = RHS.makeGE(LHS.getMinValue());
+ return KnownBits(L.Zero & R.Zero, L.One & R.One);
+}
+
+KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
+ // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
+ auto Flip = [](KnownBits Val) { return KnownBits(Val.One, Val.Zero); };
+ return Flip(umax(Flip(LHS), Flip(RHS)));
+}
+
+KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
+ // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
+ auto Flip = [](KnownBits Val) {
+ unsigned SignBitPosition = Val.getBitWidth() - 1;
+ APInt Zero = Val.Zero;
+ APInt One = Val.One;
+ Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+ One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+ return KnownBits(Zero, One);
+ };
+ return Flip(umax(Flip(LHS), Flip(RHS)));
+}
+
+KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
+ // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
+ auto Flip = [](KnownBits Val) {
+ unsigned SignBitPosition = Val.getBitWidth() - 1;
+ APInt Zero = Val.One;
+ APInt One = Val.Zero;
+ Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+ One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+ return KnownBits(Zero, One);
+ };
+ return Flip(umax(Flip(LHS), Flip(RHS)));
+}
+
KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
// Result bit is 0 if either operand bit is 0.
Zero |= RHS.Zero;
diff --git a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
index 30ff37536faf..faf6f7087ac0 100644
--- a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
@@ -719,9 +719,9 @@ TEST_F(AArch64GISelMITest, TestKnownBitsUMax) {
KnownBits KnownUmax = Info.getKnownBits(CopyUMax);
EXPECT_EQ(64u, KnownUmax.getBitWidth());
- EXPECT_EQ(0u, KnownUmax.Zero.getZExtValue());
+ EXPECT_EQ(0xffu, KnownUmax.Zero.getZExtValue());
EXPECT_EQ(0xffffffffffffff00, KnownUmax.One.getZExtValue());
- EXPECT_EQ(0u, KnownUmax.Zero.getZExtValue());
+ EXPECT_EQ(0xffu, KnownUmax.Zero.getZExtValue());
EXPECT_EQ(0xffffffffffffff00, KnownUmax.One.getZExtValue());
}
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 694e5c4dcc71..89555a5881a5 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -103,13 +103,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
unsigned Bits = 4;
ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
- KnownBits KnownAnd(Bits), KnownOr(Bits), KnownXor(Bits);
+ KnownBits KnownAnd(Bits);
KnownAnd.Zero.setAllBits();
KnownAnd.One.setAllBits();
- KnownOr.Zero.setAllBits();
- KnownOr.One.setAllBits();
- KnownXor.Zero.setAllBits();
- KnownXor.One.setAllBits();
+ KnownBits KnownOr(KnownAnd);
+ KnownBits KnownXor(KnownAnd);
+ KnownBits KnownUMax(KnownAnd);
+ KnownBits KnownUMin(KnownAnd);
+ KnownBits KnownSMax(KnownAnd);
+ KnownBits KnownSMin(KnownAnd);
ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
@@ -126,6 +128,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
Res = N1 ^ N2;
KnownXor.One &= Res;
KnownXor.Zero &= ~Res;
+
+ Res = APIntOps::umax(N1, N2);
+ KnownUMax.One &= Res;
+ KnownUMax.Zero &= ~Res;
+
+ Res = APIntOps::umin(N1, N2);
+ KnownUMin.One &= Res;
+ KnownUMin.Zero &= ~Res;
+
+ Res = APIntOps::smax(N1, N2);
+ KnownSMax.One &= Res;
+ KnownSMax.Zero &= ~Res;
+
+ Res = APIntOps::smin(N1, N2);
+ KnownSMin.One &= Res;
+ KnownSMin.Zero &= ~Res;
});
});
@@ -140,6 +158,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
KnownBits ComputedXor = Known1 ^ Known2;
EXPECT_EQ(KnownXor.Zero, ComputedXor.Zero);
EXPECT_EQ(KnownXor.One, ComputedXor.One);
+
+ KnownBits ComputedUMax = KnownBits::umax(Known1, Known2);
+ EXPECT_EQ(KnownUMax.Zero, ComputedUMax.Zero);
+ EXPECT_EQ(KnownUMax.One, ComputedUMax.One);
+
+ KnownBits ComputedUMin = KnownBits::umin(Known1, Known2);
+ EXPECT_EQ(KnownUMin.Zero, ComputedUMin.Zero);
+ EXPECT_EQ(KnownUMin.One, ComputedUMin.One);
+
+ KnownBits ComputedSMax = KnownBits::smax(Known1, Known2);
+ EXPECT_EQ(KnownSMax.Zero, ComputedSMax.Zero);
+ EXPECT_EQ(KnownSMax.One, ComputedSMax.One);
+
+ KnownBits ComputedSMin = KnownBits::smin(Known1, Known2);
+ EXPECT_EQ(KnownSMin.Zero, ComputedSMin.Zero);
+ EXPECT_EQ(KnownSMin.One, ComputedSMin.One);
});
});
}
More information about the llvm-commits
mailing list