[llvm-commits] [llvm] r44268 - /llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Nick Lewycky nicholas at mxc.ca
Wed Nov 21 23:59:41 PST 2007


Author: nicholas
Date: Thu Nov 22 01:59:40 2007
New Revision: 44268

URL: http://llvm.org/viewvc/llvm-project?rev=44268&view=rev
Log:
Instead of calculating constant factors, calculate the number of trailing
bits. Patch from Wojciech Matyjewicz.

Modified:
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=44268&r1=44267&r2=44268&view=diff

==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Thu Nov 22 01:59:40 2007
@@ -1410,62 +1410,60 @@
   return SE.getUnknown(PN);
 }
 
-/// GetConstantFactor - Determine the largest constant factor that S has.  For
-/// example, turn {4,+,8} -> 4.    (S umod result) should always equal zero.
-static APInt GetConstantFactor(SCEVHandle S) {
-  if (SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
-    const APInt& V = C->getValue()->getValue();
-    if (!V.isMinValue())
-      return V;
-    else   // Zero is a multiple of everything.
-      return APInt::getHighBitsSet(C->getBitWidth(), 1);
-  }
+/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
+/// guaranteed to end in (at every loop iteration).  It is, at the same time,
+/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
+/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
+static uint32_t GetMinTrailingZeros(SCEVHandle S) {
+  if (SCEVConstant *C = dyn_cast<SCEVConstant>(S))
+    // APInt::countTrailingZeros() returns the number of trailing zeros in its
+    // internal representation, which length may be greater than the represented
+    // value bitwidth. This is why we use a min operation here.
+    return std::min(C->getValue()->getValue().countTrailingZeros(),
+                    C->getBitWidth());
 
   if (SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
-    return GetConstantFactor(T->getOperand()).trunc(
-                               cast<IntegerType>(T->getType())->getBitWidth());
-  if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S))
-    return GetConstantFactor(E->getOperand()).zext(
-                               cast<IntegerType>(E->getType())->getBitWidth());
-  if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S))
-    return GetConstantFactor(E->getOperand()).sext(
-                               cast<IntegerType>(E->getType())->getBitWidth());
-  
+    return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth());
+
+  if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
+    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+    return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
+  }
+
+  if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
+    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+    return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
+  }
+
   if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
-    // The result is the min of all operands.
-    APInt Res(GetConstantFactor(A->getOperand(0)));
-    for (unsigned i = 1, e = A->getNumOperands(); 
-         i != e && Res.ugt(APInt(Res.getBitWidth(),1)); ++i) {
-      APInt Tmp(GetConstantFactor(A->getOperand(i)));
-      Res = APIntOps::umin(Res, Tmp);
-    }
-    return Res;
+    // The result is the min of all operands results.
+    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
+    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
+      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
+    return MinOpRes;
   }
 
   if (SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
-    // The result is the product of all the operands.
-    APInt Res(GetConstantFactor(M->getOperand(0)));
-    for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) {
-      APInt Tmp(GetConstantFactor(M->getOperand(i)));
-      Res *= Tmp;
-    }
-    return Res;
+    // The result is the sum of all operands results.
+    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
+    uint32_t BitWidth = M->getBitWidth();
+    for (unsigned i = 1, e = M->getNumOperands();
+         SumOpRes != BitWidth && i != e; ++i)
+      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
+                          BitWidth);
+    return SumOpRes;
   }
-    
+
   if (SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
-    // For now, we just handle linear expressions.
-    if (A->getNumOperands() == 2) {
-      // We want the GCD between the start and the stride value.
-      APInt Start(GetConstantFactor(A->getOperand(0)));
-      if (Start == 1) 
-        return Start;
-      APInt Stride(GetConstantFactor(A->getOperand(1)));
-      return APIntOps::GreatestCommonDivisor(Start, Stride);
-    }
+    // The result is the min of all operands results.
+    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
+    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
+      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
+    return MinOpRes;
   }
-  
-  // SCEVSDivExpr, SCEVUnknown.
-  return APInt(S->getBitWidth(), 1);
+
+  // SCEVSDivExpr, SCEVUnknown
+  return 0;
 }
 
 /// createSCEV - We know that there is no SCEV for the specified value.
@@ -1493,17 +1491,12 @@
       //
       // In order for this transformation to be safe, the LHS must be of the
       // form X*(2^n) and the Or constant must be less than 2^n.
-
       if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
         SCEVHandle LHS = getSCEV(I->getOperand(0));
-        APInt CommonFact(GetConstantFactor(LHS));
-        assert(!CommonFact.isMinValue() &&
-               "Common factor should at least be 1!");
         const APInt &CIVal = CI->getValue();
-        if (CommonFact.countTrailingZeros() >=
+        if (GetMinTrailingZeros(LHS) >=
             (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
-          return SE.getAddExpr(LHS,
-                               getSCEV(I->getOperand(1)));
+          return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
       }
       break;
     case Instruction::Xor:





More information about the llvm-commits mailing list