[llvm] c86a982 - [SCEV] `getSequentialMinMaxExpr()`: rewrite deduplication to be fully recursive

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 14 04:42:58 PST 2022


Author: Roman Lebedev
Date: 2022-01-14T15:42:26+03:00
New Revision: c86a982d7dad36c31fcfc5b06fd03b2e3289175f

URL: https://github.com/llvm/llvm-project/commit/c86a982d7dad36c31fcfc5b06fd03b2e3289175f
DIFF: https://github.com/llvm/llvm-project/commit/c86a982d7dad36c31fcfc5b06fd03b2e3289175f.diff

LOG: [SCEV] `getSequentialMinMaxExpr()`: rewrite deduplication to be fully recursive

Since we don't merge/expand non-sequential umin exprs into umin_seq exprs,
we may have umin_seq(umin(umin_seq())) chain, and the innermost umin_seq
can have duplicate operands still.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 27542bc554207..cd8e5fab6766f 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -531,6 +531,20 @@ class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
 public:
   Type *getType() const { return getOperand(0)->getType(); }
 
+  static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
+    assert(isSequentialMinMaxType(Ty));
+    switch (Ty) {
+    case scSequentialUMinExpr:
+      return scUMinExpr;
+    default:
+      llvm_unreachable("Not a sequential min/max type.");
+    }
+  }
+
+  SCEVTypes getEquivalentNonSequentialSCEVType() const {
+    return getEquivalentNonSequentialSCEVType(getSCEVType());
+  }
+
   static bool classof(const SCEV *S) {
     return isSequentialMinMaxType(S->getSCEVType());
   }

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 556472f544265..1025f30f0d8eb 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -3865,6 +3865,127 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
   return S;
 }
 
