[llvm-commits] CVS: llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp

Nick Lewycky nicholas at mxc.ca
Sat Mar 10 10:13:05 PST 2007



Changes in directory llvm/lib/Transforms/Scalar:

PredicateSimplifier.cpp updated: 1.55 -> 1.56
---
Log message:

Add value ranges. Currently inefficient in both execution time and
optimization power.


---
Diffs of the changes:  (+397 -219)

 PredicateSimplifier.cpp |  616 ++++++++++++++++++++++++++++++------------------
 1 files changed, 397 insertions(+), 219 deletions(-)


Index: llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp
diff -u llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp:1.55 llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp:1.56
--- llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp:1.55	Sun Mar  4 18:00:42 2007
+++ llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp	Sat Mar 10 12:12:48 2007
@@ -29,9 +29,9 @@
 //
 // These relationships define a graph between values of the same type. Each
 // Value is stored in a map table that retrieves the associated Node. This
-// is how EQ relationships are stored; the map contains pointers to the
-// same node. The node contains a most canonical Value* form and the list of
-// known relationships.
+// is how EQ relationships are stored; the map contains pointers from equal
+// Value to the same node. The node contains a most canonical Value* form
+// and the list of known relationships with other nodes.
 //
 // If two nodes are known to be inequal, then they will contain pointers to
 // each other with an "NE" relationship. If node getNode(%x) is less than
@@ -52,9 +52,9 @@
 // responsible for analyzing the variable and seeing what new inferences
 // can be made from each property. For example:
 //
-//   %P = icmp ne int* %ptr, null
-//   %a = and bool %P, %Q
-//   br bool %a label %cond_true, label %cond_false
+//   %P = icmp ne i32* %ptr, null
+//   %a = and i1 %P, %Q
+//   br i1 %a label %cond_true, label %cond_false
 //
 // For the true branch, the VRPSolver will start with %a EQ true and look at
 // the definition of %a and find that it can infer that %P and %Q are both
@@ -83,6 +83,7 @@
 #include "llvm/Analysis/ET-Forest.h"
 #include "llvm/Support/CFG.h"
 #include "llvm/Support/Compiler.h"
+#include "llvm/Support/ConstantRange.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/InstVisitor.h"
 #include "llvm/Transforms/Utils/Local.h"
@@ -165,6 +166,20 @@
     return Rev;
   }
 
+  /// This is a StrictWeakOrdering predicate that sorts ETNodes by how many
+  /// children they have. With this, you can iterate through a list sorted by
+  /// this operation and the first matching entry is the most specific match
+  /// for your basic block. The order provided is total; ETNodes with the
+  /// same number of children are sorted by pointer address.
+  struct VISIBILITY_HIDDEN OrderByDominance {
+    bool operator()(const ETNode *LHS, const ETNode *RHS) const {
+      unsigned LHS_spread = LHS->getDFSNumOut() - LHS->getDFSNumIn();
+      unsigned RHS_spread = RHS->getDFSNumOut() - RHS->getDFSNumIn();
+      if (LHS_spread != RHS_spread) return LHS_spread < RHS_spread;
+      else return LHS < RHS;
+    }
+  };
+
   /// The InequalityGraph stores the relationships between values.
   /// Each Value in the graph is assigned to a Node. Nodes are pointer
   /// comparable for equality. The caller is expected to maintain the logical
@@ -182,24 +197,10 @@
 
     class Node;
 
-    /// This is a StrictWeakOrdering predicate that sorts ETNodes by how many
-    /// children they have. With this, you can iterate through a list sorted by
-    /// this operation and the first matching entry is the most specific match
-    /// for your basic block. The order provided is total; ETNodes with the
-    /// same number of children are sorted by pointer address.
-    struct VISIBILITY_HIDDEN OrderByDominance {
-      bool operator()(const ETNode *LHS, const ETNode *RHS) const {
-        unsigned LHS_spread = LHS->getDFSNumOut() - LHS->getDFSNumIn();
-        unsigned RHS_spread = RHS->getDFSNumOut() - RHS->getDFSNumIn();
-        if (LHS_spread != RHS_spread) return LHS_spread < RHS_spread;
-        else return LHS < RHS;
-      }
-    };
-
     /// An Edge is contained inside a Node making one end of the edge implicit
     /// and contains a pointer to the other end. The edge contains a lattice
