[llvm-commits] [llvm] r40641 - /llvm/trunk/lib/Transforms/Scalar/LoopStrengthReduce.cpp

Dan Gohman djg at cray.com
Thu Aug 2 14:28:52 PDT 2007


>> Use SCEVExpander::InsertCastOfTo instead of calling new IntToPtrInst
>> directly, because the insert point used by the SCEVExpander may vary
>> from what LSR originally computes.
> 
> Did this cause a bug?  If so, can you please commit a regtest?  Thanks,

So far I've only seen the problem in the tree where I'm working on
the attached patch, which moves the getelementptr logic out of LSR into
SCEV, and generalizes it to handle more pointer expressions. 

Dan

-- 
Dan Gohman, Cray Inc.
-------------- next part --------------
Index: include/llvm/Analysis/ScalarEvolutionExpander.h
===================================================================
--- include/llvm/Analysis/ScalarEvolutionExpander.h	(revision 40735)
+++ include/llvm/Analysis/ScalarEvolutionExpander.h	(working copy)
@@ -144,6 +144,9 @@
     Value *visitAddRecExpr(SCEVAddRecExpr *S);
 
     Value *visitUnknown(SCEVUnknown *S) {
+      if (isa<PointerType>(S->getValue()->getType()))
+        return InsertCastOfTo(Instruction::PtrToInt,
+                              S->getValue(), S->getType());
       return S->getValue();
     }
   };
Index: include/llvm/Analysis/ScalarEvolutionExpressions.h
===================================================================
--- include/llvm/Analysis/ScalarEvolutionExpressions.h	(revision 40735)
+++ include/llvm/Analysis/ScalarEvolutionExpressions.h	(working copy)
@@ -498,14 +498,17 @@
   ///
   class SCEVUnknown : public SCEV {
     Value *V;
-    SCEVUnknown(Value *v) : SCEV(scUnknown), V(v) {}
+    const Type *Ty;
+    SCEVUnknown(Value *v, const Type *ty) : SCEV(scUnknown), V(v), Ty(ty) {}
 
   protected:
     ~SCEVUnknown();
   public:
     /// get method - For SCEVUnknown, this just gets and returns a new
-    /// SCEVUnknown.
-    static SCEVHandle get(Value *V);
+    /// SCEVUnknown. Ty is usually the type of V, unless V does not have an
+    /// integer type, in which case it is the type that V will be implicitly
+    /// casted to for the purposes of SCEV analysis.
+    static SCEVHandle get(Value *V, const Type *Ty);
 
     /// getIntegerSCEV - Given an integer or FP type, create a constant for the
     /// specified signed integer value and return a SCEV for the constant.
@@ -524,7 +527,7 @@
       return this;
     }
 
-    virtual const Type *getType() const;
+    virtual const Type *getType() const { return Ty; }
 
     virtual void print(std::ostream &OS) const;
     void print(std::ostream *OS) const { if (OS) print(*OS); }
Index: lib/Analysis/ScalarEvolutionExpander.cpp
===================================================================
--- lib/Analysis/ScalarEvolutionExpander.cpp	(revision 40735)
+++ lib/Analysis/ScalarEvolutionExpander.cpp	(working copy)
@@ -200,7 +200,7 @@
   // folders, then expandCodeFor the closed form.  This allows the folders to
   // simplify the expression without having to build a bunch of special code
   // into this folder.
-  SCEVHandle IH = SCEVUnknown::get(I);   // Get I as a "symbolic" SCEV.
+  SCEVHandle IH = SCEVUnknown::get(I, Ty);   // Get I as a "symbolic" SCEV.
 
   SCEVHandle V = S->evaluateAtIteration(IH);
   //cerr << "Evaluated: " << *this << "\n     to: " << *V << "\n";
Index: lib/Analysis/ScalarEvolution.cpp
===================================================================
--- lib/Analysis/ScalarEvolution.cpp	(revision 40735)
+++ lib/Analysis/ScalarEvolution.cpp	(working copy)
@@ -68,6 +68,7 @@
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Assembly/Writer.h"
+#include "llvm/Target/TargetData.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Support/CFG.h"
 #include "llvm/Support/CommandLine.h"
