[llvm] r286386 - [SCEV] Refactor out a useful pattern; NFC

Sanjoy Das via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 9 10:22:44 PST 2016


Author: sanjoy
Date: Wed Nov  9 12:22:43 2016
New Revision: 286386

URL: http://llvm.org/viewvc/llvm-project?rev=286386&view=rev
Log:
[SCEV] Refactor out a useful pattern; NFC

Modified:
    llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpressions.h
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpressions.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpressions.h?rev=286386&r1=286385&r2=286386&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpressions.h (original)
+++ llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpressions.h Wed Nov  9 12:22:43 2016
@@ -537,6 +537,31 @@ namespace llvm {
     T.visitAll(Root);
   }
 
+  /// Return true if any node in \p Root satisfies the predicate \p Pred.
+  template <typename PredTy>
+  bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
+    struct FindClosure {
+      bool Found = false;
+      PredTy Pred;
+
+      FindClosure(PredTy Pred) : Pred(Pred) {}
+
+      bool follow(const SCEV *S) {
+        if (!Pred(S))
+          return true;
+
+        Found = true;
+        return false;
+      }
+
+      bool isDone() const { return Found; }
+    };
+
+    FindClosure FC(Pred);
+    visitAll(Root, FC);
+    return FC.Found;
+  }
+
   /// This visitor recursively visits a SCEV expression and re-writes it.
   /// The result from each visit is cached, so it will return the same
   /// SCEV for the same input.

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=286386&r1=286385&r2=286386&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Wed Nov  9 12:22:43 2016
@@ -3356,69 +3356,24 @@ const SCEV *ScalarEvolution::getCouldNot
   return CouldNotCompute.get();
 }
 
-
 bool ScalarEvolution::checkValidity(const SCEV *S) const {
-  // Helper class working with SCEVTraversal to figure out if a SCEV contains
-  // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne
-  // is set iff if find such SCEVUnknown.
-  //
-  struct FindInvalidSCEVUnknown {
-    bool FindOne;
-    FindInvalidSCEVUnknown() { FindOne = false; }
-    bool follow(const SCEV *S) {
-      switch (static_cast<SCEVTypes>(S->getSCEVType())) {
-      case scConstant:
-        return false;
-      case scUnknown:
-        if (!cast<SCEVUnknown>(S)->getValue())
-          FindOne = true;
-        return false;
-      default:
-        return true;
-      }
-    }
-    bool isDone() const { return FindOne; }
-  };
-
-  FindInvalidSCEVUnknown F;
-  SCEVTraversal<FindInvalidSCEVUnknown> ST(F);
-  ST.visitAll(S);
+  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
+    auto *SU = dyn_cast<SCEVUnknown>(S);
+    return SU && SU->getValue() == nullptr;
+  });
 
-  return !F.FindOne;
+  return !ContainsNulls;
 }
 
 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
-  // Helper class working with SCEVTraversal to figure out if a SCEV contains a
-  // sub SCEV of scAddRecExpr type.  FindInvalidSCEVUnknown::FoundOne is set iff
-  // if such sub scAddRecExpr type SCEV is found.
-  struct FindAddRecurrence {
-    bool FoundOne;
-    FindAddRecurrence() : FoundOne(false) {}
-
-    bool follow(const SCEV *S) {
-      switch (static_cast<SCEVTypes>(S->getSCEVType())) {
-      case scAddRecExpr:
-        FoundOne = true;
-      case scConstant:
-      case scUnknown:
-      case scCouldNotCompute:
-        return false;
-      default:
-        return true;
-      }
-    }
-    bool isDone() const { return FoundOne; }
-  };
-
   HasRecMapType::iterator I = HasRecMap.find(S);
   if (I != HasRecMap.end())
     return I->second;
 
-  FindAddRecurrence F;
-  SCEVTraversal<FindAddRecurrence> ST(F);
-  ST.visitAll(S);
-  HasRecMap.insert({S, F.FoundOne});
-  return F.FoundOne;
+  bool FoundAddRec =
+      SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
+  HasRecMap.insert({S, FoundAddRec});
+  return FoundAddRec;
 }
 
 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
@@ -8993,38 +8948,15 @@ const SCEV *SCEVAddRecExpr::getNumIterat
   return SE.getCouldNotCompute();
 }
 
-namespace {
-struct FindUndefs {
-  bool Found;
-  FindUndefs() : Found(false) {}
-
-  bool follow(const SCEV *S) {
-    if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) {
-      if (isa<UndefValue>(C->getValue()))
-        Found = true;
-    } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
-      if (isa<UndefValue>(C->getValue()))
-        Found = true;
-    }
-
-    // Keep looking if we haven't found it yet.
-    return !Found;
-  }
-  bool isDone() const {
-    // Stop recursion if we have found an undef.
-    return Found;
-  }
-};
-}
-
 // Return true when S contains at least an undef value.
-static inline bool
-containsUndefs(const SCEV *S) {
-  FindUndefs F;
-  SCEVTraversal<FindUndefs> ST(F);
-  ST.visitAll(S);
-
-  return F.Found;
+static inline bool containsUndefs(const SCEV *S) {
+  return SCEVExprContains(S, [](const SCEV *S) {
+    if (const auto *SU = dyn_cast<SCEVUnknown>(S))
+      return isa<UndefValue>(SU->getValue());
+    else if (const auto *SC = dyn_cast<SCEVConstant>(S))
+      return isa<UndefValue>(SC->getValue());
+    return false;
+  });
 }
 
 namespace {
@@ -9217,40 +9149,11 @@ static bool findArrayDimensionsRec(Scala
   return true;
 }
 
-// Returns true when S contains at least a SCEVUnknown parameter.
-static inline bool
-containsParameters(const SCEV *S) {
-  struct FindParameter {
-    bool FoundParameter;
-    FindParameter() : FoundParameter(false) {}
-
-    bool follow(const SCEV *S) {
-      if (isa<SCEVUnknown>(S)) {
-        FoundParameter = true;
-        // Stop recursion: we found a parameter.
-        return false;
-      }
-      // Keep looking.
-      return true;
-    }
-    bool isDone() const {
-      // Stop recursion if we have found a parameter.
-      return FoundParameter;
-    }
-  };
-
-  FindParameter F;
-  SCEVTraversal<FindParameter> ST(F);
-  ST.visitAll(S);
-
-  return F.FoundParameter;
-}
 
 // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
-static inline bool
-containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
+static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
   for (const SCEV *T : Terms)
-    if (containsParameters(T))
+    if (SCEVExprContains(T, [](const SCEV *S) { return isa<SCEVUnknown>(S); }))
       return true;
   return false;
 }
@@ -9977,24 +9880,7 @@ bool ScalarEvolution::properlyDominates(
 }
 
 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
-  // Search for a SCEV expression node within an expression tree.
-  // Implements SCEVTraversal::Visitor.
-  struct SCEVSearch {
-    const SCEV *Node;
-    bool IsFound;
-
-    SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
-
-    bool follow(const SCEV *S) {
-      IsFound |= (S == Node);
-      return !IsFound;
-    }
-    bool isDone() const { return IsFound; }
-  };
-
-  SCEVSearch Search(Op);
-  visitAll(S, Search);
-  return Search.IsFound;
+  return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
 }
 
 void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {




More information about the llvm-commits mailing list