[llvm] r300494 - [SCEV] Add a local cache for getZeroExtendExpr and getSignExtendExpr to prevent
    Wei Mi via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Mon Apr 17 13:40:06 PDT 2017
    
    
  
Author: wmi
Date: Mon Apr 17 15:40:05 2017
New Revision: 300494
URL: http://llvm.org/viewvc/llvm-project?rev=300494&view=rev
Log:
[SCEV] Add a local cache for getZeroExtendExpr and getSignExtendExpr to prevent
the exponential behavior.
The patch is to fix PR32043. Functions getZeroExtendExpr and getSignExtendExpr
may call themselves recursively more than once. This is potentially a 2^N
complexity behavior. The exponential behavior was not commonly exposed before
because of existing global cache mechnism like UniqueSCEVs or some early return
mechanism when flags FlagNSW or FlagNUW are seen. However, we still have case
which can expose the exponential behavior, like the case in PR32043, so we add
a local cache in getZeroExtendExpr and getSignExtendExpr. If the input of the
functions -- SCEV and type pair have been seen before, we can find the extended
expression directly in the local cache.
Differential Revision: https://reviews.llvm.org/D30350
Modified:
    llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp
    llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp
Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolution.h?rev=300494&r1=300493&r2=300494&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h (original)
+++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h Mon Apr 17 15:40:05 2017
@@ -1159,8 +1159,20 @@ public:
   const SCEV *getConstant(const APInt &Val);
   const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false);
   const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty);
+
+  typedef SmallDenseMap<std::pair<const SCEV *, Type *>, const SCEV *, 8>
+      ExtendCacheTy;
   const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty);
+  const SCEV *getZeroExtendExprCached(const SCEV *Op, Type *Ty,
+                                      ExtendCacheTy &Cache);
+  const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
+                                    ExtendCacheTy &Cache);
+
   const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty);
+  const SCEV *getSignExtendExprCached(const SCEV *Op, Type *Ty,
+                                      ExtendCacheTy &Cache);
+  const SCEV *getSignExtendExprImpl(const SCEV *Op, Type *Ty,
+                                    ExtendCacheTy &Cache);
   const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty);
   const SCEV *getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
                          SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap,
Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=300494&r1=300493&r2=300494&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Mon Apr 17 15:40:05 2017
@@ -1276,7 +1276,8 @@ static const SCEV *getUnsignedOverflowLi
 namespace {
 
 struct ExtendOpTraitsBase {
-  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *);
+  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(
+      const SCEV *, Type *, ScalarEvolution::ExtendCacheTy &Cache);
 };
 
 // Used to make code generic over signed and unsigned overflow.
@@ -1305,8 +1306,9 @@ struct ExtendOpTraits<SCEVSignExtendExpr
   }
 };
 
-const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
-    SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
+const ExtendOpTraitsBase::GetExtendExprTy
+    ExtendOpTraits<SCEVSignExtendExpr>::GetExtendExpr =
+        &ScalarEvolution::getSignExtendExprCached;
 
 template <>
 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
@@ -1321,8 +1323,9 @@ struct ExtendOpTraits<SCEVZeroExtendExpr
   }
 };
 
-const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
-    SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
+const ExtendOpTraitsBase::GetExtendExprTy
+    ExtendOpTraits<SCEVZeroExtendExpr>::GetExtendExpr =
+        &ScalarEvolution::getZeroExtendExprCached;
 }
 
 // The recurrence AR has been shown to have no signed/unsigned wrap or something
@@ -1334,7 +1337,8 @@ const ExtendOpTraitsBase::GetExtendExprT
 // "sext/zext(PostIncAR)"
 template <typename ExtendOpTy>
 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
