[llvm-commits] [llvm] r44268 - /llvm/trunk/lib/Analysis/ScalarEvolution.cpp
Nick Lewycky
nicholas at mxc.ca
Wed Nov 21 23:59:41 PST 2007
Author: nicholas
Date: Thu Nov 22 01:59:40 2007
New Revision: 44268
URL: http://llvm.org/viewvc/llvm-project?rev=44268&view=rev
Log:
Instead of calculating constant factors, calculate the number of trailing
bits. Patch from Wojciech Matyjewicz.
Modified:
llvm/trunk/lib/Analysis/ScalarEvolution.cpp
Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=44268&r1=44267&r2=44268&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Thu Nov 22 01:59:40 2007
@@ -1410,62 +1410,60 @@
return SE.getUnknown(PN);
}
-/// GetConstantFactor - Determine the largest constant factor that S has. For
-/// example, turn {4,+,8} -> 4. (S umod result) should always equal zero.
-static APInt GetConstantFactor(SCEVHandle S) {
- if (SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- const APInt& V = C->getValue()->getValue();
- if (!V.isMinValue())
- return V;
- else // Zero is a multiple of everything.
- return APInt::getHighBitsSet(C->getBitWidth(), 1);
- }
+/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
+/// guaranteed to end in (at every loop iteration). It is, at the same time,
+/// the minimum number of times S is divisible by 2. For example, given {4,+,8}
+/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S.
+static uint32_t GetMinTrailingZeros(SCEVHandle S) {
+ if (SCEVConstant *C = dyn_cast<SCEVConstant>(S))
+ // APInt::countTrailingZeros() returns the number of trailing zeros in its
+ // internal representation, which length may be greater than the represented
+ // value bitwidth. This is why we use a min operation here.
+ return std::min(C->getValue()->getValue().countTrailingZeros(),
+ C->getBitWidth());
if (SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
- return GetConstantFactor(T->getOperand()).trunc(
- cast<IntegerType>(T->getType())->getBitWidth());
- if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S))
- return GetConstantFactor(E->getOperand()).zext(
- cast<IntegerType>(E->getType())->getBitWidth());
- if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S))
- return GetConstantFactor(E->getOperand()).sext(
- cast<IntegerType>(E->getType())->getBitWidth());
-
+ return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth());
+
+ if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
+ uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+ return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
+ }
+
+ if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
+ uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+ return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
+ }
+
if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
- // The result is the min of all operands.
- APInt Res(GetConstantFactor(A->getOperand(0)));
- for (unsigned i = 1, e = A->getNumOperands();
- i != e && Res.ugt(APInt(Res.getBitWidth(),1)); ++i) {
- APInt Tmp(GetConstantFactor(A->getOperand(i)));
- Res = APIntOps::umin(Res, Tmp);
- }
- return Res;
+ // The result is the min of all operands results.
+ uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
+ for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
+ MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
+ return MinOpRes;
}
if (SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
- // The result is the product of all the operands.
- APInt Res(GetConstantFactor(M->getOperand(0)));
- for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) {
- APInt Tmp(GetConstantFactor(M->getOperand(i)));
- Res *= Tmp;
- }
- return Res;
+ // The result is the sum of all operands results.
+ uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
+ uint32_t BitWidth = M->getBitWidth();
+ for (unsigned i = 1, e = M->getNumOperands();
+ SumOpRes != BitWidth && i != e; ++i)
+ SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
+ BitWidth);
+ return SumOpRes;
}
-
+
if (SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
- // For now, we just handle linear expressions.
- if (A->getNumOperands() == 2) {
- // We want the GCD between the start and the stride value.
- APInt Start(GetConstantFactor(A->getOperand(0)));
- if (Start == 1)
- return Start;
- APInt Stride(GetConstantFactor(A->getOperand(1)));
- return APIntOps::GreatestCommonDivisor(Start, Stride);
- }
+ // The result is the min of all operands results.
+ uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
+ for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
+ MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
+ return MinOpRes;
}
-
- // SCEVSDivExpr, SCEVUnknown.
- return APInt(S->getBitWidth(), 1);
+
+ // SCEVSDivExpr, SCEVUnknown
+ return 0;
}
/// createSCEV - We know that there is no SCEV for the specified value.
@@ -1493,17 +1491,12 @@
//
// In order for this transformation to be safe, the LHS must be of the
// form X*(2^n) and the Or constant must be less than 2^n.
-
if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
SCEVHandle LHS = getSCEV(I->getOperand(0));
- APInt CommonFact(GetConstantFactor(LHS));
- assert(!CommonFact.isMinValue() &&
- "Common factor should at least be 1!");
const APInt &CIVal = CI->getValue();
- if (CommonFact.countTrailingZeros() >=
+ if (GetMinTrailingZeros(LHS) >=
(CIVal.getBitWidth() - CIVal.countLeadingZeros()))
- return SE.getAddExpr(LHS,
- getSCEV(I->getOperand(1)));
+ return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
}
break;
case Instruction::Xor:
More information about the llvm-commits
mailing list