[llvm-branch-commits] [llvm] 23b4198 - [Support] Add KnownBits::icmp helpers.

Simon Pilgrim via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jan 4 05:12:41 PST 2021


Author: Simon Pilgrim
Date: 2021-01-04T12:46:27Z
New Revision: 23b41986527a3fc5615480a8f7a0b0debd5fcef4

URL: https://github.com/llvm/llvm-project/commit/23b41986527a3fc5615480a8f7a0b0debd5fcef4
DIFF: https://github.com/llvm/llvm-project/commit/23b41986527a3fc5615480a8f7a0b0debd5fcef4.diff

LOG: [Support] Add KnownBits::icmp helpers.

Check if all possible values for a pair of knownbits give the same icmp result - these are based off the checks performed in InstCombineCompares.cpp and D86578.

Add exhaustive unit test coverage - a followup will update InstCombineCompares.cpp to use this.

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownBits.h
    llvm/lib/Support/KnownBits.cpp
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index ec88b9807174..edb771d659e2 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -15,6 +15,7 @@
 #define LLVM_SUPPORT_KNOWNBITS_H
 
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/Optional.h"
 
 namespace llvm {
 
@@ -328,6 +329,36 @@ struct KnownBits {
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
   static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS);
 
+  /// Determine if these known bits always give the same ICMP_EQ result.
+  static Optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_NE result.
+  static Optional<bool> ne(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_UGT result.
+  static Optional<bool> ugt(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_UGE result.
+  static Optional<bool> uge(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_ULT result.
+  static Optional<bool> ult(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_ULE result.
+  static Optional<bool> ule(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_SGT result.
+  static Optional<bool> sgt(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_SGE result.
+  static Optional<bool> sge(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_SLT result.
+  static Optional<bool> slt(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Determine if these known bits always give the same ICMP_SLE result.
+  static Optional<bool> sle(const KnownBits &LHS, const KnownBits &RHS);
+
   /// Insert the bits from a smaller known bits starting at bitPosition.
   void insertBits(const KnownBits &SubBits, unsigned BitPosition) {
     Zero.insertBits(SubBits.Zero, BitPosition);

diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 2c25b7d9bac5..0147d21d153a 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -268,6 +268,75 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
   return Known;
 }
 
+Optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
+  if (LHS.isConstant() && RHS.isConstant())
+    return Optional<bool>(LHS.getConstant() == RHS.getConstant());
+  if (LHS.getMaxValue().ult(RHS.getMinValue()) ||
+      LHS.getMinValue().ugt(RHS.getMaxValue()))
+    return Optional<bool>(false);
+  if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
+    return Optional<bool>(false);
+  return None;
+}
+
+Optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
+  if (Optional<bool> KnownEQ = eq(LHS, RHS))
+    return Optional<bool>(!KnownEQ.getValue());
+  return None;
+}
+
+Optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
+  if (LHS.isConstant() && RHS.isConstant())
+    return Optional<bool>(LHS.getConstant().ugt(RHS.getConstant()));
+  // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
+  if (LHS.getMaxValue().ule(RHS.getMinValue()))
+    return Optional<bool>(false);
+  // LHS >u RHS -> true if umin(LHS) > umax(RHS)
+  if (LHS.getMinValue().ugt(RHS.getMaxValue()))
+    return Optional<bool>(true);
+  return None;
+}
+
+Optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
+  if (Optional<bool> IsUGT = ugt(RHS, LHS))
+    return Optional<bool>(!IsUGT.getValue());
+  return None;
+}
+
+Optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
+  return ugt(RHS, LHS);
+}
+
+Optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
+  return uge(RHS, LHS);
+}
+
+Optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
+  if (LHS.isConstant() && RHS.isConstant())
+    return Optional<bool>(LHS.getConstant().sgt(RHS.getConstant()));
+  // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
+  if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
+    return Optional<bool>(false);
+  // LHS >s RHS -> true if smin(LHS) > smax(RHS)
+  if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
+    return Optional<bool>(true);
+  return None;
+}
+
+Optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
+  if (Optional<bool> KnownSGT = sgt(RHS, LHS))
+    return Optional<bool>(!KnownSGT.getValue());
+  return None;
+}
+
+Optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
+  return sgt(RHS, LHS);
+}
+
+Optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
+  return sge(RHS, LHS);
+}
+
 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
   // If the source's MSB is zero then we know the rest of the bits already.
   if (isNonNegative())

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 528a645ec51a..ba587a1e2f65 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -281,6 +281,93 @@ TEST(KnownBitsTest, UnaryExhaustive) {
   });
 }
 