-                                        ScalarEvolution *SE) {
+                                        ScalarEvolution *SE,
+                                        ScalarEvolution::ExtendCacheTy &Cache) {
   auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
 
@@ -1381,9 +1385,9 @@ static const SCEV *getPreStartForExtend(
   unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
   Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
   const SCEV *OperandExtendedStart =
-      SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy),
-                     (SE->*GetExtendExpr)(Step, WideTy));
-  if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) {
+      SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Cache),
+                     (SE->*GetExtendExpr)(Step, WideTy, Cache));
+  if ((SE->*GetExtendExpr)(Start, WideTy, Cache) == OperandExtendedStart) {
     if (PreAR && AR->getNoWrapFlags(WrapType)) {
       // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
       // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
@@ -1408,15 +1412,17 @@ static const SCEV *getPreStartForExtend(
 // Get the normalized zero or sign extended expression for this AddRec's Start.
 template <typename ExtendOpTy>
 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
-                                        ScalarEvolution *SE) {
+                                        ScalarEvolution *SE,
+                                        ScalarEvolution::ExtendCacheTy &Cache) {
   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
 
-  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE);
+  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Cache);
   if (!PreStart)
-    return (SE->*GetExtendExpr)(AR->getStart(), Ty);
+    return (SE->*GetExtendExpr)(AR->getStart(), Ty, Cache);
 
-  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty),
-                        (SE->*GetExtendExpr)(PreStart, Ty));
+  return SE->getAddExpr(
+      (SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, Cache),
+      (SE->*GetExtendExpr)(PreStart, Ty, Cache));
 }
 
 // Try to prove away overflow by looking at "nearby" add recurrences.  A
@@ -1496,8 +1502,30 @@ bool ScalarEvolution::proveNoWrapByVaryi
   return false;
 }
 
-const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
-                                               Type *Ty) {
+const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) {
+  // Use the local cache to prevent exponential behavior of
+  // getZeroExtendExprImpl.
+  ExtendCacheTy Cache;
+  return getZeroExtendExprCached(Op, Ty, Cache);
+}
+
+/// Query \p Cache before calling getZeroExtendExprImpl. If there is no
+/// related entry in the \p Cache, call getZeroExtendExprImpl and save
+/// the result in the \p Cache.
+const SCEV *ScalarEvolution::getZeroExtendExprCached(const SCEV *Op, Type *Ty,
+                                                     ExtendCacheTy &Cache) {
+  auto It = Cache.find({Op, Ty});
+  if (It != Cache.end())
+    return It->second;
+  const SCEV *ZExt = getZeroExtendExprImpl(Op, Ty, Cache);
+  auto InsertResult = Cache.insert({{Op, Ty}, ZExt});
+  assert(InsertResult.second && "Expect the key was not in the cache");
+  return ZExt;
+}
+
+/// The real implementation of getZeroExtendExpr.
+const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
+                                                   ExtendCacheTy &Cache) {
   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
          "This is not an extending conversion!");
   assert(isSCEVable(Ty) &&
@@ -1507,11 +1535,11 @@ const SCEV *ScalarEvolution::getZeroExte
   // Fold if the operand is constant.
   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
     return getConstant(
-      cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
+        cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
 
   // zext(zext(x)) --> zext(x)
   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
-    return getZeroExtendExpr(SZ->getOperand(), Ty);
+    return getZeroExtendExprCached(SZ->getOperand(), Ty, Cache);
 
   // Before doing any expensive analysis, check to see if we've already
   // computed a SCEV for this Op and Ty.
@@ -1555,8 +1583,8 @@ const SCEV *ScalarEvolution::getZeroExte
       // we don't need to do any further analysis.
       if (AR->hasNoUnsignedWrap())
         return getAddRecExpr(
-            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
-            getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
+            getZeroExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags());
 
       // Check whether the backedge-taken count is SCEVCouldNotCompute.
       // Note that this serves two purposes: It filters out loops that are
@@ -1581,21 +1609,22 @@ const SCEV *ScalarEvolution::getZeroExte
           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
           // Check whether Start+Step*MaxBECount has no unsigned overflow.
           const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
-          const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy);
-          const SCEV *WideStart = getZeroExtendExpr(Start, WideTy);
+          const SCEV *ZAdd =
+              getZeroExtendExprCached(getAddExpr(Start, ZMul), WideTy, Cache);
+          const SCEV *WideStart = getZeroExtendExprCached(Start, WideTy, Cache);
           const SCEV *WideMaxBECount =
-            getZeroExtendExpr(CastedMaxBECount, WideTy);
-          const SCEV *OperandExtendedAdd =
-            getAddExpr(WideStart,
-                       getMulExpr(WideMaxBECount,
-                                  getZeroExtendExpr(Step, WideTy)));
+              getZeroExtendExprCached(CastedMaxBECount, WideTy, Cache);
+          const SCEV *OperandExtendedAdd = getAddExpr(
+              WideStart, getMulExpr(WideMaxBECount, getZeroExtendExprCached(
+                                                        Step, WideTy, Cache)));
           if (ZAdd == OperandExtendedAdd) {
             // Cache knowledge of AR NUW, which is propagated to this AddRec.
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
-                getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
+                getZeroExtendExprCached(Step, Ty, Cache), L,
+                AR->getNoWrapFlags());
           }
           // Similar to above, only this time treat the step value as signed.
           // This covers loops that count down.
