[llvm-commits] CVS: llvm/lib/Analysis/ScalarEvolution.cpp

Reid Spencer reid at x10sys.com
Wed Feb 28 23:26:05 PST 2007



Changes in directory llvm/lib/Analysis:

ScalarEvolution.cpp updated: 1.97 -> 1.98
---
Log message:

APIntify various computations in ScalarEvolution


---
Diffs of the changes:  (+52 -55)

 ScalarEvolution.cpp |  107 +++++++++++++++++++++++++---------------------------
 1 files changed, 52 insertions(+), 55 deletions(-)


Index: llvm/lib/Analysis/ScalarEvolution.cpp
diff -u llvm/lib/Analysis/ScalarEvolution.cpp:1.97 llvm/lib/Analysis/ScalarEvolution.cpp:1.98
--- llvm/lib/Analysis/ScalarEvolution.cpp:1.97	Wed Feb 28 17:31:17 2007
+++ llvm/lib/Analysis/ScalarEvolution.cpp	Thu Mar  1 01:25:48 2007
@@ -1176,7 +1176,7 @@
     /// in the header of its containing loop, we know the loop executes a
     /// constant number of times, and the PHI node is just a recurrence
     /// involving constants, fold it.
-    Constant *getConstantEvolutionLoopExitValue(PHINode *PN, uint64_t Its,
+    Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its,
                                                 const Loop *L);
   };
 }
@@ -1729,7 +1729,7 @@
     // Evaluate the condition for this iteration.
     Result = ConstantExpr::getICmp(predicate, Result, RHS);
     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