@@ -405,10 +406,6 @@
   return true;
 }
 
-const Type *SCEVUnknown::getType() const {
-  return V->getType();
-}
-
 void SCEVUnknown::print(std::ostream &OS) const {
   WriteAsOperand(OS, V, false);
 }
@@ -488,7 +485,7 @@
     C = ConstantFP::get(Ty, Val);
   else 
     C = ConstantInt::get(Ty, Val);
-  return SCEVUnknown::get(C);
+  return SCEVUnknown::get(C, Ty);
 }
 
 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
@@ -505,11 +502,26 @@
   return SCEVZeroExtendExpr::get(V, Ty);
 }
 
+/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
+/// input value to the specified type.  If the type must be extended, it is sign
+/// extended.
+static SCEVHandle getTruncateOrSignExtend(const SCEVHandle &V, const Type *Ty) {
+  const Type *SrcTy = V->getType();
+  assert(SrcTy->isInteger() && Ty->isInteger() &&
+         "Cannot truncate or sign extend with non-integer arguments!");
+  if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits())
+    return V;  // No conversion
+  if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits())
+    return SCEVTruncateExpr::get(V, Ty);
+  return SCEVSignExtendExpr::get(V, Ty);
+}
+
 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
 ///
 SCEVHandle SCEV::getNegativeSCEV(const SCEVHandle &V) {
   if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
-    return SCEVUnknown::get(ConstantExpr::getNeg(VC->getValue()));
+    return SCEVConstant::get
+              (cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
 
   return SCEVMulExpr::get(V, SCEVUnknown::getIntegerSCEV(-1, V->getType()));
 }
@@ -578,7 +590,7 @@
 SCEVHandle SCEVTruncateExpr::get(const SCEVHandle &Op, const Type *Ty) {
   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
     return SCEVUnknown::get(
-        ConstantExpr::getTrunc(SC->getValue(), Ty));
+        ConstantExpr::getTrunc(SC->getValue(), Ty), Ty);
 
   // If the input value is a chrec scev made out of constants, truncate
   // all of the constants.
@@ -602,7 +614,7 @@
 SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) {
   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
     return SCEVUnknown::get(
-        ConstantExpr::getZExt(SC->getValue(), Ty));
+        ConstantExpr::getZExt(SC->getValue(), Ty), Ty);
 
   // FIXME: If the input value is a chrec scev, and we can prove that the value
   // did not overflow the old, smaller, value, we can zero extend all of the
@@ -617,7 +629,7 @@
 SCEVHandle SCEVSignExtendExpr::get(const SCEVHandle &Op, const Type *Ty) {
   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
     return SCEVUnknown::get(
-        ConstantExpr::getSExt(SC->getValue(), Ty));
+        ConstantExpr::getSExt(SC->getValue(), Ty), Ty);
 
   // FIXME: If the input value is a chrec scev, and we can prove that the value
   // did not overflow the old, smaller, value, we can sign extend all of the
@@ -1037,7 +1049,8 @@
     if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
       Constant *LHSCV = LHSC->getValue();
       Constant *RHSCV = RHSC->getValue();
-      return SCEVUnknown::get(ConstantExpr::getSDiv(LHSCV, RHSCV));
+      return SCEVConstant::get(cast<ConstantInt>(ConstantExpr::getSDiv(LHSCV,
+                                                                       RHSCV)));
     }
   }
 
@@ -1085,11 +1098,14 @@
   return Result;
 }
 
