[llvm] 5dad4c6 - [SCEV] Iteratively compute ranges for deeply nested expressions.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 21 13:58:07 PST 2022


Author: Florian Hahn
Date: 2022-11-21T21:56:14Z
New Revision: 5dad4c67882a5eac2485f722e8e4109642155058

URL: https://github.com/llvm/llvm-project/commit/5dad4c67882a5eac2485f722e8e4109642155058
DIFF: https://github.com/llvm/llvm-project/commit/5dad4c67882a5eac2485f722e8e4109642155058.diff

LOG: [SCEV] Iteratively compute ranges for deeply nested expressions.

At the moment, getRangeRef may overflow the stack for very deeply nested
expressions.

This patch introduces a new getRangeRefIter function, which first builds
a worklist of N-ary expressions and phi nodes, followed by their
operands iteratively.

getRangeRef has been extended to also take a Depth argument and it
switches to use getRangeRefIter once the depth reaches a certain
threshold.

This ensures compile-time is not impacted in general. Note that
the iterative algorithm may lead to a slightly different evaluation
order, which could result in slightly worse ranges for cyclic phis.

https://llvm-compile-time-tracker.com/compare.php?from=23c3eb7cdf3478c9db86f6cb5115821a8f0f5f40&to=e0e09fa338e77e53242bfc846e1484350ad79773&stat=instructions

Fixes #49579.

Reviewed By: mkazantsev

Differential Revision: https://reviews.llvm.org/D130728

Added: 
    llvm/test/Transforms/IndVarSimplify/range-iter-threshold.ll

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/ranges.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 580fef9f7a5d6..49f8ae84f46f5 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1283,6 +1283,9 @@ class ScalarEvolution {
   /// Mark SCEVUnknown Phis currently being processed by getRangeRef.
   SmallPtrSet<const PHINode *, 6> PendingPhiRanges;
 
+  /// Mark SCEVUnknown Phis currently being processed by getRangeRefIter.
+  SmallPtrSet<const PHINode *, 6> PendingPhiRangesIter;
+
   // Mark SCEVUnknown Phis currently being processed by isImpliedViaMerge.
   SmallPtrSet<const PHINode *, 6> PendingMerges;
 
@@ -1560,7 +1563,12 @@ class ScalarEvolution {
   /// Determine the range for a particular SCEV.
   /// NOTE: This returns a reference to an entry in a cache. It must be
   /// copied if its needed for longer.
-  const ConstantRange &getRangeRef(const SCEV *S, RangeSignHint Hint);
+  const ConstantRange &getRangeRef(const SCEV *S, RangeSignHint Hint,
+                                   unsigned Depth = 0);
+
+  /// Determine the range for a particular SCEV, but evaluates ranges for
+  /// operands iteratively first.
+  const ConstantRange &getRangeRefIter(const SCEV *S, RangeSignHint Hint);
 
   /// Determines the range for the affine SCEVAddRecExpr {\p Start,+,\p Step}.
   /// Helper for \c getRange.

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e2dd015bc5715..b7766992590f8 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -220,6 +220,11 @@ static cl::opt<unsigned>
                   cl::desc("Size of the expression which is considered huge"),
                   cl::init(4096));
 
+static cl::opt<unsigned> RangeIterThreshold(
+    "scev-range-iter-threshold", cl::Hidden,
+    cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
+    cl::init(32));
+
 static cl::opt<bool>
 ClassifyExpressions("scalar-evolution-classify-expressions",
     cl::Hidden, cl::init(true),
@@ -6425,18 +6430,78 @@ getRangeForUnknownRecurrence(const SCEVUnknown *U) {
   return FullSet;
 }
 
+const ConstantRange &
+ScalarEvolution::getRangeRefIter(const SCEV *S,
+                                 ScalarEvolution::RangeSignHint SignHint) {
+  DenseMap<const SCEV *, ConstantRange> &Cache =
+      SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
+                                                       : SignedRanges;
+  SmallVector<const SCEV *> WorkList;
+  SmallPtrSet<const SCEV *, 8> Seen;
+
+  // Add Expr to the worklist, if Expr is either an N-ary expression or a
+  // SCEVUnknown PHI node.
+  auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
+    if (!Seen.insert(Expr).second)
+      return;
+    if (Cache.find(Expr) != Cache.end())
+      return;
+    if (isa<SCEVNAryExpr>(Expr) || isa<SCEVUDivExpr>(Expr))
+      WorkList.push_back(Expr);
+    else if (auto *UnknownS = dyn_cast<SCEVUnknown>(Expr))
+      if (isa<PHINode>(UnknownS->getValue()))
+        WorkList.push_back(Expr);
+  };
+  AddToWorklist(S);
+
+  // Build worklist by queuing operands of N-ary expressions and phi nodes.
+  for (unsigned I = 0; I != WorkList.size(); ++I) {
+    const SCEV *P = WorkList[I];
+    if (auto *NaryS = dyn_cast<SCEVNAryExpr>(P)) {
+      for (const SCEV *Op : NaryS->operands())
+        AddToWorklist(Op);
+    } else if (auto *UDiv = dyn_cast<SCEVUDivExpr>(P)) {
+      AddToWorklist(UDiv->getLHS());
+      AddToWorklist(UDiv->getRHS());
+    } else {
+      auto *UnknownS = cast<SCEVUnknown>(P);
+      if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
+        if (!PendingPhiRangesIter.insert(P).second)
+          continue;
+        for (auto &Op : reverse(P->operands()))
+          AddToWorklist(getSCEV(Op));
+      }
+    }
+  }
+
+  if (!WorkList.empty()) {
+    // Use getRangeRef to compute ranges for items in the worklist in reverse
+    // order. This will force ranges for earlier operands to be computed before
+    // their users in most cases.
+    for (const SCEV *P :
+         reverse(make_range(WorkList.begin() + 1, WorkList.end()))) {
+      getRangeRef(P, SignHint);
+
+      if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
+        if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
+          PendingPhiRangesIter.erase(P);
+    }
+  }
+
+  return getRangeRef(S, SignHint, 0);
+}
+
 /// Determine the range for a particular SCEV.  If SignHint is
 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
 /// with a "cleaner" unsigned (resp. signed) representation.
