[llvm-commits] CVS: llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp

Nick Lewycky nicholas at mxc.ca
Sun Jan 28 18:57:09 PST 2007



Changes in directory llvm/lib/Transforms/Scalar:

PredicateSimplifier.cpp updated: 1.49 -> 1.50
---
Log message:

Simplify names of lattice values. SGTUNE becomes SGT, for example.

Fix initializeConstant, now initializeInt. Fixes major performance
bottleneck.

X == Y || X->DominatedBy(Y) is redundant. Remove the X == Y part.

Fix crasher in makeEqual where getOrInsertNode would add a new constant,
producing an NE relationship between the two members we're trying to make
equal. This now allows us to mark more BBs as unreachable.



---
Diffs of the changes:  (+127 -71)

 PredicateSimplifier.cpp |  198 ++++++++++++++++++++++++++++++------------------
 1 files changed, 127 insertions(+), 71 deletions(-)


Index: llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp
diff -u llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp:1.49 llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp:1.50
--- llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp:1.49	Tue Jan 16 20:23:37 2007
+++ llvm/lib/Transforms/Scalar/PredicateSimplifier.cpp	Sun Jan 28 20:56:54 2007
@@ -52,14 +52,14 @@
 // responsible for analyzing the variable and seeing what new inferences
 // can be made from each property. For example:
 //
-//   %P = seteq int* %ptr, null
-//   %a = or bool %P, %Q
+//   %P = setne int* %ptr, null
+//   %a = and bool %P, %Q
 //   br bool %a label %cond_true, label %cond_false
 //
 // For the true branch, the VRPSolver will start with %a EQ true and look at
 // the definition of %a and find that it can infer that %P and %Q are both
 // true. From %P being true, it can infer that %ptr NE null. For the false
-// branch it can't infer anything from the "or" instruction.
+// branch it can't infer anything from the "and" instruction.
 //
 // Besides branches, we can also infer properties from instruction that may
 // have undefined behaviour in certain cases. For example, the dividend of
@@ -102,18 +102,18 @@
   //   0   1   0   1  1 -- GE                  11
   //   0   1   1   0  0 -- SGTULT              12
   //   0   1   1   0  1 -- SGEULE              13
-  //   0   1   1   1  0 -- SGTUNE              14
-  //   0   1   1   1  1 -- SGEUANY             15
+  //   0   1   1   1  0 -- SGT                 14
+  //   0   1   1   1  1 -- SGE                 15
   //   1   0   0   1  0 -- SLTUGT              18
   //   1   0   0   1  1 -- SLEUGE              19
   //   1   0   1   0  0 -- LT                  20
   //   1   0   1   0  1 -- LE                  21