-    /// value specifying the relationship between the two nodes. Further, there
-    /// is an ETNode specifying which subtree of the dominator the edge applies.
+    /// value specifying the relationship and an ETNode specifying the root
+    /// in the dominator tree to which this edge applies.
     class VISIBILITY_HIDDEN Edge {
     public:
       Edge(unsigned T, LatticeVal V, ETNode *ST)
@@ -221,10 +222,6 @@
     /// A single node in the InequalityGraph. This stores the canonical Value
     /// for the node, as well as the relationships with the neighbours.
     ///
-    /// Because the lists are intended to be used for traversal, it is invalid
-    /// for the node to list itself in LessEqual or GreaterEqual lists. The
-    /// fact that a node is equal to itself is implied, and may be checked
-    /// with pointer comparison.
     /// @brief A single node in the InequalityGraph.
     class VISIBILITY_HIDDEN Node {
       friend class InequalityGraph;
@@ -366,96 +363,6 @@
 
     std::vector<Node> Nodes;
 
-    std::vector<std::pair<ConstantInt *, unsigned> > Ints;
-
-    /// This is used to keep the ConstantInts list in unsigned ascending order.
-    /// If the bitwidths don't match, this sorts smaller values ahead.
-    struct SortByZExt {
-      bool operator()(const std::pair<ConstantInt *, unsigned> &LHS,
-                      const std::pair<ConstantInt *, unsigned> &RHS) const {
-        if (LHS.first->getType()->getBitWidth() !=
-            RHS.first->getType()->getBitWidth())
-          return LHS.first->getType()->getBitWidth() <
-                 RHS.first->getType()->getBitWidth();
-        return LHS.first->getValue().ult(RHS.first->getValue());
-      }
-    };
-
-    /// True when the bitwidth of LHS < bitwidth of RHS.
-    struct FindByIntegerWidth {
-      bool operator()(const std::pair<ConstantInt *, unsigned> &LHS,
-                      const std::pair<ConstantInt *, unsigned> &RHS) const {
-        return LHS.first->getType()->getBitWidth() <
-               RHS.first->getType()->getBitWidth();
-      }
-    };
-
-    void initializeInt(ConstantInt *CI, unsigned index) {
-      std::vector<std::pair<ConstantInt *, unsigned> >::iterator begin, end,
-          last, iULT, iUGT, iSLT, iSGT;
-
-      std::pair<ConstantInt *, unsigned> pair = std::make_pair(CI, index);
-
-      begin = std::lower_bound(Ints.begin(), Ints.end(), pair,
-                               FindByIntegerWidth());
-      end   = std::upper_bound(begin, Ints.end(), pair, FindByIntegerWidth());
-
-      if (begin == end) last = end;
-      else last = end - 1;
-
-      iUGT = std::lower_bound(begin, end, pair, SortByZExt());
-      iULT = (iUGT == begin || begin == end) ? end : iUGT - 1;
-
-      if (iUGT != end && iULT != end &&
-          iULT->first->getValue().isNegative() == 
-          iUGT->first->getValue().isNegative()) { // signs match
-        iSGT = iUGT;
-        iSLT = iULT;
-      } else {
-        if (iULT == end || iUGT == end) {
-          if (iULT == end) iSLT = last;  else iSLT = iULT;
-          if (iUGT == end) iSGT = begin; else iSGT = iUGT;
-        } else if (iULT->first->getValue().isNegative()) {
-          assert(iUGT->first->getValue().isPositive() && 
-                 "Bad sign comparison.");
-          iSGT = iUGT;
-          iSLT = iULT;
-        } else {
-          assert(iULT->first->getValue().isPositive() &&
-                 iUGT->first->getValue().isNegative() &&"Bad sign comparison.");
-          iSGT = iULT;
-          iSLT = iUGT;
-        }
-
-        if (iSGT != end &&
-            iSGT->first->getValue().slt(CI->getValue())) 
-          iSGT = end;
-        if (iSLT != end &&
-            iSLT->first->getValue().sgt(CI->getValue())) 
-          iSLT = end;
-
-        if (begin != end) {
-          if (begin->first->getValue().slt(CI->getValue()))
-            if (iSLT == end ||
-                begin->first->getValue().sgt(iSLT->first->getValue()))
-              iSLT = begin;
-        }
-        if (last != end) {
-          if (last->first->getValue().sgt(CI->getValue()))
-            if (iSGT == end ||
-                last->first->getValue().slt(iSGT->first->getValue()))
-              iSGT = last;
-        }
-      }
-
-      if (iULT != end) addInequality(iULT->second, index, TreeRoot, ULT);
-      if (iUGT != end) addInequality(iUGT->second, index, TreeRoot, UGT);
-      if (iSLT != end) addInequality(iSLT->second, index, TreeRoot, SLT);
-      if (iSGT != end) addInequality(iSGT->second, index, TreeRoot, SGT);
-
-      Ints.insert(iUGT, pair);
-    }
-
   public:
     /// node - returns the node object at a given index retrieved from getNode.
     /// Index zero is reserved and may not be passed in here. The pointer
@@ -498,10 +405,6 @@
              "Attempt to create a duplicate Node.");
       NodeMap.insert(std::lower_bound(NodeMap.begin(), NodeMap.end(),
                                       MapEntry), MapEntry);
-
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
-        initializeInt(CI, MapEntry.index);
-
       return MapEntry.index;
     }
 
@@ -679,20 +582,25 @@
       N2->update(n1, reversePredicate(LV1), Subtree);
     }
 
-    /// Removes a Value from the graph, but does not delete any nodes. As this
-    /// method does not delete Nodes, V may not be the canonical choice for
-    /// a node with any relationships. It is invalid to call newNode on a Value
-    /// that has been removed.
+    /// remove - Removes a Value from the graph. If the value is the canonical
+    /// choice for a Node, destroys the Node from the graph deleting all edges
+    /// to and from it. This method does not renumber the nodes.
     void remove(Value *V) {
       for (unsigned i = 0; i < NodeMap.size();) {
         NodeMapType::iterator I = NodeMap.begin()+i;
-        assert((node(I->index)->getValue() != V || node(I->index)->begin() ==
-                node(I->index)->end()) && "Tried to delete in-use node.");
         if (I->V == V) {
-#ifndef NDEBUG
-          if (node(I->index)->getValue() == V)
-            node(I->index)->Canonical = NULL;
-#endif
+          Node *N = node(I->index);
+          if (node(I->index)->getValue() == V) {
+            for (Node::iterator NI = N->begin(), NE = N->end(); NI != NE; ++NI){
+              Node::iterator Iter = node(NI->To)->find(I->index, TreeRoot);
+              do {
+                node(NI->To)->Relations.erase(Iter);
+                Iter = node(NI->To)->find(I->index, TreeRoot);
+              } while (Iter != node(NI->To)->end());
+            }
+            N->Canonical = NULL;
+          }
+          N->Relations.clear();
           NodeMap.erase(I);
         } else ++i;
       }
@@ -721,6 +629,297 @@
 #endif
   };
 
