[llvm] r294181 - [SCEV] limit recursion depth and operands number in getAddExpr

Daniil Fukalov via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 6 04:38:06 PST 2017


Author: dfukalov
Date: Mon Feb  6 06:38:06 2017
New Revision: 294181

URL: http://llvm.org/viewvc/llvm-project?rev=294181&view=rev
Log:
[SCEV] limit recursion depth and operands number in getAddExpr

for a quite big function with source like

%add = add nsw i32 %mul, %conv
%mul1 = mul nsw i32 %add, %conv
%add2 = add nsw i32 %mul1, %add
%mul3 = mul nsw i32 %add2, %add
; repeat couple of thousands times
that can be produced by loop unroll, getAddExpr() tries to recursively construct SCEV and runs almost infinite time.

Added recursion depth restriction (with new parameter to set it)

Reviewers: sanjoy

Subscribers: hfinkel, llvm-commits, mzolotukhin

Differential Revision: https://reviews.llvm.org/D28158

Modified:
    llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp
    llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp

Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolution.h?rev=294181&r1=294180&r2=294181&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h (original)
+++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h Mon Feb  6 06:38:06 2017
@@ -1140,7 +1140,8 @@ public:
   const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty);
   const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty);
   const SCEV *getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
-                         SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap);
+                         SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap,
+                         unsigned Depth = 0);
   const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS,
                          SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
     SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
@@ -1613,6 +1614,10 @@ private:
   bool doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned,
                           bool NoWrap);
 
+  /// Get add expr already created or create a new one
+  const SCEV *getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,
+                                 SCEV::NoWrapFlags Flags);
+
 private:
   FoldingSet<SCEV> UniqueSCEVs;
   FoldingSet<SCEVPredicate> UniquePreds;

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=294181&r1=294180&r2=294181&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Mon Feb  6 06:38:06 2017
@@ -137,6 +137,11 @@ static cl::opt<unsigned>
                     cl::desc("Maximum depth of recursive compare complexity"),
                     cl::init(32));
 
+static cl::opt<unsigned>
+    MaxAddExprDepth("scalar-evolution-max-addexpr-depth", cl::Hidden,
+                    cl::desc("Maximum depth of recursive AddExpr"),
+                    cl::init(32));
+
 static cl::opt<unsigned> MaxConstantEvolvingDepth(
     "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
     cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
@@ -2100,7 +2105,8 @@ StrengthenNoWrapFlags(ScalarEvolution *S
 
 /// Get a canonical add expression, or something simpler if possible.
 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
-                                        SCEV::NoWrapFlags Flags) {
+                                        SCEV::NoWrapFlags Flags,
+                                        unsigned Depth) {
   assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
          "only nuw or nsw allowed");
   assert(!Ops.empty() && "Cannot get empty add!");
@@ -2139,6 +2145,10 @@ const SCEV *ScalarEvolution::getAddExpr(
     if (Ops.size() == 1) return Ops[0];
   }
 
+  // Limit recursion calls depth
+  if (Depth > MaxAddExprDepth)
+    return getOrCreateAddExpr(Ops, Flags);
+
   // Okay, check to see if the same value occurs in the operand list more than
   // once.  If so, merge them together into an multiply expression.  Since we
   // sorted the list, these values are required to be adjacent.
@@ -2210,7 +2220,7 @@ const SCEV *ScalarEvolution::getAddExpr(
     }
     if (Ok) {
       // Evaluate the expression in the larger type.
-      const SCEV *Fold = getAddExpr(LargeOps, Flags);
+      const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1);
       // If it folds to something simple, use it. Otherwise, don't.
       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
         return getTruncateExpr(Fold, DstType);
@@ -2239,7 +2249,7 @@ const SCEV *ScalarEvolution::getAddExpr(
     // and they are not necessarily sorted.  Recurse to resort and resimplify
     // any operands we just acquired.
     if (DeletedAdd)
-      return getAddExpr(Ops);
+      return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
   }
 
   // Skip over the add expression until we get to a multiply.
@@ -2274,13 +2284,14 @@ const SCEV *ScalarEvolution::getAddExpr(
         Ops.push_back(getConstant(AccumulatedConstant));
       for (auto &MulOp : MulOpLists)
         if (MulOp.first != 0)
-          Ops.push_back(getMulExpr(getConstant(MulOp.first),
-                                   getAddExpr(MulOp.second)));
+          Ops.push_back(getMulExpr(
+              getConstant(MulOp.first),
+              getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1)));
       if (Ops.empty())
         return getZero(Ty);
       if (Ops.size() == 1)
         return Ops[0];
-      return getAddExpr(Ops);
+      return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
     }
   }
 
@@ -2305,8 +2316,8 @@ const SCEV *ScalarEvolution::getAddExpr(
             MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
             InnerMul = getMulExpr(MulOps);
           }
-          const SCEV *One = getOne(Ty);
-          const SCEV *AddOne = getAddExpr(One, InnerMul);
+          SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
+          const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
           const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV);
           if (Ops.size() == 2) return OuterMul;
           if (AddOp < Idx) {
@@ -2317,7 +2328,7 @@ const SCEV *ScalarEvolution::getAddExpr(
             Ops.erase(Ops.begin()+AddOp-1);
           }
           Ops.push_back(OuterMul);
-          return getAddExpr(Ops);
+          return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
         }
 
       // Check this multiply against other multiplies being added together.