-SCEVHandle SCEVUnknown::get(Value *V) {
+SCEVHandle SCEVUnknown::get(Value *V, const Type *Ty) {
+  assert((V->getType() == Ty ||
+          (isa<PointerType>(V->getType()) && isa<IntegerType>(Ty))) &&
+         "Type for SCEVUnknown must match value or imply a PtrToInt cast");
   if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
     return SCEVConstant::get(CI);
   SCEVUnknown *&Result = (*SCEVUnknowns)[V];
-  if (Result == 0) Result = new SCEVUnknown(V);
+  if (Result == 0) Result = new SCEVUnknown(V, Ty);
   return Result;
 }
 
@@ -1111,6 +1127,10 @@
     ///
     LoopInfo &LI;
 
+    /// TD - The target data information for the target we are targetting.
+    ///
+    TargetData &TD;
+
     /// UnknownValue - This SCEV is used to represent unknown trip counts and
     /// things.
     SCEVHandle UnknownValue;
@@ -1131,8 +1151,8 @@
     std::map<PHINode*, Constant*> ConstantEvolutionLoopExitValue;
 
   public:
-    ScalarEvolutionsImpl(Function &f, LoopInfo &li)
-      : F(f), LI(li), UnknownValue(new SCEVCouldNotCompute()) {}
+    ScalarEvolutionsImpl(Function &f, LoopInfo &li, TargetData &td)
+      : F(f), LI(li), TD(td), UnknownValue(new SCEVCouldNotCompute()) {}
 
     /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
     /// expression and create a new one.
@@ -1179,7 +1199,7 @@
 
     /// createNodeForPHI - Provide the special handling we need to analyze PHI
     /// SCEVs.
-    SCEVHandle createNodeForPHI(PHINode *PN);
+    SCEVHandle createNodeForPHI(PHINode *PN, const Type *Ty);
 
     /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value
     /// for the specified instruction and replaces any references to the
@@ -1302,7 +1322,7 @@
 /// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
 /// a loop header, making it a potential recurrence, or it doesn't.
 ///
-SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) {
+SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN, const Type *Ty) {
   if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
     if (const Loop *L = LI.getLoopFor(PN->getParent()))
       if (L->getHeader() == PN->getParent()) {
@@ -1312,7 +1332,7 @@
         unsigned BackEdge     = IncomingEdge^1;
 
         // While we are analyzing this PHI node, handle its value symbolically.
-        SCEVHandle SymbolicName = SCEVUnknown::get(PN);
+        SCEVHandle SymbolicName = SCEVUnknown::get(PN, Ty);
         assert(Scalars.find(PN) == Scalars.end() &&
                "PHI node already processed?");
         Scalars.insert(std::make_pair(PN, SymbolicName));
@@ -1393,7 +1413,7 @@
       }
 
   // If it's not a loop phi, we can't handle it yet.
-  return SCEVUnknown::get(PN);
+  return SCEVUnknown::get(PN, Ty);
 }
 
 /// GetConstantFactor - Determine the largest constant factor that S has.  For
@@ -1459,83 +1479,153 @@
 /// Analyze the expression.
 ///
 SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