+namespace {
+
+class SCEVSequentialMinMaxDeduplicatingVisitor final
+    : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
+                         Optional<const SCEV *>> {
+  using RetVal = Optional<const SCEV *>;
+  using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>;
+
+  ScalarEvolution &SE;
+  const SCEVTypes RootKind; // Must be a sequential min/max expression.
+  const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
+  SmallPtrSet<const SCEV *, 16> SeenOps;
+
+  bool canRecurseInto(SCEVTypes Kind) const {
+    // We can only recurse into the SCEV expression of the same effective type
+    // as the type of our root SCEV expression.
+    return RootKind == Kind || NonSequentialRootKind == Kind;
+  };
+
+  RetVal visitAnyMinMaxExpr(const SCEV *S) {
+    assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
+           "Only for min/max expressions.");
+    SCEVTypes Kind = S->getSCEVType();
+
+    if (!canRecurseInto(Kind))
+      return S;
+
+    auto *NAry = cast<SCEVNAryExpr>(S);
+    SmallVector<const SCEV *> NewOps;
+    bool Changed =
+        visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
+
+    if (!Changed)
+      return S;
+    if (NewOps.empty())
+      return None;
+
+    return isa<SCEVSequentialMinMaxExpr>(S)
+               ? SE.getSequentialMinMaxExpr(Kind, NewOps)
+               : SE.getMinMaxExpr(Kind, NewOps);
+  }
+
+  RetVal visit(const SCEV *S) {
+    // Has the whole operand been seen already?
+    if (!SeenOps.insert(S).second)
+      return None;
+    return Base::visit(S);
+  }
+
+public:
+  SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
+                                           SCEVTypes RootKind)
+      : SE(SE), RootKind(RootKind),
+        NonSequentialRootKind(
+            SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
+                RootKind)) {}
+
+  bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
+                         SmallVectorImpl<const SCEV *> &NewOps) {
+    bool Changed = false;
+    SmallVector<const SCEV *> Ops;
+    Ops.reserve(OrigOps.size());
+
+    for (const SCEV *Op : OrigOps) {
+      RetVal NewOp = visit(Op);
+      if (NewOp != Op)
+        Changed = true;
+      if (NewOp)
+        Ops.emplace_back(*NewOp);
+    }
+
+    if (Changed)
+      NewOps = std::move(Ops);
+    return Changed;
+  }
+
+  RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
+
+  RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
+
+  RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
+
+  RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
+
+  RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
+
+  RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
+
+  RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
+
+  RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
+
+  RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
+
+  RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
+    return visitAnyMinMaxExpr(Expr);
+  }
+
+  RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
+    return visitAnyMinMaxExpr(Expr);
+  }
+
+  RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
+    return visitAnyMinMaxExpr(Expr);
+  }
+
+  RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
+    return visitAnyMinMaxExpr(Expr);
+  }
+
+  RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
+    return visitAnyMinMaxExpr(Expr);
+  }
+
+  RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
+
+  RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
+};
+
+} // namespace
+
 const SCEV *
 ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
                                          SmallVectorImpl<const SCEV *> &Ops) {
@@ -3895,45 +4016,8 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
 
   // Keep only the first instance of an operand.
   {
-    SmallPtrSet<const SCEV *, 16> SeenOps;
-    unsigned Idx = 0;
-    bool Changed = false;
-    while (Idx < Ops.size()) {
-      // Has the whole operand been seen already?
-      if (!SeenOps.insert(Ops[Idx]).second) {
-        Ops.erase(Ops.begin() + Idx);
-        Changed = true;
-        continue; // Look at operand under this index again.
-      }
-
-      // Look into non-sequential same-typed min/max expressions,
-      // drop any of it's operands that we have already seen.
-      // FIXME: once there are other sequential min/max types, generalize.
-      if (const auto *CommUMinExpr = dyn_cast<SCEVUMinExpr>(Ops[Idx])) {
-        SmallVector<const SCEV *> InnerOps;
-        InnerOps.reserve(CommUMinExpr->getNumOperands());
-        for (const SCEV *InnerOp : CommUMinExpr->operands()) {
-          if (SeenOps.insert(InnerOp).second) // Operand not seen before?
-            InnerOps.emplace_back(InnerOp);   // Keep this inner operand.
-        }
-        // Were any operands of this 'umin' themselves redundant?
-        if (InnerOps.size() != CommUMinExpr->getNumOperands()) {
-          Changed = true;
-          // Was the whole operand effectively redundant? Note that it can
-          // happen even when the operand itself wasn't redundant as a whole.
-          if (InnerOps.empty()) {
-            Ops.erase(Ops.begin() + Idx);
-            continue; // Look at operand under this index again.
-          }
-          // Recreate our operand.
-          Ops[Idx] = getMinMaxExpr(Ops[Idx]->getSCEVType(), InnerOps);
-        }
-      }
-
-      // Ok, can't do anything else about this operand, move onto the next one.
-      ++Idx;
-    }
-
+    SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
+    bool Changed = Deduplicator.visit(Kind, Ops, Ops);
     if (Changed)
       return getSequentialMinMaxExpr(Kind, Ops);
   }

diff  --git a/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll b/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll
index 5be01913f31f5..fd75fa31218dd 100644
--- a/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll
+++ b/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll
@@ -359,9 +359,9 @@ define i32 @logical_or_5ops_redundant_opearand_of_inner_uminseq(i32 %a, i32 %b,
 ; CHECK-NEXT:    %cond_p4 = select i1 %cond_p3, i1 true, i1 %cond_p2
 ; CHECK-NEXT:    --> %cond_p4 U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %first.loop: Variant }
 ; CHECK-NEXT:    %i = phi i32 [ 0, %first.loop.exit ], [ %i.next, %loop ]
-; CHECK-NEXT:    --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d)) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT:    --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c)) LoopDispositions: { %loop: Computable }
 ; CHECK-NEXT:    %i.next = add i32 %i, 1
-; CHECK-NEXT:    --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT:    --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))) LoopDispositions: { %loop: Computable }
 ; CHECK-NEXT:    %umin = call i32 @llvm.umin.i32(i32 %c, i32 %d)
 ; CHECK-NEXT:    --> (%c umin %d) U: full-set S: full-set Exits: (%c umin %d) LoopDispositions: { %loop: Invariant }
 ; CHECK-NEXT:    %umin2 = call i32 @llvm.umin.i32(i32 %umin, i32 %first.i)
@@ -371,9 +371,9 @@ define i32 @logical_or_5ops_redundant_opearand_of_inner_uminseq(i32 %a, i32 %b,
 ; CHECK-NEXT:    %cond = select i1 %cond_p8, i1 true, i1 %cond_p7
 ; CHECK-NEXT:    --> %cond U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Variant }
 ; CHECK-NEXT:  Determining loop execution counts for: @logical_or_5ops_redundant_opearand_of_inner_uminseq
-; CHECK-NEXT:  Loop %loop: backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))
+; CHECK-NEXT:  Loop %loop: backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))
 ; CHECK-NEXT:  Loop %loop: max backedge-taken count is -1
-; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))
+; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))
 ; CHECK-NEXT:   Predicates:
 ; CHECK:       Loop %loop: Trip multiple is 1
 ; CHECK-NEXT:  Loop %first.loop: backedge-taken count is (%e umin_seq %d umin_seq %a)


        


More information about the llvm-commits mailing list