+  class VRPSolver;
+
+  /// ValueRanges tracks the known integer ranges and anti-ranges of the nodes
+  /// in the InequalityGraph.
+  class VISIBILITY_HIDDEN ValueRanges {
+
+    /// A ScopedRange ties an InequalityGraph node with a ConstantRange under
+    /// the scope of a rooted subtree in the dominator tree.
+    class VISIBILITY_HIDDEN ScopedRange {
+    public:
+      ScopedRange(Value *V, ConstantRange CR, ETNode *ST)
+        : V(V), CR(CR), Subtree(ST) {}
+
+      Value *V;
+      ConstantRange CR;
+      ETNode *Subtree;
+
+      bool operator<(const ScopedRange &range) const {
+        if (V != range.V) return V < range.V;
+        else return OrderByDominance()(Subtree, range.Subtree);
+      }
+
+      bool operator<(const Value *value) const {
+        return V < value;
+      }
+    };
+
+    std::vector<ScopedRange> Ranges;
+    typedef std::vector<ScopedRange>::iterator iterator;
+
+    // XXX: this is a copy of the code in InequalityGraph::Node. Perhaps a
+    // intrusive domtree-scoped container is in order?
+
+    iterator begin() { return Ranges.begin(); }
+    iterator end()   { return Ranges.end();   }
+
+    iterator find(Value *V, ETNode *Subtree) {
+      iterator E = end();
+      for (iterator I = std::lower_bound(begin(), E, V);
+           I != E && I->V == V; ++I) {
+        if (Subtree->DominatedBy(I->Subtree))
+          return I;
+      }
+      return E;
+    }
+
+    void update(Value *V, ConstantRange CR, ETNode *Subtree) {
+      assert(!CR.isEmptySet() && "Empty ConstantRange!");
+      if (CR.isFullSet()) return;
+
+      iterator I = find(V, Subtree);
+      if (I == end()) {
+        ScopedRange range(V, CR, Subtree);
+        iterator Insert = std::lower_bound(begin(), end(), range);
+        Ranges.insert(Insert, range);
+      } else {
+        CR = CR.intersectWith(I->CR);
+        assert(!CR.isEmptySet() && "Empty intersection of ConstantRanges!");
+
+        if (CR != I->CR) {
+          if (Subtree != I->Subtree) {
+            assert(Subtree->DominatedBy(I->Subtree) &&
+                   "Find returned subtree that doesn't apply.");
+
+            ScopedRange range(V, CR, Subtree);
+            iterator Insert = std::lower_bound(begin(), end(), range);
+            Ranges.insert(Insert, range); // invalidates I
+            I = find(V, Subtree);
+          }
+
+          // Also, we have to tighten any edge that Subtree dominates.
+          for (iterator B = begin(); I->V == V; --I) {
+            if (I->Subtree->DominatedBy(Subtree)) {
+              CR = CR.intersectWith(I->CR);
+              assert(!CR.isEmptySet() &&
+                     "Empty intersection of ConstantRanges!");
+              I->CR = CR;
+            }
+            if (I == B) break;
+          }
+        }
+      }
+    }
+
+    /// range - Creates a ConstantRange representing the set of all values
+    /// that match the ICmpInst::Predicate with any of the values in CR.
+    ConstantRange range(ICmpInst::Predicate ICmpOpcode,
+                        const ConstantRange &CR) {
+      uint32_t W = CR.getBitWidth();
+      switch (ICmpOpcode) {
+        default: assert(!"Invalid ICmp opcode to range()");
+        case ICmpInst::ICMP_EQ:
+          return ConstantRange(CR.getLower(), CR.getUpper());
+        case ICmpInst::ICMP_NE:
+          if (CR.isSingleElement())
+            return ConstantRange(CR.getUpper(), CR.getLower());
+          return ConstantRange(W);
+        case ICmpInst::ICMP_ULT:
+          return ConstantRange(APInt::getMinValue(W), CR.getUnsignedMax());
+        case ICmpInst::ICMP_SLT:
+          return ConstantRange(APInt::getSignedMinValue(W), CR.getSignedMax());
+        case ICmpInst::ICMP_ULE: {
+          APInt UMax = CR.getUnsignedMax();
+          if (UMax == APInt::getMaxValue(W))
+            return ConstantRange(W);
+          return ConstantRange(APInt::getMinValue(W), UMax + 1);
+        }
+        case ICmpInst::ICMP_SLE: {
+          APInt SMax = CR.getSignedMax();
+          if (SMax     == APInt::getSignedMaxValue(W) ||
+              SMax + 1 == APInt::getSignedMaxValue(W))
+            return ConstantRange(W);
+          return ConstantRange(APInt::getSignedMinValue(W), SMax + 1);
+        }
+        case ICmpInst::ICMP_UGT:
+          return ConstantRange(CR.getUnsignedMin() + 1, 
+                               APInt::getMaxValue(W) + 1);
+        case ICmpInst::ICMP_SGT:
+          return ConstantRange(CR.getSignedMin() + 1,
+                               APInt::getSignedMaxValue(W) + 1);
+        case ICmpInst::ICMP_UGE: {
+          APInt UMin = CR.getUnsignedMin();
+          if (UMin == APInt::getMinValue(W))
+            return ConstantRange(W);
+          return ConstantRange(UMin, APInt::getMaxValue(W) + 1);
+        }
+        case ICmpInst::ICMP_SGE: {
+          APInt SMin = CR.getSignedMin();
+          if (SMin == APInt::getSignedMinValue(W))
+            return ConstantRange(W);
+          return ConstantRange(SMin, APInt::getSignedMaxValue(W) + 1);
+        }
+      }
+    }
+
+    /// create - Creates a ConstantRange that matches the given LatticeVal
+    /// relation with a given integer.
+    ConstantRange create(LatticeVal LV, const ConstantRange &CR) {
+      assert(!CR.isEmptySet() && "Can't deal with empty set.");
+
+      if (LV == NE)
+        return range(ICmpInst::ICMP_NE, CR);
+
+      unsigned LV_s = LV & (SGT_BIT|SLT_BIT);
+      unsigned LV_u = LV & (UGT_BIT|ULT_BIT);
+      bool hasEQ = LV & EQ_BIT;
+
+      ConstantRange Range(CR.getBitWidth());
+
+      if (LV_s == SGT_BIT) {
+        Range = Range.intersectWith(range(
+                    hasEQ ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SGT, CR));
+      } else if (LV_s == SLT_BIT) {
+        Range = Range.intersectWith(range(
+                    hasEQ ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_SLT, CR));
+      }
+
+      if (LV_u == UGT_BIT) {
+        Range = Range.intersectWith(range(
+                    hasEQ ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_UGT, CR));
+      } else if (LV_u == ULT_BIT) {
+        Range = Range.intersectWith(range(
+                    hasEQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT, CR));
+      }
+
+      return Range;
+    }
+
+    ConstantRange rangeFromValue(Value *V, ETNode *Subtree, uint32_t W) {
+      ConstantInt *C = dyn_cast<ConstantInt>(V);
+      if (C) {
+        return ConstantRange(C->getValue());
+      } else {
+        iterator I = find(V, Subtree);
+        if (I != end())
+          return I->CR;
+      }
+      return ConstantRange(W);
+    }
+
+    static uint32_t widthOfValue(Value *V) {
+      const Type *Ty = V->getType();
+      if (const IntegerType *ITy = dyn_cast<IntegerType>(Ty))
+        return ITy->getBitWidth();
+
+      // XXX: I'd like to transform T* into the appropriate integer by
+      // bit length, however that data may not be available.
+
+      return 0;
+    }
+
+  public:
+
+    bool isRelatedBy(Value *V1, Value *V2, ETNode *Subtree, LatticeVal LV) {
+      uint32_t W = widthOfValue(V1);
+      if (!W) return false;
+
+      ConstantRange CR1 = rangeFromValue(V1, Subtree, W);
+      ConstantRange CR2 = rangeFromValue(V2, Subtree, W);
+
+      // True iff all values in CR1 are LV to all values in CR2.
+      switch(LV) {
+      default: assert(!"Impossible lattice value!");
+      case NE:
+        return CR1.intersectWith(CR2).isEmptySet();
+      case ULT:
+        return CR1.getUnsignedMax().ult(CR2.getUnsignedMin());
+      case ULE:
+        return CR1.getUnsignedMax().ule(CR2.getUnsignedMin());
+      case UGT:
+        return CR1.getUnsignedMin().ugt(CR2.getUnsignedMax());
+      case UGE:
+        return CR1.getUnsignedMin().uge(CR2.getUnsignedMax());
+      case SLT:
+        return CR1.getSignedMax().slt(CR2.getSignedMin());
+      case SLE:
+        return CR1.getSignedMax().sle(CR2.getSignedMin());
+      case SGT:
+        return CR1.getSignedMin().sgt(CR2.getSignedMax());
+      case SGE:
+        return CR1.getSignedMin().sge(CR2.getSignedMax());
+      case LT:
+        return CR1.getUnsignedMax().ult(CR2.getUnsignedMin()) &&
+               CR1.getSignedMax().slt(CR2.getUnsignedMin());
+      case LE:
+        return CR1.getUnsignedMax().ule(CR2.getUnsignedMin()) &&
+               CR1.getSignedMax().sle(CR2.getUnsignedMin());
+      case GT:
+        return CR1.getUnsignedMin().ugt(CR2.getUnsignedMax()) &&
+               CR1.getSignedMin().sgt(CR2.getSignedMax());
+      case GE:
+        return CR1.getUnsignedMin().uge(CR2.getUnsignedMax()) &&
+               CR1.getSignedMin().sge(CR2.getSignedMax());
+      case SLTUGT:
+        return CR1.getSignedMax().slt(CR2.getSignedMin()) &&
+               CR1.getUnsignedMin().ugt(CR2.getUnsignedMax());
+      case SLEUGE:
+        return CR1.getSignedMax().sle(CR2.getSignedMin()) &&
+               CR1.getUnsignedMin().uge(CR2.getUnsignedMax());
+      case SGTULT:
+        return CR1.getSignedMin().sgt(CR2.getSignedMax()) &&
+               CR1.getUnsignedMax().ult(CR2.getUnsignedMin());
+      case SGEULE:
+        return CR1.getSignedMin().sge(CR2.getSignedMax()) &&
+               CR1.getUnsignedMax().ule(CR2.getUnsignedMin());
+      }
+    }
+
+    void addToWorklist(Value *V, const APInt *I, ICmpInst::Predicate Pred,
+                       VRPSolver *VRP);
+
+    void addInequality(Value *V1, Value *V2, ETNode *Subtree, LatticeVal LV,
+                       VRPSolver *VRP) {
+      assert(!isRelatedBy(V1, V2, Subtree, LV) && "Asked to do useless work.");
+
+      if (LV == NE) return; // we can't represent those.
+      // XXX: except in the case where isSingleElement and equal to either
+      // Lower or Upper. That's probably not profitable. (Type::Int1Ty?)
+
+      uint32_t W = widthOfValue(V1);
+      if (!W) return;
+
+      ConstantRange CR1 = rangeFromValue(V1, Subtree, W);
+      ConstantRange CR2 = rangeFromValue(V2, Subtree, W);
+
+      if (!CR1.isSingleElement()) {
+        ConstantRange NewCR1 = CR1.intersectWith(create(LV, CR2));
+        if (NewCR1 != CR1) {
+          if (NewCR1.isSingleElement())
+            addToWorklist(V1, NewCR1.getSingleElement(),
+                          ICmpInst::ICMP_EQ, VRP);
+          else
+            update(V1, NewCR1, Subtree);
+        }
+      }
+
+      if (!CR2.isSingleElement()) {
+        ConstantRange NewCR2 = CR2.intersectWith(create(reversePredicate(LV),
+                                                        CR1));
+        if (NewCR2 != CR2) {
+          if (NewCR2.isSingleElement())
+            addToWorklist(V2, NewCR2.getSingleElement(),
+                          ICmpInst::ICMP_EQ, VRP);
+          else
+            update(V2, NewCR2, Subtree);
+        }
+      }
+    }
+  };
+
+
   /// UnreachableBlocks keeps tracks of blocks that are for one reason or
   /// another discovered to be unreachable. This is used to cull the graph when
   /// analyzing instructions, and to mark blocks with the "unreachable"