-  if (Instruction *I = dyn_cast<Instruction>(V)) {
-    switch (I->getOpcode()) {
+  const Type *IntPtrTy = TD.getIntPtrType();
+  const Type *Ty = isa<PointerType>(V->getType()) ? IntPtrTy : V->getType();
+  if (User *U = dyn_cast<User>(V)) {
+    unsigned Opcode = 0;
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
+      Opcode = CE->getOpcode();
+    } else if (Instruction *I = dyn_cast<Instruction>(V)) {
+      Opcode = I->getOpcode();
+    }
+    switch (Opcode) {
     case Instruction::Add:
-      return SCEVAddExpr::get(getSCEV(I->getOperand(0)),
-                              getSCEV(I->getOperand(1)));
+      return SCEVAddExpr::get(getSCEV(U->getOperand(0)),
+                              getSCEV(U->getOperand(1)));
     case Instruction::Mul:
-      return SCEVMulExpr::get(getSCEV(I->getOperand(0)),
-                              getSCEV(I->getOperand(1)));
+      return SCEVMulExpr::get(getSCEV(U->getOperand(0)),
+                              getSCEV(U->getOperand(1)));
     case Instruction::SDiv:
-      return SCEVSDivExpr::get(getSCEV(I->getOperand(0)),
-                              getSCEV(I->getOperand(1)));
+      return SCEVSDivExpr::get(getSCEV(U->getOperand(0)),
+                              getSCEV(U->getOperand(1)));
       break;
 
     case Instruction::Sub:
-      return SCEV::getMinusSCEV(getSCEV(I->getOperand(0)),
-                                getSCEV(I->getOperand(1)));
+      return SCEV::getMinusSCEV(getSCEV(U->getOperand(0)),
+                                getSCEV(U->getOperand(1)));
     case Instruction::Or:
       // If the RHS of the Or is a constant, we may have something like:
       // X*4+1 which got turned into X*4|1.  Handle this as an add so loop
       // optimizations will transparently handle this case.
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
-        SCEVHandle LHS = getSCEV(I->getOperand(0));
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
+        SCEVHandle LHS = getSCEV(U->getOperand(0));
         APInt CommonFact(GetConstantFactor(LHS));
         assert(!CommonFact.isMinValue() &&
                "Common factor should at least be 1!");
         if (CommonFact.ugt(CI->getValue())) {
           // If the LHS is a multiple that is larger than the RHS, use +.
           return SCEVAddExpr::get(LHS,
-                                  getSCEV(I->getOperand(1)));
+                                  getSCEV(U->getOperand(1)));
         }
       }
       break;
     case Instruction::Xor:
       // If the RHS of the xor is a signbit, then this is just an add.
       // Instcombine turns add of signbit into xor as a strength reduction step.
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
         if (CI->getValue().isSignBit())
-          return SCEVAddExpr::get(getSCEV(I->getOperand(0)),
-                                  getSCEV(I->getOperand(1)));
+          return SCEVAddExpr::get(getSCEV(U->getOperand(0)),
+                                  getSCEV(U->getOperand(1)));
       }
       break;
 
     case Instruction::Shl:
       // Turn shift left of a constant amount into a multiply.
-      if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
         uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
         Constant *X = ConstantInt::get(
           APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
-        return SCEVMulExpr::get(getSCEV(I->getOperand(0)), getSCEV(X));
+        return SCEVMulExpr::get(getSCEV(U->getOperand(0)), getSCEV(X));
       }
       break;
 
     case Instruction::Trunc:
-      return SCEVTruncateExpr::get(getSCEV(I->getOperand(0)), I->getType());
+      return SCEVTruncateExpr::get(getSCEV(U->getOperand(0)), U->getType());
 
     case Instruction::ZExt:
-      return SCEVZeroExtendExpr::get(getSCEV(I->getOperand(0)), I->getType());
+      return SCEVZeroExtendExpr::get(getSCEV(U->getOperand(0)), U->getType());
 
     case Instruction::SExt:
-      return SCEVSignExtendExpr::get(getSCEV(I->getOperand(0)), I->getType());
+      return SCEVSignExtendExpr::get(getSCEV(U->getOperand(0)), U->getType());
 
     case Instruction::BitCast:
-      // BitCasts are no-op casts so we just eliminate the cast.
-      if (I->getType()->isInteger() &&
-          I->getOperand(0)->getType()->isInteger())
-        return getSCEV(I->getOperand(0));
+      // BitCasts between integers or pointers are no-op casts so we just
+      // eliminate the cast.
+      if ((U->getType()->isInteger() ||
+           isa<PointerType>(U->getType())) &&
+          (U->getOperand(0)->getType()->isInteger() ||
+           isa<PointerType>(U->getOperand(0)->getType())))
+        return getSCEV(U->getOperand(0));
       break;
 