-const ConstantRange &
-ScalarEvolution::getRangeRef(const SCEV *S,
-                             ScalarEvolution::RangeSignHint SignHint) {
+const ConstantRange &ScalarEvolution::getRangeRef(
+    const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
   DenseMap<const SCEV *, ConstantRange> &Cache =
       SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
                                                        : SignedRanges;
   ConstantRange::PreferredRangeType RangeType =
-      SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED
-          ? ConstantRange::Unsigned : ConstantRange::Signed;
+      SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
+                                                       : ConstantRange::Signed;
 
   // See if we've computed this range already.
   DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
@@ -6446,6 +6511,11 @@ ScalarEvolution::getRangeRef(const SCEV *S,
   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
     return setRange(C, SignHint, ConstantRange(C->getAPInt()));
 
+  // Switch to iteratively computing the range for S, if it is part of a deeply
+  // nested expression.
+  if (Depth > RangeIterThreshold)
+    return getRangeRefIter(S, SignHint);
+
   unsigned BitWidth = getTypeSizeInBits(S->getType());
   ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
   using OBO = OverflowingBinaryOperator;
@@ -6465,23 +6535,23 @@ ScalarEvolution::getRangeRef(const SCEV *S,
   }
 
   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
-    ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
+    ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
     unsigned WrapType = OBO::AnyWrap;
     if (Add->hasNoSignedWrap())
       WrapType |= OBO::NoSignedWrap;
     if (Add->hasNoUnsignedWrap())
       WrapType |= OBO::NoUnsignedWrap;
     for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
-      X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint),
+      X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1),
                           WrapType, RangeType);
     return setRange(Add, SignHint,
                     ConservativeResult.intersectWith(X, RangeType));
   }
 
   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
