[llvm-commits] [llvm] r112267 - /llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Dan Gohman gohman at apple.com
Fri Aug 27 08:26:01 PDT 2010


Author: djg
Date: Fri Aug 27 10:26:01 2010
New Revision: 112267

URL: http://llvm.org/viewvc/llvm-project?rev=112267&view=rev
Log:
Optimize SCEVComplexityCompare. Use a 3-way return instead of a 2-way
return to avoid needing two calls to test for equivalence, and sort
addrecs by their degree before examining their operands.

Modified:
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=112267&r1=112266&r2=112267&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Fri Aug 27 10:26:01 2010
@@ -523,24 +523,34 @@
   public:
     explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
 
+    // Return true or false if LHS is less than, or at least RHS, respectively.
     bool operator()(const SCEV *LHS, const SCEV *RHS) const {
+      return compare(LHS, RHS) < 0;
+    }
+
+    // Return negative, zero, or positive, if LHS is less than, equal to, or
+    // greater than RHS, respectively. A three-way result allows recursive
+    // comparisons to be more efficient.
+    int compare(const SCEV *LHS, const SCEV *RHS) const {
       // Fast-path: SCEVs are uniqued so we can do a quick equality check.
       if (LHS == RHS)
-        return false;
+        return 0;
 
       // Primarily, sort the SCEVs by their getSCEVType().
       unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
       if (LType != RType)
-        return LType < RType;
+        return (int)LType - (int)RType;
 
       // Aside from the getSCEVType() ordering, the particular ordering
       // isn't very important except that it's beneficial to be consistent,
       // so that (a + b) and (b + a) don't end up as different expressions.
-
-      // Sort SCEVUnknown values with some loose heuristics. TODO: This is
-      // not as complete as it could be.
-      if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) {
+      switch (LType) {
+      case scUnknown: {
+        const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
         const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
+
+        // Sort SCEVUnknown values with some loose heuristics. TODO: This is
+        // not as complete as it could be.
         const Value *LV = LU->getValue(), *RV = RU->getValue();
 
         // Order pointer values after integer values. This helps SCEVExpander
@@ -548,22 +558,23 @@
         bool LIsPointer = LV->getType()->isPointerTy(),
              RIsPointer = RV->getType()->isPointerTy();
         if (LIsPointer != RIsPointer)
-          return RIsPointer;
+          return (int)LIsPointer - (int)RIsPointer;
 
         // Compare getValueID values.
         unsigned LID = LV->getValueID(),
                  RID = RV->getValueID();
         if (LID != RID)
-          return LID < RID;
+          return (int)LID - (int)RID;
 
         // Sort arguments by their position.
         if (const Argument *LA = dyn_cast<Argument>(LV)) {
           const Argument *RA = cast<Argument>(RV);
-          return LA->getArgNo() < RA->getArgNo();
+          unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
+          return (int)LArgNo - (int)RArgNo;
         }
 
-        // For instructions, compare their loop depth, and their opcode.
-        // This is pretty loose.
+        // For instructions, compare their loop depth, and their operand
+        // count.  This is pretty loose.
         if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
           const Instruction *RInst = cast<Instruction>(RV);
 
@@ -574,82 +585,105 @@
             unsigned LDepth = LI->getLoopDepth(LParent),
                      RDepth = LI->getLoopDepth(RParent);
             if (LDepth != RDepth)
-              return LDepth < RDepth;
+              return (int)LDepth - (int)RDepth;
           }
 
           // Compare the number of operands.
           unsigned LNumOps = LInst->getNumOperands(),
                    RNumOps = RInst->getNumOperands();
-          if (LNumOps != RNumOps)
-            return LNumOps < RNumOps;
+          return (int)LNumOps - (int)RNumOps;
         }
 
-        return false;
+        return 0;
       }
 