+    case Instruction::IntToPtr:
+      // Conversions from intptr_t to pointer is effectively a no-op cast.
+      if (U->getOperand(0)->getType() == IntPtrTy)
+        return getSCEV(U->getOperand(0));
+      break;
+
+    case Instruction::PtrToInt:
+      if (U->getType() == IntPtrTy)
+        return getSCEV(U->getOperand(0));
+      break;
+
+    case Instruction::GetElementPtr: {
+      Value *Ptr = U->getOperand(0);
+      std::vector<Value *> Indicies;
+      SCEVHandle GEPVal = SCEVUnknown::getIntegerSCEV(0, IntPtrTy);
+      for (GetElementPtrInst::op_iterator I = U->op_begin() + 1,
+                                          E = U->op_end();
+           I != E; ++I) {
+        Value *Index = cast<Value>(I);
+        Indicies.push_back(Index);
+        SCEVHandle Q = getSCEV(Index);
+        if (!isa<PointerType>(Q->getType()))
+          Q = getTruncateOrSignExtend(Q, IntPtrTy);
+        const Type *BaseType =
+          GetElementPtrInst::getIndexedType(Ptr->getType(),
+                                            &Indicies[0],
+                                            Indicies.size() - 1,
+                                            true);
+        // Is this index for a struct or an array? Note that the
+        // first index in a getelementptr is always an array index.
+        if (Indicies.size() > 1 && isa<StructType>(BaseType)) {
+          const StructType *STy = cast<StructType>(BaseType);
+          unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
+
+          // Get structure layout information...
+          const StructLayout *Layout = TD.getStructLayout(STy);
+
+          // Add in the offset, as calculated by the structure layout info...
+          uint64_t Offset = Layout->getElementOffset(FieldNo);
+
+          Q = SCEVUnknown::getIntegerSCEV(Offset, IntPtrTy);
+        } else {
+          const Type *ScaleType =
+            GetElementPtrInst::getIndexedType(Ptr->getType(),
+                                              &Indicies[0],
+                                              Indicies.size(),
+                                              true);
+          SCEVHandle ScaleSCEV =
+            SCEVUnknown::getIntegerSCEV(TD.getTypeSize(ScaleType), IntPtrTy);
+          Q = SCEVMulExpr::get(Q, ScaleSCEV);
+        }
+
+        GEPVal = SCEVAddExpr::get(GEPVal, Q);
+      }
+      SCEVHandle PtrSCEV = getSCEV(Ptr);
+      SCEVHandle Result = SCEVAddExpr::get(PtrSCEV, GEPVal);
+      return Result;
+    }
+
     case Instruction::PHI:
-      return createNodeForPHI(cast<PHINode>(I));
+      return createNodeForPHI(cast<PHINode>(U), Ty);
 
     default: // We cannot analyze this expression.
       break;
     }
   }
 
-  return SCEVUnknown::get(V);
+  return SCEVUnknown::get(V, Ty);
 }
 
 
@@ -2040,7 +2130,7 @@
               Constant *RV = getConstantEvolutionLoopExitValue(PN,
                                                     ICC->getValue()->getValue(),
                                                                LI);
-              if (RV) return SCEVUnknown::get(RV);
+              if (RV) return SCEVUnknown::get(RV, RV->getType());
             }
           }
 
@@ -2074,7 +2164,7 @@
           }
         }
         Constant *C =ConstantFoldInstOperands(I, &Operands[0], Operands.size());
-        return SCEVUnknown::get(C);
+        return SCEVUnknown::get(C, C->getType());
       }
     }
 
@@ -2246,7 +2336,7 @@
         Constant *Rem = ConstantExpr::getSRem(StartNegC, StepC->getValue());
         if (Rem->isNullValue()) {
           Constant *Result =ConstantExpr::getSDiv(StartNegC,StepC->getValue());
-          return SCEVUnknown::get(Result);
+          return SCEVUnknown::get(Result, Result->getType());
         }
       }
     }
@@ -2546,7 +2636,9 @@
 //===----------------------------------------------------------------------===//
 
 bool ScalarEvolution::runOnFunction(Function &F) {
-  Impl = new ScalarEvolutionsImpl(F, getAnalysis<LoopInfo>());
+  Impl = new ScalarEvolutionsImpl(F,
+                                  getAnalysis<LoopInfo>(),
+                                  getAnalysis<TargetData>());
   return false;
 }
 
@@ -2558,6 +2650,7 @@
 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
   AU.addRequiredTransitive<LoopInfo>();