-    ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
+    ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
-      X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
+      X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
     return setRange(Mul, SignHint,
                     ConservativeResult.intersectWith(X, RangeType));
   }
@@ -6507,41 +6577,42 @@ ScalarEvolution::getRangeRef(const SCEV *S,
     }
 
     const auto *NAry = cast<SCEVNAryExpr>(S);
-    ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint);
+    ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
     for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
-      X = X.intrinsic(ID, {X, getRangeRef(NAry->getOperand(i), SignHint)});
+      X = X.intrinsic(
+          ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
     return setRange(S, SignHint,
                     ConservativeResult.intersectWith(X, RangeType));
   }
 
   if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
-    ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
-    ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
+    ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
+    ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
     return setRange(UDiv, SignHint,
                     ConservativeResult.intersectWith(X.udiv(Y), RangeType));
   }
 
   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
-    ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
+    ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
     return setRange(ZExt, SignHint,
                     ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
                                                      RangeType));
   }
 
   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
-    ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
+    ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
     return setRange(SExt, SignHint,
                     ConservativeResult.intersectWith(X.signExtend(BitWidth),
                                                      RangeType));
   }
 
   if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
-    ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint);
+    ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
     return setRange(PtrToInt, SignHint, X);
   }
 
   if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
-    ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
+    ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
     return setRange(Trunc, SignHint,
                     ConservativeResult.intersectWith(X.truncate(BitWidth),
                                                      RangeType));
@@ -6671,12 +6742,13 @@ ScalarEvolution::getRangeRef(const SCEV *S,
           RangeType);
 
     // A range of Phi is a subset of union of all ranges of its input.
