[llvm-commits] [llvm] r52615 - /llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Dan Gohman gohman at apple.com
Sun Jun 22 12:56:46 PDT 2008


Author: djg
Date: Sun Jun 22 14:56:46 2008
New Revision: 52615

URL: http://llvm.org/viewvc/llvm-project?rev=52615&view=rev
Log:
Generalize createSCEV to be able to form SCEV expressions from
ConstantExprs.

Modified:
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=52615&r1=52614&r2=52615&view=diff

==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Sun Jun 22 14:56:46 2008
@@ -1704,118 +1704,125 @@
   if (!isa<IntegerType>(V->getType()))
     return SE.getUnknown(V);
     
-  if (Instruction *I = dyn_cast<Instruction>(V)) {
-    switch (I->getOpcode()) {
-    case Instruction::Add:
-      return SE.getAddExpr(getSCEV(I->getOperand(0)),
-                           getSCEV(I->getOperand(1)));
-    case Instruction::Mul:
-      return SE.getMulExpr(getSCEV(I->getOperand(0)),
-                           getSCEV(I->getOperand(1)));
-    case Instruction::UDiv:
-      return SE.getUDivExpr(getSCEV(I->getOperand(0)),
-                            getSCEV(I->getOperand(1)));
-    case Instruction::Sub:
-      return SE.getMinusSCEV(getSCEV(I->getOperand(0)),
-                             getSCEV(I->getOperand(1)));
-    case Instruction::Or:
-      // If the RHS of the Or is a constant, we may have something like:
-      // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
-      // optimizations will transparently handle this case.
-      //
-      // In order for this transformation to be safe, the LHS must be of the
-      // form X*(2^n) and the Or constant must be less than 2^n.
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
-        SCEVHandle LHS = getSCEV(I->getOperand(0));
-        const APInt &CIVal = CI->getValue();
-        if (GetMinTrailingZeros(LHS) >=
-            (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
-          return SE.getAddExpr(LHS, getSCEV(I->getOperand(1)));
-      }
-      break;
-    case Instruction::Xor:
-      // If the RHS of the xor is a signbit, then this is just an add.
-      // Instcombine turns add of signbit into xor as a strength reduction step.
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
-        if (CI->getValue().isSignBit())
-          return SE.getAddExpr(getSCEV(I->getOperand(0)),
-                               getSCEV(I->getOperand(1)));
-        else if (CI->isAllOnesValue())
-          return SE.getNotSCEV(getSCEV(I->getOperand(0)));
-      }
-      break;
+  unsigned Opcode = Instruction::UserOp1;
+  if (Instruction *I = dyn_cast<Instruction>(V))
+    Opcode = I->getOpcode();
+  else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
+    Opcode = CE->getOpcode();
+  else
+    return SE.getUnknown(V);
 
-    case Instruction::Shl:
-      // Turn shift left of a constant amount into a multiply.
-      if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
-        uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
-        Constant *X = ConstantInt::get(
-          APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
-        return SE.getMulExpr(getSCEV(I->getOperand(0)), getSCEV(X));
-      }
-      break;
+  User *U = cast<User>(V);
+  switch (Opcode) {
+  case Instruction::Add:
+    return SE.getAddExpr(getSCEV(U->getOperand(0)),
+                         getSCEV(U->getOperand(1)));
+  case Instruction::Mul:
+    return SE.getMulExpr(getSCEV(U->getOperand(0)),
+                         getSCEV(U->getOperand(1)));
+  case Instruction::UDiv:
+    return SE.getUDivExpr(getSCEV(U->getOperand(0)),
+                          getSCEV(U->getOperand(1)));
+  case Instruction::Sub:
+    return SE.getMinusSCEV(getSCEV(U->getOperand(0)),
+                           getSCEV(U->getOperand(1)));
+  case Instruction::Or:
+    // If the RHS of the Or is a constant, we may have something like:
+    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
+    // optimizations will transparently handle this case.
+    //
+    // In order for this transformation to be safe, the LHS must be of the
+    // form X*(2^n) and the Or constant must be less than 2^n.
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
+      SCEVHandle LHS = getSCEV(U->getOperand(0));
+      const APInt &CIVal = CI->getValue();
+      if (GetMinTrailingZeros(LHS) >=
+          (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
+        return SE.getAddExpr(LHS, getSCEV(U->getOperand(1)));
+    }
+    break;
+  case Instruction::Xor:
+    // If the RHS of the xor is a signbit, then this is just an add.
+    // Instcombine turns add of signbit into xor as a strength reduction step.
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
+      if (CI->getValue().isSignBit())
+        return SE.getAddExpr(getSCEV(U->getOperand(0)),
+                             getSCEV(U->getOperand(1)));
+      else if (CI->isAllOnesValue())
+        return SE.getNotSCEV(getSCEV(U->getOperand(0)));
+    }
+    break;
 
-    case Instruction::Trunc:
-      return SE.getTruncateExpr(getSCEV(I->getOperand(0)), I->getType());
+  case Instruction::Shl:
+    // Turn shift left of a constant amount into a multiply.
+    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
+      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
+      Constant *X = ConstantInt::get(
+        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
+      return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
+    }
+    break;
 
-    case Instruction::ZExt:
-      return SE.getZeroExtendExpr(getSCEV(I->getOperand(0)), I->getType());
+  case Instruction::Trunc:
+    return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
 
-    case Instruction::SExt:
-      return SE.getSignExtendExpr(getSCEV(I->getOperand(0)), I->getType());
-
-    case Instruction::BitCast:
-      // BitCasts are no-op casts so we just eliminate the cast.
-      if (I->getType()->isInteger() &&
-          I->getOperand(0)->getType()->isInteger())
-        return getSCEV(I->getOperand(0));
-      break;
-
-    case Instruction::PHI:
-      return createNodeForPHI(cast<PHINode>(I));
-
-    case Instruction::Select:
-      // This could be a smax or umax that was lowered earlier.
-      // Try to recover it.
-      if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) {
-        Value *LHS = ICI->getOperand(0);
-        Value *RHS = ICI->getOperand(1);
-        switch (ICI->getPredicate()) {
-        case ICmpInst::ICMP_SLT:
-        case ICmpInst::ICMP_SLE:
-          std::swap(LHS, RHS);
-          // fall through
-        case ICmpInst::ICMP_SGT:
-        case ICmpInst::ICMP_SGE:
-          if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
-            return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
-          else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
-            // -smax(-x, -y) == smin(x, y).
-            return SE.getNegativeSCEV(SE.getSMaxExpr(
-                                          SE.getNegativeSCEV(getSCEV(LHS)),
-                                          SE.getNegativeSCEV(getSCEV(RHS))));
-          break;
-        case ICmpInst::ICMP_ULT:
-        case ICmpInst::ICMP_ULE:
-          std::swap(LHS, RHS);
-          // fall through
-        case ICmpInst::ICMP_UGT:
-        case ICmpInst::ICMP_UGE:
-          if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
-            return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
-          else if (LHS == I->getOperand(2) && RHS == I->getOperand(1))
-            // ~umax(~x, ~y) == umin(x, y)
-            return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
-                                                SE.getNotSCEV(getSCEV(RHS))));
-          break;
-        default:
-          break;
-        }
-      }
+  case Instruction::ZExt:
+    return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
+
+  case Instruction::SExt:
+    return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
+
+  case Instruction::BitCast:
+    // BitCasts are no-op casts so we just eliminate the cast.
+    if (U->getType()->isInteger() &&
+        U->getOperand(0)->getType()->isInteger())
+      return getSCEV(U->getOperand(0));
+    break;
+
+  case Instruction::PHI:
+    return createNodeForPHI(cast<PHINode>(U));
 