+  AU.addRequiredTransitive<TargetData>();
 }
 
 SCEVHandle ScalarEvolution::getSCEV(Value *V) const {
Index: lib/Transforms/Scalar/LoopStrengthReduce.cpp
===================================================================
--- lib/Transforms/Scalar/LoopStrengthReduce.cpp	(revision 40735)
+++ lib/Transforms/Scalar/LoopStrengthReduce.cpp	(working copy)
@@ -130,11 +130,6 @@
     /// dependent on random ordering of pointers in the process.
     std::vector<SCEVHandle> StrideOrder;
 
-    /// CastedValues - As we need to cast values to uintptr_t, this keeps track
-    /// of the casted version of each value.  This is accessed by
-    /// getCastedVersionOf.
-    std::map<Value*, Value*> CastedPointers;
-
     /// DeadInsts - Keep track of instructions we may have made dead, so that
     /// we can remove them after we are done working.
     std::set<Instruction*> DeadInsts;
@@ -166,13 +161,9 @@
       AU.addRequired<ScalarEvolution>();
     }
     
-    /// getCastedVersionOf - Return the specified value casted to uintptr_t.
-    ///
-    Value *getCastedVersionOf(Instruction::CastOps opcode, Value *V);
 private:
     bool AddUsersIfInteresting(Instruction *I, Loop *L,
                                std::set<Instruction*> &Processed);
-    SCEVHandle GetExpressionSCEV(Instruction *E, Loop *L);
 
     void OptimizeIndvars(Loop *L);
     bool FindIVForUser(ICmpInst *Cond, IVStrideUse *&CondUse,
@@ -196,24 +187,7 @@
   return new LoopStrengthReduce(TLI);
 }
 
-/// getCastedVersionOf - Return the specified value casted to uintptr_t. This
-/// assumes that the Value* V is of integer or pointer type only.
-///
-Value *LoopStrengthReduce::getCastedVersionOf(Instruction::CastOps opcode, 
-                                              Value *V) {
-  if (V->getType() == UIntPtrTy) return V;
-  if (Constant *CB = dyn_cast<Constant>(V))
-    return ConstantExpr::getCast(opcode, CB, UIntPtrTy);
 
-  Value *&New = CastedPointers[V];
-  if (New) return New;
-  
-  New = SCEVExpander::InsertCastOfTo(opcode, V, UIntPtrTy);
-  DeadInsts.insert(cast<Instruction>(New));
-  return New;
-}
-
-
 /// DeleteTriviallyDeadInstructions - If any of the instructions is the
 /// specified set are trivially dead, delete them and see if this makes any of
 /// their operands subsequently dead.
@@ -234,71 +208,6 @@
 }
 
 