@@ -1609,7 +1638,7 @@ const SCEV *ScalarEvolution::getZeroExte
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
+                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
           }
         }
@@ -1641,8 +1670,9 @@ const SCEV *ScalarEvolution::getZeroExte
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
-                getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
+                getZeroExtendExprCached(Step, Ty, Cache), L,
+                AR->getNoWrapFlags());
           }
         } else if (isKnownNegative(Step)) {
           const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
@@ -1657,7 +1687,7 @@ const SCEV *ScalarEvolution::getZeroExte
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
+                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
           }
         }
@@ -1666,8 +1696,8 @@ const SCEV *ScalarEvolution::getZeroExte
       if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
         return getAddRecExpr(
-            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
-            getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
+            getZeroExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags());
       }
     }
 
@@ -1678,7 +1708,7 @@ const SCEV *ScalarEvolution::getZeroExte
       // commute the zero extension with the addition operation.
       SmallVector<const SCEV *, 4> Ops;
       for (const auto *Op : SA->operands())
-        Ops.push_back(getZeroExtendExpr(Op, Ty));
+        Ops.push_back(getZeroExtendExprCached(Op, Ty, Cache));
       return getAddExpr(Ops, SCEV::FlagNUW);
     }
   }
@@ -1692,8 +1722,30 @@ const SCEV *ScalarEvolution::getZeroExte
   return S;
 }
 
-const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
-                                               Type *Ty) {
+const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) {
+  // Use the local cache to prevent exponential behavior of
+  // getSignExtendExprImpl.
+  ExtendCacheTy Cache;
+  return getSignExtendExprCached(Op, Ty, Cache);
+}
+
+/// Query \p Cache before calling getSignExtendExprImpl. If there is no
+/// related entry in the \p Cache, call getSignExtendExprImpl and save
+/// the result in the \p Cache.
+const SCEV *ScalarEvolution::getSignExtendExprCached(const SCEV *Op, Type *Ty,
+                                                     ExtendCacheTy &Cache) {
+  auto It = Cache.find({Op, Ty});
+  if (It != Cache.end())
+    return It->second;
+  const SCEV *SExt = getSignExtendExprImpl(Op, Ty, Cache);
+  auto InsertResult = Cache.insert({{Op, Ty}, SExt});
+  assert(InsertResult.second && "Expect the key was not in the cache");
+  return SExt;
+}
+
+/// The real implementation of getSignExtendExpr.
+const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
+                                                   ExtendCacheTy &Cache) {
   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
          "This is not an extending conversion!");
   assert(isSCEVable(Ty) &&
@@ -1703,11 +1755,11 @@ const SCEV *ScalarEvolution::getSignExte
   // Fold if the operand is constant.
   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
     return getConstant(
-      cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
+        cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
 
   // sext(sext(x)) --> sext(x)
   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
-    return getSignExtendExpr(SS->getOperand(), Ty);
+    return getSignExtendExprCached(SS->getOperand(), Ty, Cache);
 
   // sext(zext(x)) --> zext(x)
   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
@@ -1746,8 +1798,8 @@ const SCEV *ScalarEvolution::getSignExte
           const APInt &C2 = SC2->getAPInt();
           if (C1.isStrictlyPositive() && C2.isStrictlyPositive() &&
               C2.ugt(C1) && C2.isPowerOf2())
-            return getAddExpr(getSignExtendExpr(SC1, Ty),
-                              getSignExtendExpr(SMul, Ty));
+            return getAddExpr(getSignExtendExprCached(SC1, Ty, Cache),
+                              getSignExtendExprCached(SMul, Ty, Cache));
         }
       }
     }