@@ -2345,13 +2356,15 @@ const SCEV *ScalarEvolution::getAddExpr(
               MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
               InnerMul2 = getMulExpr(MulOps);
             }
-            const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
+            SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
+            const SCEV *InnerMulSum =
+                getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
             const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
             if (Ops.size() == 2) return OuterMul;
             Ops.erase(Ops.begin()+Idx);
             Ops.erase(Ops.begin()+OtherMulIdx-1);
             Ops.push_back(OuterMul);
-            return getAddExpr(Ops);
+            return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
           }
       }
     }
@@ -2387,7 +2400,7 @@ const SCEV *ScalarEvolution::getAddExpr(
       // This follows from the fact that the no-wrap flags on the outer add
       // expression are applicable on the 0th iteration, when the add recurrence
       // will be equal to its start value.
-      AddRecOps[0] = getAddExpr(LIOps, Flags);
+      AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1);
 
       // Build the new addrec. Propagate the NUW and NSW flags if both the
       // outer add and the inner addrec are guaranteed to have no overflow.
@@ -2404,7 +2417,7 @@ const SCEV *ScalarEvolution::getAddExpr(
           Ops[i] = NewRec;
           break;
         }
-      return getAddExpr(Ops);
+      return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
     }
 
     // Okay, if there weren't any loop invariants to be folded, check to see if
@@ -2428,14 +2441,15 @@ const SCEV *ScalarEvolution::getAddExpr(
                                    OtherAddRec->op_end());
                   break;
                 }
-                AddRecOps[i] = getAddExpr(AddRecOps[i],
-                                          OtherAddRec->getOperand(i));
+                SmallVector<const SCEV *, 2> TwoOps = {
+                    AddRecOps[i], OtherAddRec->getOperand(i)};
+                AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
               }
               Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
             }
         // Step size has changed, so we cannot guarantee no self-wraparound.
         Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
-        return getAddExpr(Ops);
+        return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
       }
 
     // Otherwise couldn't fold anything into this recurrence.  Move onto the
@@ -2444,18 +2458,24 @@ const SCEV *ScalarEvolution::getAddExpr(
 
   // Okay, it looks like we really DO need an add expr.  Check to see if we
   // already have one, otherwise create a new one.
+  return getOrCreateAddExpr(Ops, Flags);
+}
+
+const SCEV *
+ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,
+                                    SCEV::NoWrapFlags Flags) {
   FoldingSetNodeID ID;
   ID.AddInteger(scAddExpr);
   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
     ID.AddPointer(Ops[i]);
   void *IP = nullptr;
   SCEVAddExpr *S =
-    static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+      static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
   if (!S) {
     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
-    S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator),
-                                        O, Ops.size());
+    S = new (SCEVAllocator)
+        SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
     UniqueSCEVs.InsertNode(S, IP);
   }
   S->setNoWrapFlags(Flags);

Modified: llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp?rev=294181&r1=294180&r2=294181&view=diff
==============================================================================
--- llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp (original)
+++ llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp Mon Feb  6 06:38:06 2017
@@ -532,5 +532,33 @@ TEST_F(ScalarEvolutionsTest, SCEVCompare
   EXPECT_NE(nullptr, SE.getSCEV(Acc[0]));
 }
 
+TEST_F(ScalarEvolutionsTest, SCEVAddExpr) {
+  Type *Ty32 = Type::getInt32Ty(Context);
+  Type *ArgTys[] = {Type::getInt64Ty(Context), Ty32};
+
+  FunctionType *FTy =
+      FunctionType::get(Type::getVoidTy(Context), ArgTys, false);
+  Function *F = cast<Function>(M.getOrInsertFunction("f", FTy));
+
+  Argument *A1 = &*F->arg_begin();
+  Argument *A2 = &*(std::next(F->arg_begin()));
+  BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
+
+  Instruction *Trunc = CastInst::CreateTruncOrBitCast(A1, Ty32, "", EntryBB);
+  Instruction *Mul1 = BinaryOperator::CreateMul(Trunc, A2, "", EntryBB);
+  Instruction *Add1 = BinaryOperator::CreateAdd(Mul1, Trunc, "", EntryBB);
+  Mul1 = BinaryOperator::CreateMul(Add1, Trunc, "", EntryBB);
+  Instruction *Add2 = BinaryOperator::CreateAdd(Mul1, Add1, "", EntryBB);
+  for (int i = 0; i < 1000; i++) {
+    Mul1 = BinaryOperator::CreateMul(Add2, Add1, "", EntryBB);
+    Add1 = Add2;
+    Add2 = BinaryOperator::CreateAdd(Mul1, Add1, "", EntryBB);
+  }
+
+  ReturnInst::Create(Context, nullptr, EntryBB);
+  ScalarEvolution SE = buildSE(*F);
+  EXPECT_NE(nullptr, SE.getSCEV(Mul1));
+}
+
 }  // end anonymous namespace
 }  // end namespace llvm




More information about the llvm-commits mailing list