-/// GetExpressionSCEV - Compute and return the SCEV for the specified
-/// instruction.
-SCEVHandle LoopStrengthReduce::GetExpressionSCEV(Instruction *Exp, Loop *L) {
-  // Pointer to pointer bitcast instructions return the same value as their
-  // operand.
-  if (BitCastInst *BCI = dyn_cast<BitCastInst>(Exp)) {
-    if (SE->hasSCEV(BCI) || !isa<Instruction>(BCI->getOperand(0)))
-      return SE->getSCEV(BCI);
-    SCEVHandle R = GetExpressionSCEV(cast<Instruction>(BCI->getOperand(0)), L);
-    SE->setSCEV(BCI, R);
-    return R;
-  }
-
-  // Scalar Evolutions doesn't know how to compute SCEV's for GEP instructions.
-  // If this is a GEP that SE doesn't know about, compute it now and insert it.
-  // If this is not a GEP, or if we have already done this computation, just let
-  // SE figure it out.
-  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Exp);
-  if (!GEP || SE->hasSCEV(GEP))
-    return SE->getSCEV(Exp);
-    
-  // Analyze all of the subscripts of this getelementptr instruction, looking
-  // for uses that are determined by the trip count of L.  First, skip all
-  // operands the are not dependent on the IV.
-
-  // Build up the base expression.  Insert an LLVM cast of the pointer to
-  // uintptr_t first.
-  SCEVHandle GEPVal = SCEVUnknown::get(
-      getCastedVersionOf(Instruction::PtrToInt, GEP->getOperand(0)));
-
-  gep_type_iterator GTI = gep_type_begin(GEP);
-  
-  for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
-    // If this is a use of a recurrence that we can analyze, and it comes before
-    // Op does in the GEP operand list, we will handle this when we process this
-    // operand.
-    if (const StructType *STy = dyn_cast<StructType>(*GTI)) {
-      const StructLayout *SL = TD->getStructLayout(STy);
-      unsigned Idx = cast<ConstantInt>(GEP->getOperand(i))->getZExtValue();
-      uint64_t Offset = SL->getElementOffset(Idx);
-      GEPVal = SCEVAddExpr::get(GEPVal,
-                                SCEVUnknown::getIntegerSCEV(Offset, UIntPtrTy));
-    } else {
-      unsigned GEPOpiBits = 
-        GEP->getOperand(i)->getType()->getPrimitiveSizeInBits();
-      unsigned IntPtrBits = UIntPtrTy->getPrimitiveSizeInBits();
-      Instruction::CastOps opcode = (GEPOpiBits < IntPtrBits ? 
-          Instruction::SExt : (GEPOpiBits > IntPtrBits ? Instruction::Trunc :
-            Instruction::BitCast));
-      Value *OpVal = getCastedVersionOf(opcode, GEP->getOperand(i));
-      SCEVHandle Idx = SE->getSCEV(OpVal);
-
-      uint64_t TypeSize = TD->getTypeSize(GTI.getIndexedType());
-      if (TypeSize != 1)
-        Idx = SCEVMulExpr::get(Idx,
-                               SCEVConstant::get(ConstantInt::get(UIntPtrTy,
-                                                                   TypeSize)));
-      GEPVal = SCEVAddExpr::get(GEPVal, Idx);
-    }
-  }
-
-  SE->setSCEV(GEP, GEPVal);
-  return GEPVal;
-}
-
 /// getSCEVStartAndStride - Compute the start and stride of this expression,
 /// returning false if the expression is not a start/stride pair, or true if it
 /// is.  The stride must be a loop invariant expression, but the start may be
@@ -410,7 +319,7 @@
     return true;    // Instruction already handled.
   
   // Get the symbolic expression for this instruction.
-  SCEVHandle ISE = GetExpressionSCEV(I, L);
+  SCEVHandle ISE = SE->getSCEV(I);
   if (isa<SCEVCouldNotCompute>(ISE)) return false;
   
   // Get the start and stride for this expression.
@@ -565,7 +474,8 @@
     IP = Rewriter.getInsertionPoint();
   
   // Always emit the immediate (if non-zero) into the same block as the user.
-  SCEVHandle NewValSCEV = SCEVAddExpr::get(SCEVUnknown::get(Base), Imm);
+  SCEVHandle NewValSCEV =
+    SCEVAddExpr::get(SCEVUnknown::get(Base, NewBase->getType()), Imm);
   return Rewriter.expandCodeFor(NewValSCEV, IP);
   
 }
@@ -1150,10 +1060,10 @@
 
     // Emit the increment of the base value before the terminator of the loop
     // latch block, and add it to the Phi node.
-    SCEVHandle IncExp = SCEVUnknown::get(StrideV);
+    SCEVHandle IncExp = SCEVUnknown::get(StrideV, ReplacedTy);
     if (isNegative)
       IncExp = SCEV::getNegativeSCEV(IncExp);
-    IncExp = SCEVAddExpr::get(SCEVUnknown::get(NewPHI), IncExp);
+    IncExp = SCEVAddExpr::get(SCEVUnknown::get(NewPHI, ReplacedTy), IncExp);
   
     IncV = Rewriter.expandCodeFor(IncExp, LatchBlock->getTerminator());
     IncV->setName(NewPHI->getName()+".inc");