@@ -1758,7 +1810,7 @@ const SCEV *ScalarEvolution::getSignExte
       // commute the sign extension with the addition operation.
       SmallVector<const SCEV *, 4> Ops;
       for (const auto *Op : SA->operands())
-        Ops.push_back(getSignExtendExpr(Op, Ty));
+        Ops.push_back(getSignExtendExprCached(Op, Ty, Cache));
       return getAddExpr(Ops, SCEV::FlagNSW);
     }
   }
@@ -1782,8 +1834,8 @@ const SCEV *ScalarEvolution::getSignExte
       // we don't need to do any further analysis.
       if (AR->hasNoSignedWrap())
         return getAddRecExpr(
-            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
-            getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW);
+            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
+            getSignExtendExprCached(Step, Ty, Cache), L, SCEV::FlagNSW);
 
       // Check whether the backedge-taken count is SCEVCouldNotCompute.
       // Note that this serves two purposes: It filters out loops that are
@@ -1808,21 +1860,22 @@ const SCEV *ScalarEvolution::getSignExte
           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
           // Check whether Start+Step*MaxBECount has no signed overflow.
           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
-          const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy);
-          const SCEV *WideStart = getSignExtendExpr(Start, WideTy);
+          const SCEV *SAdd =
+              getSignExtendExprCached(getAddExpr(Start, SMul), WideTy, Cache);
+          const SCEV *WideStart = getSignExtendExprCached(Start, WideTy, Cache);
           const SCEV *WideMaxBECount =
-            getZeroExtendExpr(CastedMaxBECount, WideTy);
-          const SCEV *OperandExtendedAdd =
-            getAddExpr(WideStart,
-                       getMulExpr(WideMaxBECount,
-                                  getSignExtendExpr(Step, WideTy)));
+              getZeroExtendExpr(CastedMaxBECount, WideTy);
+          const SCEV *OperandExtendedAdd = getAddExpr(
+              WideStart, getMulExpr(WideMaxBECount, getSignExtendExprCached(
+                                                        Step, WideTy, Cache)));
           if (SAdd == OperandExtendedAdd) {
             // Cache knowledge of AR NSW, which is propagated to this AddRec.
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
-                getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
+                getSignExtendExprCached(Step, Ty, Cache), L,
+                AR->getNoWrapFlags());
           }
           // Similar to above, only this time treat the step value as unsigned.
           // This covers loops that count up with an unsigned step.
@@ -1843,7 +1896,7 @@ const SCEV *ScalarEvolution::getSignExte
 
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
+                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
                 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
           }
         }
@@ -1875,8 +1928,9 @@ const SCEV *ScalarEvolution::getSignExte
           // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
           const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
           return getAddRecExpr(
-              getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
-              getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+              getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
+              getSignExtendExprCached(Step, Ty, Cache), L,
+              AR->getNoWrapFlags());
         }
       }
 
@@ -1890,18 +1944,18 @@ const SCEV *ScalarEvolution::getSignExte
         const APInt &C2 = SC2->getAPInt();
         if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
             C2.isPowerOf2()) {
-          Start = getSignExtendExpr(Start, Ty);
+          Start = getSignExtendExprCached(Start, Ty, Cache);
           const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L,
                                             AR->getNoWrapFlags());
-          return getAddExpr(Start, getSignExtendExpr(NewAR, Ty));
+          return getAddExpr(Start, getSignExtendExprCached(NewAR, Ty, Cache));
         }
       }
 
       if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
         return getAddRecExpr(
-            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
-            getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
+            getSignExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags());
       }
     }
 
Modified: llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp?rev=300494&r1=300493&r2=300494&view=diff
==============================================================================
--- llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp (original)
+++ llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp Mon Apr 17 15:40:05 2017
@@ -666,5 +666,95 @@ TEST_F(ScalarEvolutionsTest, SCEVNormali
   });
 }
 
