[llvm] c302f1e - [SCEV] Generalize SCEVEqualsPredicate to any compare [NFC]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 8 08:18:16 PST 2022


Author: Philip Reames
Date: 2022-02-08T08:18:09-08:00
New Revision: c302f1e6771b0cbfad466e56f19e36d4dcaecd11

URL: https://github.com/llvm/llvm-project/commit/c302f1e6771b0cbfad466e56f19e36d4dcaecd11
DIFF: https://github.com/llvm/llvm-project/commit/c302f1e6771b0cbfad466e56f19e36d4dcaecd11.diff

LOG: [SCEV] Generalize SCEVEqualsPredicate to any compare [NFC]

PredicatedScalarEvolution has a predicate type for representing A == B.  This change generalizes it into something which can represent a A <pred> B.

This generality is currently unused, but is motivated by a couple of recent cases which have come up.  In particular, I'm currently playing around with using this to simplify the runtime checking code in LoopVectorizer. Regardless of the outcome of that prototyping, generalizing the compare node seemed useful.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index c39de9d72cf49..51b15510a3add 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -219,7 +219,7 @@ class SCEVPredicate : public FoldingSetNode {
   FoldingSetNodeIDRef FastID;
 
 public:
-  enum SCEVPredicateKind { P_Union, P_Equal, P_Wrap };
+  enum SCEVPredicateKind { P_Union, P_Compare, P_Wrap };
 
 protected:
   SCEVPredicateKind Kind;
@@ -276,16 +276,18 @@ struct FoldingSetTrait<SCEVPredicate> : DefaultFoldingSetTrait<SCEVPredicate> {
   }
 };
 
-/// This class represents an assumption that two SCEV expressions are equal,
-/// and this can be checked at run-time.
-class SCEVEqualPredicate final : public SCEVPredicate {
-  /// We assume that LHS == RHS.
+/// This class represents an assumption that the expression LHS Pred RHS
+/// evaluates to true, and this can be checked at run-time.
+class SCEVComparePredicate final : public SCEVPredicate {
+  /// We assume that LHS Pred RHS is true.
+  const ICmpInst::Predicate Pred;
   const SCEV *LHS;
   const SCEV *RHS;
 
 public:
-  SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEV *LHS,
-                     const SCEV *RHS);
+  SCEVComparePredicate(const FoldingSetNodeIDRef ID,
+                       const ICmpInst::Predicate Pred,
+                       const SCEV *LHS, const SCEV *RHS);
 
   /// Implementation of the SCEVPredicate interface
   bool implies(const SCEVPredicate *N) const override;
@@ -293,15 +295,17 @@ class SCEVEqualPredicate final : public SCEVPredicate {
   bool isAlwaysTrue() const override;
   const SCEV *getExpr() const override;
 
-  /// Returns the left hand side of the equality.
+  ICmpInst::Predicate getPredicate() const { return Pred; }
+
+  /// Returns the left hand side of the predicate.
   const SCEV *getLHS() const { return LHS; }
 
-  /// Returns the right hand side of the equality.
+  /// Returns the right hand side of the predicate.
   const SCEV *getRHS() const { return RHS; }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const SCEVPredicate *P) {
-    return P->getKind() == P_Equal;
+    return P->getKind() == P_Compare;
   }
 };
 

diff  --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 277eb7acf238c..60b772b94a6f7 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -293,8 +293,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc);
 
   /// A specialized variant of expandCodeForPredicate, handling the case when
-  /// we are expanding code for a SCEVEqualPredicate.
-  Value *expandEqualPredicate(const SCEVEqualPredicate *Pred, Instruction *Loc);
+  /// we are expanding code for a SCEVComparePredicate.
+  Value *expandComparePredicate(const SCEVComparePredicate *Pred,
+                                Instruction *Loc);
 
   /// Generates code that evaluates if the \p AR expression will overflow.
   Value *generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc,

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 977fc09113550..fca5614c7469e 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13541,14 +13541,15 @@ const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
   assert(LHS->getType() == RHS->getType() &&
          "Type mismatch between LHS and RHS");
   // Unique this node based on the arguments