-  //   1   0   1   1  0 -- SLTUNE              22
-  //   1   0   1   1  1 -- SLEUANY             23
-  //   1   1   0   1  0 -- SNEUGT              26
-  //   1   1   0   1  1 -- SANYUGE             27
-  //   1   1   1   0  0 -- SNEULT              28
-  //   1   1   1   0  1 -- SANYULE             29
+  //   1   0   1   1  0 -- SLT                 22
+  //   1   0   1   1  1 -- SLE                 23
+  //   1   1   0   1  0 -- UGT                 26
+  //   1   1   0   1  1 -- UGE                 27
+  //   1   1   1   0  0 -- ULT                 28
+  //   1   1   1   0  1 -- ULE                 29
   //   1   1   1   1  0 -- NE                  30
   enum LatticeBits {
     EQ_BIT = 1, UGT_BIT = 2, ULT_BIT = 4, SGT_BIT = 8, SLT_BIT = 16
@@ -128,23 +128,23 @@
     SGEULE = SGTULT | EQ_BIT,
     SLTUGT = SLT_BIT | UGT_BIT,
     SLEUGE = SLTUGT | EQ_BIT,
-    SNEULT = SLT_BIT | SGT_BIT | ULT_BIT,
-    SNEUGT = SLT_BIT | SGT_BIT | UGT_BIT,
-    SLTUNE = SLT_BIT | ULT_BIT | UGT_BIT,
-    SGTUNE = SGT_BIT | ULT_BIT | UGT_BIT,
-    SLEUANY = SLT_BIT | ULT_BIT | UGT_BIT | EQ_BIT,
-    SGEUANY = SGT_BIT | ULT_BIT | UGT_BIT | EQ_BIT,
-    SANYULE = SLT_BIT | SGT_BIT | ULT_BIT | EQ_BIT,
-    SANYUGE = SLT_BIT | SGT_BIT | UGT_BIT | EQ_BIT
+    ULT = SLT_BIT | SGT_BIT | ULT_BIT,
+    UGT = SLT_BIT | SGT_BIT | UGT_BIT,
+    SLT = SLT_BIT | ULT_BIT | UGT_BIT,
+    SGT = SGT_BIT | ULT_BIT | UGT_BIT,
+    SLE = SLT | EQ_BIT,
+    SGE = SGT | EQ_BIT,
+    ULE = ULT | EQ_BIT,
+    UGE = UGT | EQ_BIT
   };
 
   static bool validPredicate(LatticeVal LV) {
     switch (LV) {
     case GT: case GE: case LT: case LE: case NE:
-    case SGTULT: case SGTUNE: case SGEULE:
-    case SLTUGT: case SLTUNE: case SLEUGE:
-    case SNEULT: case SNEUGT:
-    case SLEUANY: case SGEUANY: case SANYULE: case SANYUGE:
+    case SGTULT: case SGT: case SGEULE:
+    case SLTUGT: case SLT: case SLEUGE:
+    case ULT: case UGT:
+    case SLE: case SGE: case ULE: case UGE:
       return true;
     default:
       return false;
@@ -366,36 +366,91 @@
 
     std::vector<Node> Nodes;
 
-    std::vector<std::pair<ConstantInt *, unsigned> > Constants;
-    void initializeConstant(Constant *C, unsigned index) {
-      ConstantInt *CI = dyn_cast<ConstantInt>(C);
-      if (!CI) return;
-
-      // XXX: instead of O(n) calls to addInequality, just find the 2, 3 or 4
-      // nodes that are nearest less than or greater than (signed or unsigned).
-      for (std::vector<std::pair<ConstantInt *, unsigned> >::iterator
-           I = Constants.begin(), E = Constants.end(); I != E; ++I) {
-        ConstantInt *Other = I->first;
-        if (CI->getType() == Other->getType()) {
-          unsigned lv = 0;
-
-          if (CI->getZExtValue() < Other->getZExtValue())
-            lv |= ULT_BIT;
-          else
-            lv |= UGT_BIT;
-
-          if (CI->getSExtValue() < Other->getSExtValue())
-            lv |= SLT_BIT;
-          else
-            lv |= SGT_BIT;
-
-          LatticeVal LV = static_cast<LatticeVal>(lv);
-          assert(validPredicate(LV) && "Not a valid predicate.");
-          if (!isRelatedBy(index, I->second, TreeRoot, LV))
-            addInequality(index, I->second, TreeRoot, LV);
-        }
+    std::vector<std::pair<ConstantInt *, unsigned> > Ints;
+
+    /// This is used to keep the ConstantInts list in unsigned ascending order.
+    /// If the bitwidths don't match, this sorts smaller values ahead.
+    struct SortByZExt {
+      bool operator()(const std::pair<ConstantInt *, unsigned> &LHS,
+                      const std::pair<ConstantInt *, unsigned> &RHS) const {
+        if (LHS.first->getType()->getBitWidth() !=
+            RHS.first->getType()->getBitWidth())
+          return LHS.first->getType()->getBitWidth() <
+                 RHS.first->getType()->getBitWidth();
+        return LHS.first->getZExtValue() < RHS.first->getZExtValue();
+      }
+    };
+
+    /// True when the bitwidth of LHS < bitwidth of RHS.
+    struct FindByIntegerWidth {
+      bool operator()(const std::pair<ConstantInt *, unsigned> &LHS,
+                      const std::pair<ConstantInt *, unsigned> &RHS) const {
+        return LHS.first->getType()->getBitWidth() <
+               RHS.first->getType()->getBitWidth();
       }
-      Constants.push_back(std::make_pair(CI, index));
+    };
+
+    void initializeInt(ConstantInt *CI, unsigned index) {
+      std::vector<std::pair<ConstantInt *, unsigned> >::iterator begin, end,
+          last, iULT, iUGT, iSLT, iSGT;
+
+      std::pair<ConstantInt *, unsigned> pair = std::make_pair(CI, index);
+
+      begin = std::lower_bound(Ints.begin(), Ints.end(), pair,
+                               FindByIntegerWidth());
+      end   = std::upper_bound(begin, Ints.end(), pair, FindByIntegerWidth());
+
+      if (begin == end) last = end;
+      else last = end - 1;
+
+      iUGT = std::lower_bound(begin, end, pair, SortByZExt());
+      iULT = (iUGT == begin || begin == end) ? end : iUGT - 1;
+
+      if (iUGT != end && iULT != end &&
+          (iULT->first->getSExtValue() >> 63) ==
+          (iUGT->first->getSExtValue() >> 63)) { // signs match
+        iSGT = iUGT;
+        iSLT = iULT;
+      } else {
+        if (iULT == end || iUGT == end) {
+          if (iULT == end) iSLT = last;  else iSLT = iULT;
+          if (iUGT == end) iSGT = begin; else iSGT = iUGT;
+	} else if (iULT->first->getSExtValue() < 0) {
+          assert(iUGT->first->getSExtValue() >= 0 && "Bad sign comparison.");
+          iSGT = iUGT;
+          iSLT = iULT;
+        } else {
+          assert(iULT->first->getSExtValue() >= 0 &&
+                 iUGT->first->getSExtValue() < 0 && "Bad sign comparison.");
+          iSGT = iULT;
+          iSLT = iUGT;
+	}
+
+        if (iSGT != end &&
+            iSGT->first->getSExtValue() < CI->getSExtValue()) iSGT = end;
+        if (iSLT != end &&
+            iSLT->first->getSExtValue() > CI->getSExtValue()) iSLT = end;
+
+        if (begin != end) {
+          if (begin->first->getSExtValue() < CI->getSExtValue())
+            if (iSLT == end ||
+                begin->first->getSExtValue() > iSLT->first->getSExtValue())
+              iSLT = begin;
+	}
+        if (last != end) {
+          if (last->first->getSExtValue() > CI->getSExtValue())
+            if (iSGT == end ||
+                last->first->getSExtValue() < iSGT->first->getSExtValue())
+              iSGT = last;
+	}
+      }
+
+      if (iULT != end) addInequality(iULT->second, index, TreeRoot, ULT);
+      if (iUGT != end) addInequality(iUGT->second, index, TreeRoot, UGT);
+      if (iSLT != end) addInequality(iSLT->second, index, TreeRoot, SLT);
+      if (iSGT != end) addInequality(iSGT->second, index, TreeRoot, SGT);
+
+      Ints.insert(iUGT, pair);
     }
 
   public:
@@ -441,8 +496,8 @@
       NodeMap.insert(std::lower_bound(NodeMap.begin(), NodeMap.end(),
                                       MapEntry), MapEntry);
 
-      if (Constant *C = dyn_cast<Constant>(V))
-        initializeConstant(C, MapEntry.index);
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
+        initializeInt(CI, MapEntry.index);
 
       return MapEntry.index;
     }
@@ -485,11 +540,7 @@
       ToRepoint.push_back(V);
 
       if (unsigned Conflict = getNode(V, Subtree)) {
-        // XXX: NodeMap.size() exceeds 68000 entries compiling kimwitu++!
-        // This adds 57 seconds to the otherwise 3 second build. Unacceptable.
-        //
-        // IDEA: could we iterate 1..Nodes.size() calling getNode? It's
-        // O(n log n) but kimwitu++ only has about 300 nodes.
+        // XXX: NodeMap.size() exceeds 68,000 entries compiling kimwitu++!
         for (NodeMapType::iterator I = NodeMap.begin(), E = NodeMap.end();
              I != E; ++I) {
           if (I->index == Conflict && Subtree->DominatedBy(I->Subtree))
@@ -541,7 +592,7 @@
       // add %a < %n2 too. This keeps the graph fully connected.
       if (LV1 != NE) {
         // Someone with a head for this sort of logic, please review this.
-        // Given that %x SLTUGT %y and %a SLEUANY %x, what is the relationship
+        // Given that %x SLTUGT %y and %a SLE %x, what is the relationship
         // between %a and %y? I believe the below code is correct, but I don't
         // think it's the most efficient solution.
 
@@ -795,7 +846,7 @@
         return IdomI(TopInst, I);
       else {
         ETNode *Node = Forest->getNodeForBlock(I->getParent());
-        return Node == Top || Node->DominatedBy(Top);
+        return Node->DominatedBy(Top);
       }
     }
 
@@ -945,7 +996,12 @@
       // Create N1.
       // XXX: this should call newNode, but instead the node might be created
       // in isRelatedBy. That's also a fixme.
-      if (!n1) n1 = IG.getOrInsertNode(V1, Top);
+      if (!n1) {
+        n1 = IG.getOrInsertNode(V1, Top);
+
+        if (isa<ConstantInt>(V1))
+          if (IG.isRelatedBy(n1, n2, Top, NE)) return false;
+      }
 
       // Migrate relationships from removed nodes to N1.
       Node *N1 = IG.node(n1);
@@ -954,7 +1010,7 @@
         unsigned n = *I;
         Node *N = IG.node(n);
         for (Node::iterator NI = N->begin(), NE = N->end(); NI != NE; ++NI) {
-          if (Top == NI->Subtree || NI->Subtree->DominatedBy(Top)) {
+          if (NI->Subtree->DominatedBy(Top)) {
             if (NI->To == n1) {
               assert((NI->LV & EQ_BIT) && "Node inequal to itself.");
               continue;
@@ -1016,21 +1072,21 @@
         case ICmpInst::ICMP_NE:
           return NE;
         case ICmpInst::ICMP_UGT:
-          return SNEUGT;
+          return UGT;
         case ICmpInst::ICMP_UGE:
-          return SANYUGE;
+          return UGE;
         case ICmpInst::ICMP_ULT:
-          return SNEULT;
+          return ULT;
         case ICmpInst::ICMP_ULE:
-          return SANYULE;
+          return ULE;
         case ICmpInst::ICMP_SGT:
-          return SGTUNE;
+          return SGT;
         case ICmpInst::ICMP_SGE:
-          return SGEUANY;
+          return SGE;
         case ICmpInst::ICMP_SLT:
-          return SLTUNE;
+          return SLT;
         case ICmpInst::ICMP_SLE:
-          return SLEUANY;
+          return SLE;
       }
     }
 






More information about the llvm-commits mailing list