-      // Compare constant values.
-      if (const SCEVConstant *LC = dyn_cast<SCEVConstant>(LHS)) {
+      case scConstant: {
+        const SCEVConstant *LC = cast<SCEVConstant>(LHS);
         const SCEVConstant *RC = cast<SCEVConstant>(RHS);
+
+        // Compare constant values.
         const APInt &LA = LC->getValue()->getValue();
         const APInt &RA = RC->getValue()->getValue();
         unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
         if (LBitWidth != RBitWidth)
-          return LBitWidth < RBitWidth;
-        return LA.ult(RA);
+          return (int)LBitWidth - (int)RBitWidth;
+        return LA.ult(RA) ? -1 : 1;
       }
 
-      // Compare addrec loop depths.
-      if (const SCEVAddRecExpr *LA = dyn_cast<SCEVAddRecExpr>(LHS)) {
+      case scAddRecExpr: {
+        const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
         const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
+
+        // Compare addrec loop depths.
         const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
         if (LLoop != RLoop) {
           unsigned LDepth = LLoop->getLoopDepth(),
                    RDepth = RLoop->getLoopDepth();
           if (LDepth != RDepth)
-            return LDepth < RDepth;
+            return (int)LDepth - (int)RDepth;
         }
+
+        // Addrec complexity grows with operand count.
+        unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
+        if (LNumOps != RNumOps)
+          return (int)LNumOps - (int)RNumOps;
+
+        // Lexicographically compare.
+        for (unsigned i = 0; i != LNumOps; ++i) {
+          long X = compare(LA->getOperand(i), RA->getOperand(i));
+          if (X != 0)
+            return X;
+        }
+
+        return 0;
       }
 
-      // Lexicographically compare n-ary expressions.
-      if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) {
+      case scAddExpr:
+      case scMulExpr:
+      case scSMaxExpr:
+      case scUMaxExpr: {
+        const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
         const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
+
+        // Lexicographically compare n-ary expressions.
         unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
         for (unsigned i = 0; i != LNumOps; ++i) {
           if (i >= RNumOps)
-            return false;
-          const SCEV *LOp = LC->getOperand(i), *ROp = RC->getOperand(i);
-          if (operator()(LOp, ROp))
-            return true;
-          if (operator()(ROp, LOp))
-            return false;
+            return 1;
+          long X = compare(LC->getOperand(i), RC->getOperand(i));
+          if (X != 0)
+            return X;
         }
-        return LNumOps < RNumOps;
+        return (int)LNumOps - (int)RNumOps;
       }
 
-      // Lexicographically compare udiv expressions.
-      if (const SCEVUDivExpr *LC = dyn_cast<SCEVUDivExpr>(LHS)) {
+      case scUDivExpr: {
+        const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
         const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
-        const SCEV *LL = LC->getLHS(), *LR = LC->getRHS(),
-                   *RL = RC->getLHS(), *RR = RC->getRHS();
-        if (operator()(LL, RL))
-          return true;
-        if (operator()(RL, LL))
-          return false;
-        if (operator()(LR, RR))
-          return true;
-        if (operator()(RR, LR))
-          return false;
-        return false;
+
+        // Lexicographically compare udiv expressions.
+        long X = compare(LC->getLHS(), RC->getLHS());
+        if (X != 0)
+          return X;
+        return compare(LC->getRHS(), RC->getRHS());
       }
 
-      // Compare cast expressions by operand.
-      if (const SCEVCastExpr *LC = dyn_cast<SCEVCastExpr>(LHS)) {
+      case scTruncate:
+      case scZeroExtend:
+      case scSignExtend: {
+        const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
         const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
-        return operator()(LC->getOperand(), RC->getOperand());
+
+        // Compare cast expressions by operand.
+        return compare(LC->getOperand(), RC->getOperand());
+      }
+
+      default:
+        break;
       }
 
       llvm_unreachable("Unknown SCEV kind!");
-      return false;
+      return 0;
     }
   };
 }





More information about the llvm-commits mailing list