[llvm] 5350e1b - [KnownBits] Implement accurate unsigned and signed max and min

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 7 01:10:15 PDT 2020


Author: Jay Foad
Date: 2020-09-07T09:09:01+01:00
New Revision: 5350e1b5096aa4707aa525baf7398d93b4a4f1a5

URL: https://github.com/llvm/llvm-project/commit/5350e1b5096aa4707aa525baf7398d93b4a4f1a5
DIFF: https://github.com/llvm/llvm-project/commit/5350e1b5096aa4707aa525baf7398d93b4a4f1a5.diff

LOG: [KnownBits] Implement accurate unsigned and signed max and min

Use the new implementation in ValueTracking, SelectionDAG and
GlobalISel.

Differential Revision: https://reviews.llvm.org/D87034

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownBits.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/Support/KnownBits.cpp
    llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 5b3de63cd359..a29e150b904a 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -173,6 +173,10 @@ struct KnownBits {
                      One.extractBits(NumBits, BitPosition));
   }
 
+  /// Return KnownBits based on this, but updated given that the underlying
+  /// value is known to be greater than or equal to Val.
+  KnownBits makeGE(const APInt &Val) const;
+
   /// Returns the minimum number of trailing zero bits.
   unsigned countMinTrailingZeros() const {
     return Zero.countTrailingOnes();
@@ -241,6 +245,18 @@ struct KnownBits {
   static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
                                     KnownBits RHS);
 
+  /// Compute known bits for umax(LHS, RHS).
+  static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute known bits for umin(LHS, RHS).
+  static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute known bits for smax(LHS, RHS).
+  static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute known bits for smin(LHS, RHS).
+  static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
+
   /// Insert the bits from a smaller known bits starting at bitPosition.
   void insertBits(const KnownBits &SubBits, unsigned BitPosition) {
     Zero.insertBits(SubBits.Zero, BitPosition);

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 3a6ee355c646..6e5a7195bb19 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1212,59 +1212,41 @@ static void computeKnownBitsFromOperator(const Operator *I,
     if (SelectPatternResult::isMinOrMax(SPF)) {
       computeKnownBits(RHS, Known, Depth + 1, Q);
       computeKnownBits(LHS, Known2, Depth + 1, Q);
-    } else {
-      computeKnownBits(I->getOperand(2), Known, Depth + 1, Q);
-      computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+      switch (SPF) {
+      default:
+        llvm_unreachable("Unhandled select pattern flavor!");
+      case SPF_SMAX:
+        Known = KnownBits::smax(Known, Known2);
+        break;
+      case SPF_SMIN:
+        Known = KnownBits::smin(Known, Known2);
+        break;
+      case SPF_UMAX:
+        Known = KnownBits::umax(Known, Known2);
+        break;
+      case SPF_UMIN:
+        Known = KnownBits::umin(Known, Known2);
+        break;
+      }
+      break;
     }
 
-    unsigned MaxHighOnes = 0;
-    unsigned MaxHighZeros = 0;
-    if (SPF == SPF_SMAX) {
-      // If both sides are negative, the result is negative.
-      if (Known.isNegative() && Known2.isNegative())
-        // We can derive a lower bound on the result by taking the max of the
-        // leading one bits.
-        MaxHighOnes =
-            std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes());
-      // If either side is non-negative, the result is non-negative.
-      else if (Known.isNonNegative() || Known2.isNonNegative())
-        MaxHighZeros = 1;
-    } else if (SPF == SPF_SMIN) {
-      // If both sides are non-negative, the result is non-negative.
-      if (Known.isNonNegative() && Known2.isNonNegative())
-        // We can derive an upper bound on the result by taking the max of the
-        // leading zero bits.
-        MaxHighZeros = std::max(Known.countMinLeadingZeros(),
-                                Known2.countMinLeadingZeros());
-      // If either side is negative, the result is negative.
-      else if (Known.isNegative() || Known2.isNegative())
-        MaxHighOnes = 1;
-    } else if (SPF == SPF_UMAX) {
-      // We can derive a lower bound on the result by taking the max of the
-      // leading one bits.
-      MaxHighOnes =
-          std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes());
-    } else if (SPF == SPF_UMIN) {
-      // We can derive an upper bound on the result by taking the max of the
-      // leading zero bits.
-      MaxHighZeros =
-          std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros());
-    } else if (SPF == SPF_ABS) {
+    computeKnownBits(I->getOperand(2), Known, Depth + 1, Q);
+    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
+
+    // Only known if known in both the LHS and RHS.
+    Known.One &= Known2.One;
+    Known.Zero &= Known2.Zero;
+
+    if (SPF == SPF_ABS) {
       // RHS from matchSelectPattern returns the negation part of abs pattern.
       // If the negate has an NSW flag we can assume the sign bit of the result
       // will be 0 because that makes abs(INT_MIN) undefined.
       if (match(RHS, m_Neg(m_Specific(LHS))) &&
           Q.IIQ.hasNoSignedWrap(cast<Instruction>(RHS)))
-        MaxHighZeros = 1;
+        Known.Zero.setSignBit();
     }
 
-    // Only known if known in both the LHS and RHS.
-    Known.One &= Known2.One;
-    Known.Zero &= Known2.Zero;
-    if (MaxHighOnes > 0)
-      Known.One.setHighBits(MaxHighOnes);
-    if (MaxHighZeros > 0)
-      Known.Zero.setHighBits(MaxHighZeros);
     break;
   }
   case Instruction::FPTrunc:

diff  --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index c615462af407..3ebbac9fd659 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -308,11 +308,24 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
                         Known, DemandedElts, Depth + 1);
     break;
   }
-  case TargetOpcode::G_SMIN:
+  case TargetOpcode::G_SMIN: {
+    // TODO: Handle clamp pattern with number of sign bits
+    KnownBits KnownRHS;
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
+                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
+                         Depth + 1);
+    Known = KnownBits::smin(Known, KnownRHS);
+    break;
+  }
   case TargetOpcode::G_SMAX: {
     // TODO: Handle clamp pattern with number of sign bits
-    computeKnownBitsMin(MI.getOperand(1).getReg(), MI.getOperand(2).getReg(),
-                        Known, DemandedElts, Depth + 1);
+    KnownBits KnownRHS;
+    computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
+                         Depth + 1);
+    computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS, DemandedElts,
+                         Depth + 1);
+    Known = KnownBits::smax(Known, KnownRHS);
     break;
   }
   case TargetOpcode::G_UMIN: {
@@ -321,13 +334,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
                          DemandedElts, Depth + 1);
     computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS,
                          DemandedElts, Depth + 1);
-
-    // UMIN - we know that the result will have the maximum of the
-    // known zero leading bits of the inputs.
-    unsigned LeadZero = Known.countMinLeadingZeros();
-    LeadZero = std::max(LeadZero, KnownRHS.countMinLeadingZeros());
-    Known &= KnownRHS;
-    Known.Zero.setHighBits(LeadZero);
+    Known = KnownBits::umin(Known, KnownRHS);
     break;
   }
   case TargetOpcode::G_UMAX: {
@@ -336,14 +343,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
                          DemandedElts, Depth + 1);
     computeKnownBitsImpl(MI.getOperand(2).getReg(), KnownRHS,
                          DemandedElts, Depth + 1);