-    if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
+    if (PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
       // Make sure that we do not run over cycled Phis.
       if (PendingPhiRanges.insert(Phi).second) {
         ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
+
         for (const auto &Op : Phi->operands()) {
-          auto OpRange = getRangeRef(getSCEV(Op), SignHint);
+          auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
           RangeFromOps = RangeFromOps.unionWith(OpRange);
           // No point to continue if we already have a full set.
           if (RangeFromOps.isFullSet())

diff  --git a/llvm/test/Analysis/ScalarEvolution/ranges.ll b/llvm/test/Analysis/ScalarEvolution/ranges.ll
index f5583d091f148..1a72f10cdd3df 100644
--- a/llvm/test/Analysis/ScalarEvolution/ranges.ll
+++ b/llvm/test/Analysis/ScalarEvolution/ranges.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
  ; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
+ ; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" -scev-range-iter-threshold=1 2>&1 | FileCheck %s
 
 target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64"
 

diff  --git a/llvm/test/Transforms/IndVarSimplify/range-iter-threshold.ll b/llvm/test/Transforms/IndVarSimplify/range-iter-threshold.ll
new file mode 100644
index 0000000000000..9b27d01f8eee0
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/range-iter-threshold.ll
@@ -0,0 +1,73 @@
+; RUN: opt -passes=indvars -S %s | FileCheck --check-prefix=COMMON --check-prefix=DEFAULT %s
+; RUN: opt -passes=indvars -scev-range-iter-threshold=1 -S %s | FileCheck --check-prefix=COMMON --check-prefix=LIMIT %s
+
+target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+
+define i32 @test(i1 %c.0, i32 %m) {
+; COMMON-LABEL: @test(
+; COMMON-NEXT:  entry:
+; COMMON-NEXT:    br label [[OUTER_HEADER:%.*]]
+; COMMON:       outer.header:
+; DEFAULT-NEXT:   [[INDVARS_IV:%.*]] = phi i32 [ [[INDVARS_IV_NEXT:%.*]], [[OUTER_LATCH:%.*]] ], [ 2, [[ENTRY:%.*]] ]
+; COMMON-NEXT:    [[IV_1:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[IV_1_NEXT:%.*]], [[OUTER_LATCH:%.*]] ]
+; COMMON-NEXT:    [[MAX_0:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[MAX_1:%.*]], [[OUTER_LATCH]] ]
+; COMMON-NEXT:    [[TMP0:%.*]] = sext i32 [[IV_1]] to i64
+; COMMON-NEXT:    br label [[INNER_1:%.*]]
+; COMMON:       inner.1:
+; COMMON-NEXT:    [[C_1:%.*]] = icmp slt i64 0, [[TMP0]]
+; COMMON-NEXT:    br i1 [[C_1]], label [[INNER_1]], label [[INNER_2_HEADER_PREHEADER:%.*]]
+; COMMON:       inner.2.header.preheader:
+; COMMON-NEXT:    br label [[INNER_2_HEADER:%.*]]
+; COMMON:       inner.2.header:
+; COMMON-NEXT:    [[IV_3:%.*]] = phi i32 [ [[IV_3_NEXT:%.*]], [[INNER_2_LATCH:%.*]] ], [ 0, [[INNER_2_HEADER_PREHEADER]] ]
+; COMMON-NEXT:    br i1 [[C_0:%.*]], label [[OUTER_LATCH]], label [[INNER_2_LATCH]]
+; COMMON:       inner.2.latch:
+; COMMON-NEXT:    [[IV_3_NEXT]] = add i32 [[IV_3]], 1
+; DEFAULT-NEXT:   [[EXITCOND:%.*]] = icmp eq i32 [[IV_3_NEXT]], [[INDVARS_IV]]
+; LIMIT-NEXT:     [[EXITCOND:%.*]] = icmp ugt i32 [[IV_3]], [[IV_1]]
+; COMMON-NEXT:    br i1 [[EXITCOND]], label [[OUTER_LATCH]], label [[INNER_2_HEADER]]
+; COMMON:       outer.latch:
+; COMMON-NEXT:    [[MAX_1]] = phi i32 [ [[M:%.*]], [[INNER_2_LATCH]] ], [ 0, [[INNER_2_HEADER]] ]
+; COMMON-NEXT:    [[IV_1_NEXT]] = add nuw i32 [[IV_1]], 1
+; COMMON-NEXT:    [[C_3:%.*]] = icmp ugt i32 [[IV_1]], [[MAX_0]]
+; DEFAULT-NEXT:   [[INDVARS_IV_NEXT]] = add i32 [[INDVARS_IV]], 1
+; COMMON-NEXT:    br i1 [[C_3]], label [[EXIT:%.*]], label [[OUTER_HEADER]], !llvm.loop [[LOOP0:![0-9]+]]
+; COMMON:       exit:
+; COMMON-NEXT:    ret i32 0
+;
+entry:
+  br label %outer.header
+
+outer.header:
+  %iv.1 = phi i32 [ 0, %entry ], [ %iv.1.next, %outer.latch ]
+  %iv.2 = phi i32 [ 0, %entry ], [ %iv.2.next , %outer.latch ]
+  %max.0 = phi i32 [ 0, %entry ], [ %max.1, %outer.latch ]
+  %0 = sext i32 %iv.1 to i64
+  br label %inner.1
+
+inner.1:
+  %c.1 = icmp slt i64 0, %0
+  br i1 %c.1, label %inner.1, label %inner.2.header
+
+inner.2.header:
+  %iv.3 = phi i32 [ 0, %inner.1 ], [ %iv.3.next, %inner.2.latch ]
+  br i1 %c.0, label %outer.latch, label %inner.2.latch
+
+inner.2.latch:
+  %iv.3.next = add i32 %iv.3, 1
+  %c.2 = icmp ugt i32 %iv.3, %iv.2
+  br i1 %c.2, label %outer.latch, label %inner.2.header
+
+outer.latch:
+  %max.1 = phi i32 [ %m, %inner.2.latch ], [ %iv.3, %inner.2.header ]
+  %iv.1.next = add i32 %iv.1, 1
+  %iv.2.next = add i32 %iv.2, 1
+  %c.3 = icmp ugt i32 %iv.2, %max.0
+  br i1 %c.3, label %exit, label %outer.header, !llvm.loop !0
+
+exit:
+  ret i32 0
+}
+
+!0 = distinct !{!0, !1}
+!1 = !{!"llvm.loop.mustprogress"}


        


More information about the llvm-commits mailing list