@@ -781,6 +980,8 @@
   /// @brief VRPSolver calculates inferences from a new relationship.
   class VISIBILITY_HIDDEN VRPSolver {
   private:
+    friend class ValueRanges;
+
     struct Operation {
       Value *LHS, *RHS;
       ICmpInst::Predicate Op;
@@ -792,6 +993,8 @@
 
     InequalityGraph &IG;
     UnreachableBlocks &UB;
+    ValueRanges &VR;
+
     ETForest *Forest;
     ETNode *Top;
     BasicBlock *TopBB;
@@ -997,14 +1200,7 @@
       if (exitEarly) return true;
 
       // Create N1.
-      // XXX: this should call newNode, but instead the node might be created
-      // in isRelatedBy. That's also a fixme.
-      if (!n1) {
-        n1 = IG.getOrInsertNode(V1, Top);
-
-        if (isa<ConstantInt>(V1))
-          if (IG.isRelatedBy(n1, n2, Top, NE)) return false;
-      }
+      if (!n1) n1 = IG.newNode(V1);
 
       // Migrate relationships from removed nodes to N1.
       Node *N1 = IG.node(n1);
@@ -1094,20 +1290,22 @@
     }
 
   public:
-    VRPSolver(InequalityGraph &IG, UnreachableBlocks &UB, ETForest *Forest,
-              bool &modified, BasicBlock *TopBB)
+    VRPSolver(InequalityGraph &IG, UnreachableBlocks &UB, ValueRanges &VR,
+              ETForest *Forest, bool &modified, BasicBlock *TopBB)
       : IG(IG),
         UB(UB),
+        VR(VR),
         Forest(Forest),
         Top(Forest->getNodeForBlock(TopBB)),
         TopBB(TopBB),
         TopInst(NULL),
         modified(modified) {}
 
-    VRPSolver(InequalityGraph &IG, UnreachableBlocks &UB, ETForest *Forest,
-              bool &modified, Instruction *TopInst)
+    VRPSolver(InequalityGraph &IG, UnreachableBlocks &UB, ValueRanges &VR,
+              ETForest *Forest, bool &modified, Instruction *TopInst)
       : IG(IG),
         UB(UB),
+        VR(VR),
         Forest(Forest),
         TopInst(TopInst),
         modified(modified)
@@ -1122,12 +1320,6 @@
           return ConstantExpr::getCompare(Pred, C1, C2) ==
                  ConstantInt::getTrue();
 
-      // XXX: this is lousy. If we're passed a Constant, then we might miss
-      // some relationships if it isn't in the IG because the relationships
-      // added by initializeConstant are missing.
-      if (isa<Constant>(V1)) IG.getOrInsertNode(V1, Top);
-      if (isa<Constant>(V2)) IG.getOrInsertNode(V2, Top);
-
       if (unsigned n1 = IG.getNode(V1, Top))
         if (unsigned n2 = IG.getNode(V2, Top)) {
           if (n1 == n2) return Pred == ICmpInst::ICMP_EQ ||
@@ -1136,10 +1328,11 @@
                                Pred == ICmpInst::ICMP_SLE ||
                                Pred == ICmpInst::ICMP_SGE;
           if (Pred == ICmpInst::ICMP_EQ) return false;
-          return IG.isRelatedBy(n1, n2, Top, cmpInstToLattice(Pred));
+          if (IG.isRelatedBy(n1, n2, Top, cmpInstToLattice(Pred))) return true;
         }
 
-      return false;
+      if (Pred == ICmpInst::ICMP_EQ) return V1 == V2;
+      return VR.isRelatedBy(V1, V2, Top, cmpInstToLattice(Pred));
     }
 
     /// add - adds a new property to the work queue
