[llvm] c86a982 - [SCEV] `getSequentialMinMaxExpr()`: rewrite deduplication to be fully recursive
Roman Lebedev via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 14 04:42:58 PST 2022
Author: Roman Lebedev
Date: 2022-01-14T15:42:26+03:00
New Revision: c86a982d7dad36c31fcfc5b06fd03b2e3289175f
URL: https://github.com/llvm/llvm-project/commit/c86a982d7dad36c31fcfc5b06fd03b2e3289175f
DIFF: https://github.com/llvm/llvm-project/commit/c86a982d7dad36c31fcfc5b06fd03b2e3289175f.diff
LOG: [SCEV] `getSequentialMinMaxExpr()`: rewrite deduplication to be fully recursive
Since we don't merge/expand non-sequential umin exprs into umin_seq exprs,
we may have umin_seq(umin(umin_seq())) chain, and the innermost umin_seq
can have duplicate operands still.
Added:
Modified:
llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 27542bc554207..cd8e5fab6766f 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -531,6 +531,20 @@ class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
public:
Type *getType() const { return getOperand(0)->getType(); }
+ static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
+ assert(isSequentialMinMaxType(Ty));
+ switch (Ty) {
+ case scSequentialUMinExpr:
+ return scUMinExpr;
+ default:
+ llvm_unreachable("Not a sequential min/max type.");
+ }
+ }
+
+ SCEVTypes getEquivalentNonSequentialSCEVType() const {
+ return getEquivalentNonSequentialSCEVType(getSCEVType());
+ }
+
static bool classof(const SCEV *S) {
return isSequentialMinMaxType(S->getSCEVType());
}
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 556472f544265..1025f30f0d8eb 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -3865,6 +3865,127 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
return S;
}
+namespace {
+
+class SCEVSequentialMinMaxDeduplicatingVisitor final
+ : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
+ Optional<const SCEV *>> {
+ using RetVal = Optional<const SCEV *>;
+ using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>;
+
+ ScalarEvolution &SE;
+ const SCEVTypes RootKind; // Must be a sequential min/max expression.
+ const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
+ SmallPtrSet<const SCEV *, 16> SeenOps;
+
+ bool canRecurseInto(SCEVTypes Kind) const {
+ // We can only recurse into the SCEV expression of the same effective type
+ // as the type of our root SCEV expression.
+ return RootKind == Kind || NonSequentialRootKind == Kind;
+ };
+
+ RetVal visitAnyMinMaxExpr(const SCEV *S) {
+ assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
+ "Only for min/max expressions.");
+ SCEVTypes Kind = S->getSCEVType();
+
+ if (!canRecurseInto(Kind))
+ return S;
+
+ auto *NAry = cast<SCEVNAryExpr>(S);
+ SmallVector<const SCEV *> NewOps;
+ bool Changed =
+ visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
+
+ if (!Changed)
+ return S;
+ if (NewOps.empty())
+ return None;
+
+ return isa<SCEVSequentialMinMaxExpr>(S)
+ ? SE.getSequentialMinMaxExpr(Kind, NewOps)
+ : SE.getMinMaxExpr(Kind, NewOps);
+ }
+
+ RetVal visit(const SCEV *S) {
+ // Has the whole operand been seen already?
+ if (!SeenOps.insert(S).second)
+ return None;
+ return Base::visit(S);
+ }
+
+public:
+ SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
+ SCEVTypes RootKind)
+ : SE(SE), RootKind(RootKind),
+ NonSequentialRootKind(
+ SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
+ RootKind)) {}
+
+ bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
+ SmallVectorImpl<const SCEV *> &NewOps) {
+ bool Changed = false;
+ SmallVector<const SCEV *> Ops;
+ Ops.reserve(OrigOps.size());
+
+ for (const SCEV *Op : OrigOps) {
+ RetVal NewOp = visit(Op);
+ if (NewOp != Op)
+ Changed = true;
+ if (NewOp)
+ Ops.emplace_back(*NewOp);
+ }
+
+ if (Changed)
+ NewOps = std::move(Ops);
+ return Changed;
+ }
+
+ RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
+
+ RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
+
+ RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
+
+ RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
+
+ RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
+
+ RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
+
+ RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
+
+ RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
+
+ RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
+
+ RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
+ return visitAnyMinMaxExpr(Expr);
+ }
+
+ RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
+ return visitAnyMinMaxExpr(Expr);
+ }
+
+ RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
+ return visitAnyMinMaxExpr(Expr);
+ }
+
+ RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
+ return visitAnyMinMaxExpr(Expr);
+ }
+
+ RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
+ return visitAnyMinMaxExpr(Expr);
+ }
+
+ RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
+
+ RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
+};
+
+} // namespace
+
const SCEV *
ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
SmallVectorImpl<const SCEV *> &Ops) {
@@ -3895,45 +4016,8 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
// Keep only the first instance of an operand.
{
- SmallPtrSet<const SCEV *, 16> SeenOps;
- unsigned Idx = 0;
- bool Changed = false;
- while (Idx < Ops.size()) {
- // Has the whole operand been seen already?
- if (!SeenOps.insert(Ops[Idx]).second) {
- Ops.erase(Ops.begin() + Idx);
- Changed = true;
- continue; // Look at operand under this index again.
- }
-
- // Look into non-sequential same-typed min/max expressions,
- // drop any of it's operands that we have already seen.
- // FIXME: once there are other sequential min/max types, generalize.
- if (const auto *CommUMinExpr = dyn_cast<SCEVUMinExpr>(Ops[Idx])) {
- SmallVector<const SCEV *> InnerOps;
- InnerOps.reserve(CommUMinExpr->getNumOperands());
- for (const SCEV *InnerOp : CommUMinExpr->operands()) {
- if (SeenOps.insert(InnerOp).second) // Operand not seen before?
- InnerOps.emplace_back(InnerOp); // Keep this inner operand.
- }
- // Were any operands of this 'umin' themselves redundant?
- if (InnerOps.size() != CommUMinExpr->getNumOperands()) {
- Changed = true;
- // Was the whole operand effectively redundant? Note that it can
- // happen even when the operand itself wasn't redundant as a whole.
- if (InnerOps.empty()) {
- Ops.erase(Ops.begin() + Idx);
- continue; // Look at operand under this index again.
- }
- // Recreate our operand.
- Ops[Idx] = getMinMaxExpr(Ops[Idx]->getSCEVType(), InnerOps);
- }
- }
-
- // Ok, can't do anything else about this operand, move onto the next one.
- ++Idx;
- }
-
+ SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
+ bool Changed = Deduplicator.visit(Kind, Ops, Ops);
if (Changed)
return getSequentialMinMaxExpr(Kind, Ops);
}
diff --git a/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll b/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll
index 5be01913f31f5..fd75fa31218dd 100644
--- a/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll
+++ b/llvm/test/Analysis/ScalarEvolution/exit-count-select-safe.ll
@@ -359,9 +359,9 @@ define i32 @logical_or_5ops_redundant_opearand_of_inner_uminseq(i32 %a, i32 %b,
; CHECK-NEXT: %cond_p4 = select i1 %cond_p3, i1 true, i1 %cond_p2
; CHECK-NEXT: --> %cond_p4 U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %first.loop: Variant }
; CHECK-NEXT: %i = phi i32 [ 0, %first.loop.exit ], [ %i.next, %loop ]
-; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d)) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {0,+,1}<%loop> U: full-set S: full-set Exits: (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c)) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %i.next = add i32 %i, 1
-; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {1,+,1}<%loop> U: full-set S: full-set Exits: (1 + (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %umin = call i32 @llvm.umin.i32(i32 %c, i32 %d)
; CHECK-NEXT: --> (%c umin %d) U: full-set S: full-set Exits: (%c umin %d) LoopDispositions: { %loop: Invariant }
; CHECK-NEXT: %umin2 = call i32 @llvm.umin.i32(i32 %umin, i32 %first.i)
@@ -371,9 +371,9 @@ define i32 @logical_or_5ops_redundant_opearand_of_inner_uminseq(i32 %a, i32 %b,
; CHECK-NEXT: %cond = select i1 %cond_p8, i1 true, i1 %cond_p7
; CHECK-NEXT: --> %cond U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Variant }
; CHECK-NEXT: Determining loop execution counts for: @logical_or_5ops_redundant_opearand_of_inner_uminseq
-; CHECK-NEXT: Loop %loop: backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))
+; CHECK-NEXT: Loop %loop: backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))
; CHECK-NEXT: Loop %loop: max backedge-taken count is -1
-; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d umin_seq %a) umin %c umin %d))
+; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (%a umin_seq %b umin_seq ((%e umin_seq %d) umin %c))
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
; CHECK-NEXT: Loop %first.loop: backedge-taken count is (%e umin_seq %d umin_seq %a)
More information about the llvm-commits
mailing list