-
-    // UMAX - we know that the result will have the maximum of the
-    // known one leading bits of the inputs.
-    unsigned LeadOne = Known.countMinLeadingOnes();
-    LeadOne = std::max(LeadOne, KnownRHS.countMinLeadingOnes());
-    Known.Zero &= KnownRHS.Zero;
-    Known.One &= KnownRHS.One;
-    Known.One.setHighBits(LeadOne);
+    Known = KnownBits::umax(Known, KnownRHS);
     break;
   }
   case TargetOpcode::G_FCMP:

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 73e042c47540..d2b3e009c202 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3390,29 +3390,13 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::UMIN: {
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-
-    // UMIN - we know that the result will have the maximum of the
-    // known zero leading bits of the inputs.
-    unsigned LeadZero = Known.countMinLeadingZeros();
-    LeadZero = std::max(LeadZero, Known2.countMinLeadingZeros());
-
-    Known.Zero &= Known2.Zero;
-    Known.One &= Known2.One;
-    Known.Zero.setHighBits(LeadZero);
+    Known = KnownBits::umin(Known, Known2);
     break;
   }
   case ISD::UMAX: {
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-
-    // UMAX - we know that the result will have the maximum of the
-    // known one leading bits of the inputs.
-    unsigned LeadOne = Known.countMinLeadingOnes();
-    LeadOne = std::max(LeadOne, Known2.countMinLeadingOnes());
-
-    Known.Zero &= Known2.Zero;
-    Known.One &= Known2.One;
-    Known.One.setHighBits(LeadOne);
+    Known = KnownBits::umax(Known, Known2);
     break;
   }
   case ISD::SMIN:
@@ -3446,12 +3430,13 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
       }
     }
 
-    // Fallback - just get the shared known bits of the operands.
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     if (Known.isUnknown()) break; // Early-out
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known.Zero &= Known2.Zero;
-    Known.One &= Known2.One;
+    if (IsMax)
+      Known = KnownBits::smax(Known, Known2);
+    else
+      Known = KnownBits::smin(Known, Known2);
     break;
   }
   case ISD::FrameIndex:

diff  --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 1ff66d504cbe..aad50e124034 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -83,6 +83,68 @@ KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
   return KnownOut;
 }
 
+KnownBits KnownBits::makeGE(const APInt &Val) const {
+  // Count the number of leading bit positions where our underlying value is
+  // known to be less than or equal to Val.
+  unsigned N = (Zero | Val).countLeadingOnes();
+
+  // For each of those bit positions, if Val has a 1 in that bit then our
+  // underlying value must also have a 1.
+  APInt MaskedVal(Val);
+  MaskedVal.clearLowBits(getBitWidth() - N);
+  return KnownBits(Zero, One | MaskedVal);
+}
+
+KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
+  // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
+  // RHS. Ideally our caller would already have spotted these cases and
+  // optimized away the umax operation, but we handle them here for
+  // completeness.
+  if (LHS.getMinValue().uge(RHS.getMaxValue()))
+    return LHS;
+  if (RHS.getMinValue().uge(LHS.getMaxValue()))
+    return RHS;
+
+  // If the result of the umax is LHS then it must be greater than or equal to
+  // the minimum possible value of RHS. Likewise for RHS. Any known bits that
+  // are common to these two values are also known in the result.
+  KnownBits L = LHS.makeGE(RHS.getMinValue());
+  KnownBits R = RHS.makeGE(LHS.getMinValue());
+  return KnownBits(L.Zero & R.Zero, L.One & R.One);
+}
+
+KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
+  // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
+  auto Flip = [](KnownBits Val) { return KnownBits(Val.One, Val.Zero); };
+  return Flip(umax(Flip(LHS), Flip(RHS)));
+}
+
+KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
+  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
+  auto Flip = [](KnownBits Val) {
+    unsigned SignBitPosition = Val.getBitWidth() - 1;
+    APInt Zero = Val.Zero;
+    APInt One = Val.One;
+    Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+    One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+    return KnownBits(Zero, One);
+  };
+  return Flip(umax(Flip(LHS), Flip(RHS)));
+}
+
+KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
+  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
+  auto Flip = [](KnownBits Val) {
+    unsigned SignBitPosition = Val.getBitWidth() - 1;
+    APInt Zero = Val.One;
+    APInt One = Val.Zero;
+    Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+    One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+    return KnownBits(Zero, One);
+  };
+  return Flip(umax(Flip(LHS), Flip(RHS)));
+}
+
 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
   // Result bit is 0 if either operand bit is 0.
   Zero |= RHS.Zero;