@@ -1169,12 +1362,11 @@
         Value *Op0 = IG.canonicalize(BO->getOperand(0), Top);
         Value *Op1 = IG.canonicalize(BO->getOperand(1), Top);
 
-        // TODO: "and bool true, %x" EQ %y then %x EQ %y.
+        // TODO: "and i32 -1, %x" EQ %y then %x EQ %y.
 
         switch (BO->getOpcode()) {
           case Instruction::And: {
-            // "and int %a, %b"  EQ -1   then %a EQ -1   and %b EQ -1
-            // "and bool %a, %b" EQ true then %a EQ true and %b EQ true
+            // "and i32 %a, %b"  EQ -1 then %a EQ -1 and %b EQ -1
             ConstantInt *CI = ConstantInt::getAllOnesValue(Ty);
             if (Canonical == CI) {
               add(CI, Op0, ICmpInst::ICMP_EQ, NewContext);
@@ -1182,8 +1374,7 @@
             }
           } break;
           case Instruction::Or: {
-            // "or int %a, %b"  EQ 0     then %a EQ 0     and %b EQ 0
-            // "or bool %a, %b" EQ false then %a EQ false and %b EQ false
+            // "or i32 %a, %b" EQ 0 then %a EQ 0 and %b EQ 0
             Constant *Zero = Constant::getNullValue(Ty);
             if (Canonical == Zero) {
               add(Zero, Op0, ICmpInst::ICMP_EQ, NewContext);
@@ -1191,13 +1382,10 @@
             }
           } break;
           case Instruction::Xor: {
-            // "xor bool true,  %a" EQ true  then %a EQ false
-            // "xor bool true,  %a" EQ false then %a EQ true
-            // "xor bool false, %a" EQ true  then %a EQ true
-            // "xor bool false, %a" EQ false then %a EQ false
-            // "xor int %c, %a" EQ %c then %a EQ 0
-            // "xor int %c, %a" NE %c then %a NE 0
-            // 1. Repeat all of the above, with order of operands reversed.
+            // "xor i32 %c, %a" EQ %b then %a EQ %c ^ %b
+            // "xor i32 %c, %a" EQ %c then %a EQ 0
+            // "xor i32 %c, %a" NE %c then %a NE 0
+            // Repeat the above, with order of operands reversed.
             Value *LHS = Op0;
             Value *RHS = Op1;
             if (!isa<Constant>(LHS)) std::swap(LHS, RHS);
@@ -1221,7 +1409,7 @@
             break;
         }
       } else if (ICmpInst *IC = dyn_cast<ICmpInst>(I)) {
-        // "icmp ult int %a, int %y" EQ true then %a u< y
+        // "icmp ult i32 %a, %y" EQ true then %a u< y
         // etc.
 
         if (Canonical == ConstantInt::getTrue()) {
@@ -1234,7 +1422,7 @@
       } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
         if (I->getType()->isFPOrFPVector()) return;
 
-        // Given: "%a = select bool %x, int %b, int %c"
+        // Given: "%a = select i1 %x, i32 %b, i32 %c"
         // %a EQ %b and %b NE %c then %x EQ true
         // %a EQ %c and %b NE %c then %x EQ false
 
@@ -1246,7 +1434,7 @@
             add(SI->getCondition(), ConstantInt::getTrue(),
                 ICmpInst::ICMP_EQ, NewContext);
           else if (Canonical == IG.canonicalize(False, Top) ||
-                   isRelatedBy(I, True, ICmpInst::ICMP_NE))
+                   isRelatedBy(Canonical, True, ICmpInst::ICMP_NE))
             add(SI->getCondition(), ConstantInt::getFalse(),
                 ICmpInst::ICMP_EQ, NewContext);
         }