+// Expect the call of getZeroExtendExpr will not cost exponential time.
+TEST_F(ScalarEvolutionsTest, SCEVZeroExtendExpr) {
+  LLVMContext C;
+  SMDiagnostic Err;
+
+  // Generate a function like below:
+  // define void @foo() {
+  // entry:
+  //   br label %for.cond
+  //
+  // for.cond:
+  //   %0 = phi i64 [ 100, %entry ], [ %dec, %for.inc ]
+  //   %cmp = icmp sgt i64 %0, 90
+  //   br i1 %cmp, label %for.inc, label %for.cond1
+  //
+  // for.inc:
+  //   %dec = add nsw i64 %0, -1
+  //   br label %for.cond
+  //
+  // for.cond1:
+  //   %1 = phi i64 [ 100, %for.cond ], [ %dec5, %for.inc2 ]
+  //   %cmp3 = icmp sgt i64 %1, 90
+  //   br i1 %cmp3, label %for.inc2, label %for.cond4
+  //
+  // for.inc2:
+  //   %dec5 = add nsw i64 %1, -1
+  //   br label %for.cond1
+  //
+  // ......
+  //
+  // for.cond89:
+  //   %19 = phi i64 [ 100, %for.cond84 ], [ %dec94, %for.inc92 ]
+  //   %cmp93 = icmp sgt i64 %19, 90
+  //   br i1 %cmp93, label %for.inc92, label %for.end
+  //
+  // for.inc92:
+  //   %dec94 = add nsw i64 %19, -1
+  //   br label %for.cond89
+  //
+  // for.end:
+  //   %gep = getelementptr i8, i8* null, i64 %dec
+  //   %gep6 = getelementptr i8, i8* %gep, i64 %dec5
+  //   ......
+  //   %gep95 = getelementptr i8, i8* %gep91, i64 %dec94
+  //   ret void
+  // }
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), {}, false);
+  Function *F = cast<Function>(M.getOrInsertFunction("foo", FTy));
+
+  BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
+  BasicBlock *CondBB = BasicBlock::Create(Context, "for.cond", F);
+  BasicBlock *EndBB = BasicBlock::Create(Context, "for.end", F);
+  BranchInst::Create(CondBB, EntryBB);
+  BasicBlock *PrevBB = EntryBB;
+
+  Type *I64Ty = Type::getInt64Ty(Context);
+  Type *I8Ty = Type::getInt8Ty(Context);
+  Type *I8PtrTy = Type::getInt8PtrTy(Context);
+  Value *Accum = Constant::getNullValue(I8PtrTy);
+  int Iters = 20;
+  for (int i = 0; i < Iters; i++) {
+    BasicBlock *IncBB = BasicBlock::Create(Context, "for.inc", F, EndBB);
+    auto *PN = PHINode::Create(I64Ty, 2, "", CondBB);
+    PN->addIncoming(ConstantInt::get(Context, APInt(64, 100)), PrevBB);
+    auto *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, PN,
+                                ConstantInt::get(Context, APInt(64, 90)), "cmp",
+                                CondBB);
+    BasicBlock *NextBB;
+    if (i != Iters - 1)
+      NextBB = BasicBlock::Create(Context, "for.cond", F, EndBB);
+    else
+      NextBB = EndBB;
+    BranchInst::Create(IncBB, NextBB, Cmp, CondBB);
+    auto *Dec = BinaryOperator::CreateNSWAdd(
+        PN, ConstantInt::get(Context, APInt(64, -1)), "dec", IncBB);
+    PN->addIncoming(Dec, IncBB);
+    BranchInst::Create(CondBB, IncBB);
+
+    Accum = GetElementPtrInst::Create(I8Ty, Accum, Dec, "gep", EndBB);
+
+    PrevBB = CondBB;
+    CondBB = NextBB;
+  }
+  ReturnInst::Create(Context, nullptr, EndBB);
+  ScalarEvolution SE = buildSE(*F);
+  const SCEV *S = SE.getSCEV(Accum);
+  Type *I128Ty = Type::getInt128Ty(Context);
+  SE.getZeroExtendExpr(S, I128Ty);
+}
+
 }  // end anonymous namespace
 }  // end namespace llvm
    
    
More information about the llvm-commits
mailing list