+TEST(KnownBitsTest, ICmpExhaustive) {
+  unsigned Bits = 4;
+  ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+      bool AllEQ = true, NoneEQ = true;
+      bool AllNE = true, NoneNE = true;
+      bool AllUGT = true, NoneUGT = true;
+      bool AllUGE = true, NoneUGE = true;
+      bool AllULT = true, NoneULT = true;
+      bool AllULE = true, NoneULE = true;
+      bool AllSGT = true, NoneSGT = true;
+      bool AllSGE = true, NoneSGE = true;
+      bool AllSLT = true, NoneSLT = true;
+      bool AllSLE = true, NoneSLE = true;
+
+      ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
+        ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
+          AllEQ &= N1.eq(N2);
+          AllNE &= N1.ne(N2);
+          AllUGT &= N1.ugt(N2);
+          AllUGE &= N1.uge(N2);
+          AllULT &= N1.ult(N2);
+          AllULE &= N1.ule(N2);
+          AllSGT &= N1.sgt(N2);
+          AllSGE &= N1.sge(N2);
+          AllSLT &= N1.slt(N2);
+          AllSLE &= N1.sle(N2);
+          NoneEQ &= !N1.eq(N2);
+          NoneNE &= !N1.ne(N2);
+          NoneUGT &= !N1.ugt(N2);
+          NoneUGE &= !N1.uge(N2);
+          NoneULT &= !N1.ult(N2);
+          NoneULE &= !N1.ule(N2);
+          NoneSGT &= !N1.sgt(N2);
+          NoneSGE &= !N1.sge(N2);
+          NoneSLT &= !N1.slt(N2);
+          NoneSLE &= !N1.sle(N2);
+        });
+      });
+
+      Optional<bool> KnownEQ = KnownBits::eq(Known1, Known2);
+      Optional<bool> KnownNE = KnownBits::ne(Known1, Known2);
+      Optional<bool> KnownUGT = KnownBits::ugt(Known1, Known2);
+      Optional<bool> KnownUGE = KnownBits::uge(Known1, Known2);
+      Optional<bool> KnownULT = KnownBits::ult(Known1, Known2);
+      Optional<bool> KnownULE = KnownBits::ule(Known1, Known2);
+      Optional<bool> KnownSGT = KnownBits::sgt(Known1, Known2);
+      Optional<bool> KnownSGE = KnownBits::sge(Known1, Known2);
+      Optional<bool> KnownSLT = KnownBits::slt(Known1, Known2);
+      Optional<bool> KnownSLE = KnownBits::sle(Known1, Known2);
+
+      EXPECT_EQ(AllEQ || NoneEQ, KnownEQ.hasValue());
+      EXPECT_EQ(AllNE || NoneNE, KnownNE.hasValue());
+      EXPECT_EQ(AllUGT || NoneUGT, KnownUGT.hasValue());
+      EXPECT_EQ(AllUGE || NoneUGE, KnownUGE.hasValue());
+      EXPECT_EQ(AllULT || NoneULT, KnownULT.hasValue());
+      EXPECT_EQ(AllULE || NoneULE, KnownULE.hasValue());
+      EXPECT_EQ(AllSGT || NoneSGT, KnownSGT.hasValue());
+      EXPECT_EQ(AllSGE || NoneSGE, KnownSGE.hasValue());
+      EXPECT_EQ(AllSLT || NoneSLT, KnownSLT.hasValue());
+      EXPECT_EQ(AllSLE || NoneSLE, KnownSLE.hasValue());
+
+      EXPECT_EQ(AllEQ, KnownEQ.hasValue() && KnownEQ.getValue());
+      EXPECT_EQ(AllNE, KnownNE.hasValue() && KnownNE.getValue());
+      EXPECT_EQ(AllUGT, KnownUGT.hasValue() && KnownUGT.getValue());
+      EXPECT_EQ(AllUGE, KnownUGE.hasValue() && KnownUGE.getValue());
+      EXPECT_EQ(AllULT, KnownULT.hasValue() && KnownULT.getValue());
+      EXPECT_EQ(AllULE, KnownULE.hasValue() && KnownULE.getValue());
+      EXPECT_EQ(AllSGT, KnownSGT.hasValue() && KnownSGT.getValue());
+      EXPECT_EQ(AllSGE, KnownSGE.hasValue() && KnownSGE.getValue());
+      EXPECT_EQ(AllSLT, KnownSLT.hasValue() && KnownSLT.getValue());
+      EXPECT_EQ(AllSLE, KnownSLE.hasValue() && KnownSLE.getValue());
+
+      EXPECT_EQ(NoneEQ, KnownEQ.hasValue() && !KnownEQ.getValue());
+      EXPECT_EQ(NoneNE, KnownNE.hasValue() && !KnownNE.getValue());
+      EXPECT_EQ(NoneUGT, KnownUGT.hasValue() && !KnownUGT.getValue());
+      EXPECT_EQ(NoneUGE, KnownUGE.hasValue() && !KnownUGE.getValue());
+      EXPECT_EQ(NoneULT, KnownULT.hasValue() && !KnownULT.getValue());
+      EXPECT_EQ(NoneULE, KnownULE.hasValue() && !KnownULE.getValue());
+      EXPECT_EQ(NoneSGT, KnownSGT.hasValue() && !KnownSGT.getValue());
+      EXPECT_EQ(NoneSGE, KnownSGE.hasValue() && !KnownSGE.getValue());
+      EXPECT_EQ(NoneSLT, KnownSLT.hasValue() && !KnownSLT.getValue());
+      EXPECT_EQ(NoneSLE, KnownSLE.hasValue() && !KnownSLE.getValue());
+    });
+  });
+}
+
 TEST(KnownBitsTest, GetMinMaxVal) {
   unsigned Bits = 4;
   ForeachKnownBits(Bits, [&](const KnownBits &Known) {


        


More information about the llvm-branch-commits mailing list