@@ -1293,10 +1481,10 @@
           }
         }
 
-        // "%x = add int %y, %z" and %x EQ %y then %z EQ 0
-        // "%x = mul int %y, %z" and %x EQ %y then %z EQ 1
+        // "%x = add i32 %y, %z" and %x EQ %y then %z EQ 0
+        // "%x = mul i32 %y, %z" and %x EQ %y then %z EQ 1
         // 1. Repeat all of the above, with order of operands reversed.
-        // "%x = udiv int %y, %z" and %x EQ %y then %z EQ 1
+        // "%x = udiv i32 %y, %z" and %x EQ %y then %z EQ 1
 
         Value *Known = Op0, *Unknown = Op1;
         if (Known != BO) std::swap(Known, Unknown);
@@ -1326,11 +1514,11 @@
           }
         }
 
-        // TODO: "%a = add int %b, 1" and %b > %z then %a >= %z.
+        // TODO: "%a = add i32 %b, 1" and %b > %z then %a >= %z.
 
       } else if (ICmpInst *IC = dyn_cast<ICmpInst>(I)) {
-        // "%a = icmp ult %b, %c" and %b u< %c  then %a EQ true
-        // "%a = icmp ult %b, %c" and %b u>= %c then %a EQ false
+        // "%a = icmp ult i32 %b, %c" and %b u< %c  then %a EQ true
+        // "%a = icmp ult i32 %b, %c" and %b u>= %c then %a EQ false
         // etc.
 
         Value *Op0 = IG.canonicalize(IC->getOperand(0), Top);
@@ -1343,7 +1531,7 @@
           add(IC, ConstantInt::getFalse(), ICmpInst::ICMP_EQ, NewContext);
         }
 