-  ID.AddInteger(SCEVPredicate::P_Equal);
+  ID.AddInteger(SCEVPredicate::P_Compare);
+  ID.AddInteger(ICmpInst::ICMP_EQ);
   ID.AddPointer(LHS);
   ID.AddPointer(RHS);
   void *IP = nullptr;
   if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
     return S;
-  SCEVEqualPredicate *Eq = new (SCEVAllocator)
-      SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+  SCEVComparePredicate *Eq = new (SCEVAllocator)
+    SCEVComparePredicate(ID.Intern(SCEVAllocator), ICmpInst::ICMP_EQ, LHS, RHS);
   UniquePreds.InsertNode(Eq, IP);
   return Eq;
 }
@@ -13594,8 +13595,9 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
     if (Pred) {
       auto ExprPreds = Pred->getPredicatesForExpr(Expr);
       for (auto *Pred : ExprPreds)
-        if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
-          if (IPred->getLHS() == Expr)
+        if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
+          if (IPred->getLHS() == Expr &&
+              IPred->getPredicate() == ICmpInst::ICMP_EQ)
             return IPred->getRHS();
     }
     return convertToAddRecWithPreds(Expr);
@@ -13715,28 +13717,38 @@ SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
                              SCEVPredicateKind Kind)
     : FastID(ID), Kind(Kind) {}
 
-SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
-                                       const SCEV *LHS, const SCEV *RHS)
-    : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {
+SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
+                                   const ICmpInst::Predicate Pred,
+                                   const SCEV *LHS, const SCEV *RHS)
+  : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
   assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
   assert(LHS != RHS && "LHS and RHS are the same SCEV");
 }
 
-bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
-  const auto *Op = dyn_cast<SCEVEqualPredicate>(N);
+bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
+  const auto *Op = dyn_cast<SCEVComparePredicate>(N);
 
   if (!Op)
     return false;
 
+  if (Pred != ICmpInst::ICMP_EQ)
+    return false;
+
   return Op->LHS == LHS && Op->RHS == RHS;
 }
 
-bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
 
-const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+const SCEV *SCEVComparePredicate::getExpr() const { return LHS; }
+
+void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const {
+  if (Pred == ICmpInst::ICMP_EQ)
+    OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+  else
+    OS.indent(Depth) << "Compare predicate: " << *LHS
+                     << " " << CmpInst::getPredicateName(Pred) << ") "
+                     << *RHS << "\n";
 
-void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
-  OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
 }
 
 SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,

diff  --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 277431fbbfd77..fbd42ce41a999 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -2469,8 +2469,8 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
   switch (Pred->getKind()) {
   case SCEVPredicate::P_Union:
     return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP);
-  case SCEVPredicate::P_Equal:
-    return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP);
+  case SCEVPredicate::P_Compare:
+    return expandComparePredicate(cast<SCEVComparePredicate>(Pred), IP);
   case SCEVPredicate::P_Wrap: {
     auto *AddRecPred = cast<SCEVWrapPredicate>(Pred);
     return expandWrapPredicate(AddRecPred, IP);
@@ -2479,15 +2479,16 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
   llvm_unreachable("Unknown SCEV predicate type");
 }
 
-Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
-                                          Instruction *IP) {
+Value *SCEVExpander::expandComparePredicate(const SCEVComparePredicate *Pred,
+                                            Instruction *IP) {
   Value *Expr0 =
       expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP, false);
   Value *Expr1 =
       expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP, false);
 
   Builder.SetInsertPoint(IP);
-  auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
+  auto InvPred = ICmpInst::getInversePredicate(Pred->getPredicate());
+  auto *I = Builder.CreateICmp(InvPred, Expr0, Expr1, "ident.check");
   return I;
 }
 


        


More information about the llvm-commits mailing list