-    if (cast<ConstantInt>(Result)->getZExtValue() == false) {
+    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
 #if 0
       cerr << "\n***\n*** Computed loop count " << *ItCst
            << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
@@ -1824,13 +1824,13 @@
 /// constant number of times, and the PHI node is just a recurrence
 /// involving constants, fold it.
 Constant *ScalarEvolutionsImpl::
-getConstantEvolutionLoopExitValue(PHINode *PN, uint64_t Its, const Loop *L) {
+getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L){
   std::map<PHINode*, Constant*>::iterator I =
     ConstantEvolutionLoopExitValue.find(PN);
   if (I != ConstantEvolutionLoopExitValue.end())
     return I->second;
 
-  if (Its > MaxBruteForceIterations)
+  if (Its.ugt(APInt(Its.getBitWidth(),MaxBruteForceIterations)))
     return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
 
   Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
@@ -1850,11 +1850,11 @@
     return RetVal = 0;  // Not derived from same PHI.
 
   // Execute the loop symbolically to determine the exit value.
-  unsigned IterationNum = 0;
-  unsigned NumIterations = Its;
-  if (NumIterations != Its)
-    return RetVal = 0;  // More than 2^32 iterations??
+  if (Its.getActiveBits() >= 32)
+    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
 
+  unsigned NumIterations = Its.getZExtValue(); // must be in range
+  unsigned IterationNum = 0;
   for (Constant *PHIVal = StartCST; ; ++IterationNum) {
     if (IterationNum == NumIterations)
       return RetVal = PHIVal;  // Got exit value!
@@ -1904,7 +1904,7 @@
     // Couldn't symbolically evaluate.
     if (!CondVal) return UnknownValue;
 
-    if (CondVal->getZExtValue() == uint64_t(ExitWhen)) {
+    if (CondVal->getValue() == uint64_t(ExitWhen)) {
       ConstantEvolutionLoopExitValue[PN] = PHIVal;
       ++NumBruteForceTripCountsComputed;
       return SCEVConstant::get(ConstantInt::get(Type::Int32Ty, IterationNum));
@@ -1946,7 +1946,7 @@
               // this is a constant evolving PHI node, get the final value at
               // the specified iteration number.
               Constant *RV = getConstantEvolutionLoopExitValue(PN,
-                                               ICC->getValue()->getZExtValue(),
+                                                    ICC->getValue()->getValue(),
                                                                LI);
               if (RV) return SCEVUnknown::get(RV);
             }
@@ -2063,57 +2063,54 @@
 static std::pair<SCEVHandle,SCEVHandle>
 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec) {
   assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
-  SCEVConstant *L = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
-  SCEVConstant *M = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
-  SCEVConstant *N = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
+  SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
+  SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
+  SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
 
   // We currently can only solve this if the coefficients are constants.
-  if (!L || !M || !N) {
+  if (!LC || !MC || !NC) {
     SCEV *CNC = new SCEVCouldNotCompute();
     return std::make_pair(CNC, CNC);
   }
 
-  Constant *C = L->getValue();
-  Constant *Two = ConstantInt::get(C->getType(), 2);
-
-  // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
-  // The B coefficient is M-N/2
-  Constant *B = ConstantExpr::getSub(M->getValue(),
-                                     ConstantExpr::getSDiv(N->getValue(),
-                                                          Two));
-  // The A coefficient is N/2
-  Constant *A = ConstantExpr::getSDiv(N->getValue(), Two);
-
-  // Compute the B^2-4ac term.
-  Constant *SqrtTerm =
-    ConstantExpr::getMul(ConstantInt::get(C->getType(), 4),
-                         ConstantExpr::getMul(A, C));
-  SqrtTerm = ConstantExpr::getSub(ConstantExpr::getMul(B, B), SqrtTerm);
-
-  // Compute floor(sqrt(B^2-4ac))
-  uint64_t SqrtValV = cast<ConstantInt>(SqrtTerm)->getZExtValue();
-  uint64_t SqrtValV2 = (uint64_t)sqrt((double)SqrtValV);
-  // The square root might not be precise for arbitrary 64-bit integer
-  // values.  Do some sanity checks to ensure it's correct.
-  if (SqrtValV2*SqrtValV2 > SqrtValV ||
-      (SqrtValV2+1)*(SqrtValV2+1) <= SqrtValV) {
-    SCEV *CNC = new SCEVCouldNotCompute();
-    return std::make_pair(CNC, CNC);
-  }
-
-  ConstantInt *SqrtVal = ConstantInt::get(Type::Int64Ty, SqrtValV2);
-  SqrtTerm = ConstantExpr::getTruncOrBitCast(SqrtVal, SqrtTerm->getType());
-
-  Constant *NegB = ConstantExpr::getNeg(B);
-  Constant *TwoA = ConstantExpr::getMul(A, Two);
-
-  // The divisions must be performed as signed divisions.
-  Constant *Solution1 =
-    ConstantExpr::getSDiv(ConstantExpr::getAdd(NegB, SqrtTerm), TwoA);
-  Constant *Solution2 =
-    ConstantExpr::getSDiv(ConstantExpr::getSub(NegB, SqrtTerm), TwoA);
-  return std::make_pair(SCEVUnknown::get(Solution1),
-                        SCEVUnknown::get(Solution2));
+  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
+  APInt L(LC->getValue()->getValue());
+  APInt M(MC->getValue()->getValue());
+  APInt N(MC->getValue()->getValue());
+  APInt Two(BitWidth, 2);
+  APInt Four(BitWidth, 4);
+
+  { 
+    using namespace APIntOps;
+    APInt C(L);
+    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
+    // The B coefficient is M-N/2
+    APInt B(M);
+    B -= sdiv(N,Two);
+
+    // The A coefficient is N/2
+    APInt A(N);
+    A = A.sdiv(Two);
+
+    // Compute the B^2-4ac term.
+    APInt SqrtTerm(B);
+    SqrtTerm *= B;
+    SqrtTerm -= Four * (A * C);
+
+    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
+    // integer value or else APInt::sqrt() will assert.
+    APInt SqrtVal(SqrtTerm.sqrt());
+
+    // Compute the two solutions for the quadratic formula. 
+    // The divisions must be performed as signed divisions.
+    APInt NegB(-B);
+    APInt TwoA( A * Two );
+    ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
+    ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
+
+    return std::make_pair(SCEVUnknown::get(Solution1), 
+                          SCEVUnknown::get(Solution2));
+    } // end APIntOps namespace
 }
 
 /// HowFarToZero - Return the number of times a backedge comparing the specified






More information about the llvm-commits mailing list