-        // TODO: "bool %x s<u> %y" implies %x = true and %y = false.
+        // TODO: "i1 %x s<u> %y" implies %x = true and %y = false.
 
         // TODO: make the predicate more strict, if possible.
 
@@ -1363,11 +1551,12 @@
           add(SI, SI->getTrueValue(), ICmpInst::ICMP_EQ, NewContext);
         }
       } else if (CastInst *CI = dyn_cast<CastInst>(I)) {
-        if (CI->getDestTy()->isFPOrFPVector()) return;
+        const Type *Ty = CI->getDestTy();
+        if (Ty->isFPOrFPVector()) return;
 
         if (Constant *C = dyn_cast<Constant>(
                 IG.canonicalize(CI->getOperand(0), Top))) {
-          add(CI, ConstantExpr::getCast(CI->getOpcode(), C, CI->getDestTy()),
+          add(CI, ConstantExpr::getCast(CI->getOpcode(), C, Ty),
               ICmpInst::ICMP_EQ, NewContext);
         }
 
@@ -1413,20 +1602,20 @@
           }
         }
 
-        if (compare(O.RHS, O.LHS)) {
+        if (compare(O.LHS, O.RHS)) {
           std::swap(O.LHS, O.RHS);
           O.Op = ICmpInst::getSwappedPredicate(O.Op);
         }
 
         if (O.Op == ICmpInst::ICMP_EQ) {
-          if (!makeEqual(O.LHS, O.RHS))
+          if (!makeEqual(O.RHS, O.LHS))
             UB.mark(TopBB);
         } else {
           LatticeVal LV = cmpInstToLattice(O.Op);
 
           if ((LV & EQ_BIT) &&
               isRelatedBy(O.LHS, O.RHS, ICmpInst::getSwappedPredicate(O.Op))) {
-            if (!makeEqual(O.LHS, O.RHS))
+            if (!makeEqual(O.RHS, O.LHS))
               UB.mark(TopBB);
           } else {
             if (isRelatedBy(O.LHS, O.RHS, ICmpInst::getInversePredicate(O.Op))){
@@ -1435,10 +1624,10 @@
               continue;
             }
 
-            unsigned n1 = IG.getOrInsertNode(O.LHS, Top);
-            unsigned n2 = IG.getOrInsertNode(O.RHS, Top);
+            unsigned n1 = IG.getNode(O.LHS, Top);
+            unsigned n2 = IG.getNode(O.RHS, Top);
 
-            if (n1 == n2) {
+            if (n1 && n1 == n2) {
               if (O.Op != ICmpInst::ICMP_UGE && O.Op != ICmpInst::ICMP_ULE &&
                   O.Op != ICmpInst::ICMP_SGE && O.Op != ICmpInst::ICMP_SLE)
                 UB.mark(TopBB);
@@ -1447,40 +1636,20 @@
               continue;
             }
 
-            if (IG.isRelatedBy(n1, n2, Top, LV)) {
+            if (VR.isRelatedBy(O.LHS, O.RHS, Top, LV) ||
+                (n1 && n2 && IG.isRelatedBy(n1, n2, Top, LV))) {
               WorkList.pop_front();
               continue;
             }
 
-            // Generalize %x u> -10 to %x > -10.
-            if (ConstantInt *CI = dyn_cast<ConstantInt>(O.RHS)) {
-              // xform doesn't apply to i1
-              if (CI->getType()->getBitWidth() > 1) {
-                if (LV == SLT && CI->getValue().isNegative()) {
-                  // i8 %x s< -5 implies %x < -5 and %x u> 127
-
-                  const IntegerType *Ty = CI->getType();
-                  LV = LT;
-                  add(O.LHS, ConstantInt::get(
-                        APInt::getSignedMaxValue(Ty->getBitWidth())),
-                      ICmpInst::ICMP_UGT);
-                } else if (LV == SGT && CI->getValue().isPositive()) {
-                  // i8 %x s> 5 implies %x > 5 and %x u< 128
-
-                  const IntegerType *Ty = CI->getType();
-                  LV = LT;
-                  add(O.LHS, ConstantInt::get(
-                        APInt::getSignedMinValue(Ty->getBitWidth())),
-                      ICmpInst::ICMP_ULT);
-                } else if (CI->getValue().isPositive()) {
-                  if (LV == ULT || LV == SLT) LV = LT;
-                  if (LV == UGT || LV == SGT) LV = GT;
-                }
-              }
+            VR.addInequality(O.LHS, O.RHS, Top, LV, this);
+            if ((!isa<ConstantInt>(O.RHS) && !isa<ConstantInt>(O.LHS)) ||
+                LV == NE) {
+              if (!n1) n1 = IG.newNode(O.LHS);
+              if (!n2) n2 = IG.newNode(O.RHS);
+              IG.addInequality(n1, n2, Top, LV);
             }
 
-            IG.addInequality(n1, n2, Top, LV);
-
             if (Instruction *I1 = dyn_cast<Instruction>(O.LHS)) {
               if (below(I1) ||
                   Top->DominatedBy(Forest->getNodeForBlock(I1->getParent())))
@@ -1523,6 +1692,11 @@
     }
   };
 
+  void ValueRanges::addToWorklist(Value *V, const APInt *I,
+                                  ICmpInst::Predicate Pred, VRPSolver *VRP) {
+    VRP->add(V, ConstantInt::get(*I), Pred, VRP->TopInst);
+  }
+
   /// PredicateSimplifier - This class is a simplifier that replaces
   /// one equivalent variable with another. It also tracks what
   /// can't be equal and will solve setcc instructions when possible.
@@ -1533,6 +1707,7 @@
     bool modified;
     InequalityGraph *IG;
     UnreachableBlocks UB;
+    ValueRanges *VR;
 
     std::vector<DominatorTree::Node *> WorkList;
 
@@ -1560,9 +1735,10 @@
     public:
       InequalityGraph &IG;
       UnreachableBlocks &UB;
+      ValueRanges &VR;
 
       Forwards(PredicateSimplifier *PS, DominatorTree::Node *DTNode)
-        : PS(PS), DTNode(DTNode), IG(*PS->IG), UB(PS->UB) {}
+        : PS(PS), DTNode(DTNode), IG(*PS->IG), UB(PS->UB), VR(*PS->VR) {}
 
       void visitTerminatorInst(TerminatorInst &TI);
       void visitBranchInst(BranchInst &BI);
@@ -1577,7 +1753,7 @@
 
       void visitBinaryOperator(BinaryOperator &BO);
     };
-
+  
     // Used by terminator instructions to proceed from the current basic
     // block to the next. Verifies that "current" dominates "next",
     // then calls visitBasicBlock.
@@ -1665,6 +1841,7 @@
     modified = false;
     BasicBlock *RootBlock = &F.getEntryBlock();
     IG = new InequalityGraph(Forest->getNodeForBlock(RootBlock));
+    VR = new ValueRanges();
     WorkList.push_back(DT->getRootNode());
 
     do {
@@ -1673,6 +1850,7 @@
       if (!UB.isDead(DTNode->getBlock())) visitBasicBlock(DTNode);
     } while (!WorkList.empty());
 
+    delete VR;
     delete IG;
 
     modified |= UB.kill();
@@ -1707,13 +1885,13 @@
 
       if (Dest == TrueDest) {
         DOUT << "(" << DTNode->getBlock()->getName() << ") true set:\n";
-        VRPSolver VRP(IG, UB, PS->Forest, PS->modified, Dest);
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, Dest);
         VRP.add(ConstantInt::getTrue(), Condition, ICmpInst::ICMP_EQ);
         VRP.solve();
         DEBUG(IG.dump());
       } else if (Dest == FalseDest) {
         DOUT << "(" << DTNode->getBlock()->getName() << ") false set:\n";
-        VRPSolver VRP(IG, UB, PS->Forest, PS->modified, Dest);
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, Dest);
         VRP.add(ConstantInt::getFalse(), Condition, ICmpInst::ICMP_EQ);
         VRP.solve();
         DEBUG(IG.dump());
@@ -1735,7 +1913,7 @@
       DOUT << "Switch thinking about BB %" << BB->getName()
            << "(" << PS->Forest->getNodeForBlock(BB)->getDFSNumIn() << ")\n";
 
-      VRPSolver VRP(IG, UB, PS->Forest, PS->modified, BB);
+      VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, BB);
       if (BB == SI.getDefaultDest()) {
         for (unsigned i = 1, e = SI.getNumCases(); i < e; ++i)
           if (SI.getSuccessor(i) != BB)
@@ -1750,7 +1928,7 @@
   }
 
   void PredicateSimplifier::Forwards::visitAllocaInst(AllocaInst &AI) {
-    VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &AI);
+    VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &AI);
     VRP.add(Constant::getNullValue(AI.getType()), &AI, ICmpInst::ICMP_NE);
     VRP.solve();
   }