diff  --git a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
index 30ff37536faf..faf6f7087ac0 100644
--- a/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/KnownBitsTest.cpp
@@ -719,9 +719,9 @@ TEST_F(AArch64GISelMITest, TestKnownBitsUMax) {
 
   KnownBits KnownUmax = Info.getKnownBits(CopyUMax);
   EXPECT_EQ(64u, KnownUmax.getBitWidth());
-  EXPECT_EQ(0u, KnownUmax.Zero.getZExtValue());
+  EXPECT_EQ(0xffu, KnownUmax.Zero.getZExtValue());
   EXPECT_EQ(0xffffffffffffff00, KnownUmax.One.getZExtValue());
 
-  EXPECT_EQ(0u, KnownUmax.Zero.getZExtValue());
+  EXPECT_EQ(0xffu, KnownUmax.Zero.getZExtValue());
   EXPECT_EQ(0xffffffffffffff00, KnownUmax.One.getZExtValue());
 }

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 694e5c4dcc71..89555a5881a5 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -103,13 +103,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
   unsigned Bits = 4;
   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
     ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
-      KnownBits KnownAnd(Bits), KnownOr(Bits), KnownXor(Bits);
+      KnownBits KnownAnd(Bits);
       KnownAnd.Zero.setAllBits();
       KnownAnd.One.setAllBits();
-      KnownOr.Zero.setAllBits();
-      KnownOr.One.setAllBits();
-      KnownXor.Zero.setAllBits();
-      KnownXor.One.setAllBits();
+      KnownBits KnownOr(KnownAnd);
+      KnownBits KnownXor(KnownAnd);
+      KnownBits KnownUMax(KnownAnd);
+      KnownBits KnownUMin(KnownAnd);
+      KnownBits KnownSMax(KnownAnd);
+      KnownBits KnownSMin(KnownAnd);
 
       ForeachNumInKnownBits(Known1, [&](const APInt &N1) {
         ForeachNumInKnownBits(Known2, [&](const APInt &N2) {
@@ -126,6 +128,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
           Res = N1 ^ N2;
           KnownXor.One &= Res;
           KnownXor.Zero &= ~Res;
+
+          Res = APIntOps::umax(N1, N2);
+          KnownUMax.One &= Res;
+          KnownUMax.Zero &= ~Res;
+
+          Res = APIntOps::umin(N1, N2);
+          KnownUMin.One &= Res;
+          KnownUMin.Zero &= ~Res;
+
+          Res = APIntOps::smax(N1, N2);
+          KnownSMax.One &= Res;
+          KnownSMax.Zero &= ~Res;
+
+          Res = APIntOps::smin(N1, N2);
+          KnownSMin.One &= Res;
+          KnownSMin.Zero &= ~Res;
         });
       });
 
@@ -140,6 +158,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       KnownBits ComputedXor = Known1 ^ Known2;
       EXPECT_EQ(KnownXor.Zero, ComputedXor.Zero);
       EXPECT_EQ(KnownXor.One, ComputedXor.One);
+
+      KnownBits ComputedUMax = KnownBits::umax(Known1, Known2);
+      EXPECT_EQ(KnownUMax.Zero, ComputedUMax.Zero);
+      EXPECT_EQ(KnownUMax.One, ComputedUMax.One);
+
+      KnownBits ComputedUMin = KnownBits::umin(Known1, Known2);
+      EXPECT_EQ(KnownUMin.Zero, ComputedUMin.Zero);
+      EXPECT_EQ(KnownUMin.One, ComputedUMin.One);
+
+      KnownBits ComputedSMax = KnownBits::smax(Known1, Known2);
+      EXPECT_EQ(KnownSMax.Zero, ComputedSMax.Zero);
+      EXPECT_EQ(KnownSMax.One, ComputedSMax.One);
+
+      KnownBits ComputedSMin = KnownBits::smin(Known1, Known2);
+      EXPECT_EQ(KnownSMin.Zero, ComputedSMin.Zero);
+      EXPECT_EQ(KnownSMin.One, ComputedSMin.One);
     });
   });
 }


        


More information about the llvm-commits mailing list