-    default: // We cannot analyze this expression.
-      break;
+  case Instruction::Select:
+    // This could be a smax or umax that was lowered earlier.
+    // Try to recover it.
+    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
+      Value *LHS = ICI->getOperand(0);
+      Value *RHS = ICI->getOperand(1);
+      switch (ICI->getPredicate()) {
+      case ICmpInst::ICMP_SLT:
+      case ICmpInst::ICMP_SLE:
+        std::swap(LHS, RHS);
+        // fall through
+      case ICmpInst::ICMP_SGT:
+      case ICmpInst::ICMP_SGE:
+        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
+          return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
+        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
+          // -smax(-x, -y) == smin(x, y).
+          return SE.getNegativeSCEV(SE.getSMaxExpr(
+                                        SE.getNegativeSCEV(getSCEV(LHS)),
+                                        SE.getNegativeSCEV(getSCEV(RHS))));
+        break;
+      case ICmpInst::ICMP_ULT:
+      case ICmpInst::ICMP_ULE:
+        std::swap(LHS, RHS);
+        // fall through
+      case ICmpInst::ICMP_UGT:
+      case ICmpInst::ICMP_UGE:
+        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
+          return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
+        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
+          // ~umax(~x, ~y) == umin(x, y)
+          return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
+                                              SE.getNotSCEV(getSCEV(RHS))));
+        break;
+      default:
+        break;
+      }
     }
+
+  default: // We cannot analyze this expression.
+    break;
   }
 
   return SE.getUnknown(V);





More information about the llvm-commits mailing list