@@ -1167,7 +1077,8 @@
     Constant *C = dyn_cast<Constant>(CommonBaseV);
     if (!C ||
         (!C->isNullValue() &&
-         !isTargetConstant(SCEVUnknown::get(CommonBaseV), ReplacedTy, TLI)))
+         !isTargetConstant(SCEVUnknown::get(CommonBaseV, ReplacedTy),
+                           ReplacedTy, TLI)))
       // We want the common base emitted into the preheader! This is just
       // using cast as a copy so BitCast (no-op cast) is appropriate
       CommonBaseV = new BitCastInst(CommonBaseV, CommonBaseV->getType(), 
@@ -1196,7 +1107,7 @@
     // Get a base value.
     SCEVHandle Base = UsersToProcess[i].Base;
     
-    // Compact everything with this base to be consequetive with this one.
+    // Compact everything with this base to be consecutive with this one.
     for (unsigned j = i+1; j != e; ++j) {
       if (UsersToProcess[j].Base == Base) {
         std::swap(UsersToProcess[i+1], UsersToProcess[j]);
@@ -1256,7 +1167,7 @@
         RewriteOp = SCEVExpander::InsertCastOfTo(opcode, RewriteOp, ReplacedTy);
       }
 
-      SCEVHandle RewriteExpr = SCEVUnknown::get(RewriteOp);
+      SCEVHandle RewriteExpr = SCEVUnknown::get(RewriteOp, ReplacedTy);
 
       // Clear the SCEVExpander's expression map so that we are guaranteed
       // to have the code emitted where we expect it.
@@ -1276,14 +1187,16 @@
         if (!isa<ConstantInt>(CommonBaseV) ||
             !cast<ConstantInt>(CommonBaseV)->isZero())
           RewriteExpr = SCEVAddExpr::get(RewriteExpr,
-                                         SCEVUnknown::get(CommonBaseV));
+                                         SCEVUnknown::get(CommonBaseV,
+                                                          ReplacedTy));
       }
 
       // Now that we know what we need to do, insert code before User for the
       // immediate and any loop-variant expressions.
       if (!isa<ConstantInt>(BaseV) || !cast<ConstantInt>(BaseV)->isZero())
         // Add BaseV to the PHI value if needed.
-        RewriteExpr = SCEVAddExpr::get(RewriteExpr, SCEVUnknown::get(BaseV));
+        RewriteExpr = SCEVAddExpr::get(RewriteExpr,
+                                       SCEVUnknown::get(BaseV, ReplacedTy));
 
       User.RewriteInstructionToUseNewBase(RewriteExpr, Rewriter, L, this);
 
@@ -1407,6 +1320,14 @@
   };
 }
 
+/// Skip past effective no-op users and return the eventual interesting user.
+static Value *SkipNoopUsers(Value *V, const Type *IntPtrTy) {
+  while (isa<CastInst>(V) && cast<CastInst>(V)->isNoopCast(IntPtrTy) &&
+         V->hasOneUse())
+    V = *cast<CastInst>(V)->use_begin();
+  return V;
+}
+
 bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) {
 
   LI = &getAnalysis<LoopInfo>();
@@ -1489,10 +1410,13 @@
       // FIXME: this needs to eliminate an induction variable even if it's being
       // compared against some value to decide loop termination.
       if (PN->hasOneUse()) {
-        Instruction *BO = dyn_cast<Instruction>(*PN->use_begin());
-        if (BO && (isa<BinaryOperator>(BO) || isa<CmpInst>(BO))) {
-          if (BO->hasOneUse() && PN == *(BO->use_begin())) {
-            DeadInsts.insert(BO);
+        Value *BO = SkipNoopUsers(*PN->use_begin(), UIntPtrTy);
+        if (BO->hasOneUse() &&
+            (isa<BinaryOperator>(BO) || isa<CmpInst>(BO) ||
+             isa<GetElementPtrInst>(BO))) {
+          Value *V = SkipNoopUsers(*(BO->use_begin()), UIntPtrTy);
+          if (PN == V) {
+            DeadInsts.insert(cast<Instruction>(BO));
             // Break the cycle, then delete the PHI.
             PN->replaceAllUsesWith(UndefValue::get(PN->getType()));
             SE->deleteValueFromRecords(PN);
@@ -1504,7 +1428,6 @@
     DeleteTriviallyDeadInstructions(DeadInsts);
   }
 
-  CastedPointers.clear();
   IVUsesByStride.clear();
   StrideOrder.clear();
   return false;


More information about the llvm-commits mailing list