@@ -1760,7 +1938,7 @@
     // avoid "load uint* null" -> null NE null.
     if (isa<Constant>(Ptr)) return;
 
-    VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &LI);
+    VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &LI);
     VRP.add(Constant::getNullValue(Ptr->getType()), Ptr, ICmpInst::ICMP_NE);
     VRP.solve();
   }
@@ -1769,13 +1947,13 @@
     Value *Ptr = SI.getPointerOperand();
     if (isa<Constant>(Ptr)) return;
 
-    VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &SI);
+    VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &SI);
     VRP.add(Constant::getNullValue(Ptr->getType()), Ptr, ICmpInst::ICMP_NE);
     VRP.solve();
   }
 
   void PredicateSimplifier::Forwards::visitSExtInst(SExtInst &SI) {
-    VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &SI);
+    VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &SI);
     uint32_t SrcBitWidth = cast<IntegerType>(SI.getSrcTy())->getBitWidth();
     uint32_t DstBitWidth = cast<IntegerType>(SI.getDestTy())->getBitWidth();
     APInt Min(APInt::getSignedMinValue(SrcBitWidth));
@@ -1788,7 +1966,7 @@
   }
 
   void PredicateSimplifier::Forwards::visitZExtInst(ZExtInst &ZI) {
-    VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &ZI);
+    VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &ZI);
     uint32_t SrcBitWidth = cast<IntegerType>(ZI.getSrcTy())->getBitWidth();
     uint32_t DstBitWidth = cast<IntegerType>(ZI.getDestTy())->getBitWidth();
     APInt Max(APInt::getMaxValue(SrcBitWidth));
@@ -1806,7 +1984,7 @@
       case Instruction::UDiv:
       case Instruction::SDiv: {
         Value *Divisor = BO.getOperand(1);
-        VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &BO);
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &BO);
         VRP.add(Constant::getNullValue(Divisor->getType()), Divisor,
                 ICmpInst::ICMP_NE);
         VRP.solve();






More information about the llvm-commits mailing list