[llvm] f8eeeff - [NFC][SCEV] Reflow `computeSCEVAtScope()` into an exhaustive switch
Roman Lebedev via llvm-commits
llvm-commits at lists.llvm.org
Sat Jan 21 12:53:46 PST 2023
Author: Roman Lebedev
Date: 2023-01-21T23:38:14+03:00
New Revision: f8eeeffadad33585027a489aaac79ff64d1e3464
URL: https://github.com/llvm/llvm-project/commit/f8eeeffadad33585027a489aaac79ff64d1e3464
DIFF: https://github.com/llvm/llvm-project/commit/f8eeeffadad33585027a489aaac79ff64d1e3464.diff
LOG: [NFC][SCEV] Reflow `computeSCEVAtScope()` into an exhaustive switch
Otherwise instead of a compile-time error that you forgot to modify it,
you'd get a run-time error, which happened every time i've added new expr.
This is completely NFC, there are no other changes here.
Added:
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 6c71d69dd71b..8dcf11f59da1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9798,12 +9798,112 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
}
const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
- if (isa<SCEVConstant>(V))
+ switch (V->getSCEVType()) {
+ case scConstant:
return V;
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend:
+ case scPtrToInt: {
+ const SCEVCastExpr *Cast = cast<SCEVCastExpr>(V);
+ const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
+ if (Op == Cast->getOperand())
+ return Cast; // must be loop invariant
+ return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
+ }
+ case scUDivExpr: {
+ const SCEVUDivExpr *Div = cast<SCEVUDivExpr>(V);
+ const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
+ const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
+ if (LHS == Div->getLHS() && RHS == Div->getRHS())
+ return Div; // must be loop invariant
+ return getUDivExpr(LHS, RHS);
+ }
+ case scAddRecExpr: {
+ // If this is a loop recurrence for a loop that does not contain L, then we
+ // are dealing with the final value computed by the loop.
+ const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
+ // First, attempt to evaluate each operand.
+ // Avoid performing the look-up in the common case where the specified
+ // expression has no loop-variant portions.
+ for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
+ const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
+ if (OpAtScope == AddRec->getOperand(i))
+ continue;
+
+ // Okay, at least one of these operands is loop variant but might be
+ // foldable. Build a new instance of the folded commutative expression.
+ SmallVector<const SCEV *, 8> NewOps(AddRec->operands().take_front(i));
+ NewOps.push_back(OpAtScope);
+ for (++i; i != e; ++i)
+ NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
+
+ const SCEV *FoldedRec = getAddRecExpr(
+ NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
+ AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
+ // The addrec may be folded to a nonrecurrence, for example, if the
+ // induction variable is multiplied by zero after constant folding. Go
+ // ahead and return the folded value.
+ if (!AddRec)
+ return FoldedRec;
+ break;
+ }
+
+ // If the scope is outside the addrec's loop, evaluate it by using the
+ // loop exit value of the addrec.
+ if (!AddRec->getLoop()->contains(L)) {
+ // To evaluate this recurrence, we need to know how many times the AddRec
+ // loop iterates. Compute this now.
+ const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
+ if (BackedgeTakenCount == getCouldNotCompute())
+ return AddRec;
+
+ // Then, evaluate the AddRec.
+ return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
+ }
+
+ return AddRec;
+ }
+ case scAddExpr:
+ case scMulExpr:
+ case scUMaxExpr:
+ case scSMaxExpr:
+ case scUMinExpr:
+ case scSMinExpr:
+ case scSequentialUMinExpr: {
+ const auto *Comm = cast<SCEVNAryExpr>(V);
+ // Avoid performing the look-up in the common case where the specified
+ // expression has no loop-variant portions.
+ for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
+ const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
+ if (OpAtScope != Comm->getOperand(i)) {
+ // Okay, at least one of these operands is loop variant but might be
+ // foldable. Build a new instance of the folded commutative expression.
+ SmallVector<const SCEV *, 8> NewOps(Comm->operands().take_front(i));
+ NewOps.push_back(OpAtScope);
- // If this instruction is evolved from a constant-evolving PHI, compute the
- // exit value from the loop without using SCEVs.
- if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
+ for (++i; i != e; ++i) {
+ OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
+ NewOps.push_back(OpAtScope);
+ }
+ if (isa<SCEVAddExpr>(Comm))
+ return getAddExpr(NewOps, Comm->getNoWrapFlags());
+ if (isa<SCEVMulExpr>(Comm))
+ return getMulExpr(NewOps, Comm->getNoWrapFlags());
+ if (isa<SCEVMinMaxExpr>(Comm))
+ return getMinMaxExpr(Comm->getSCEVType(), NewOps);
+ if (isa<SCEVSequentialMinMaxExpr>(Comm))
+ return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
+ llvm_unreachable("Unknown commutative / sequential min/max SCEV type!");
+ }
+ }
+ // If we got here, all operands are loop invariant.
+ return Comm;
+ }
+ case scUnknown: {
+ // If this instruction is evolved from a constant-evolving PHI, compute the
+ // exit value from the loop without using SCEVs.
+ const SCEVUnknown *SU = cast<SCEVUnknown>(V);
if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
if (PHINode *PN = dyn_cast<PHINode>(I)) {
const Loop *CurrLoop = this->LI[I->getParent()];
@@ -9916,98 +10016,9 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
// This is some other type of SCEVUnknown, just return it.
return V;
}
-
- if (isa<SCEVCommutativeExpr>(V) || isa<SCEVSequentialMinMaxExpr>(V)) {
- const auto *Comm = cast<SCEVNAryExpr>(V);
- // Avoid performing the look-up in the common case where the specified
- // expression has no loop-variant portions.
- for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
- const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
- if (OpAtScope != Comm->getOperand(i)) {
- // Okay, at least one of these operands is loop variant but might be
- // foldable. Build a new instance of the folded commutative expression.
- SmallVector<const SCEV *, 8> NewOps(Comm->operands().take_front(i));
- NewOps.push_back(OpAtScope);
-
- for (++i; i != e; ++i) {
- OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
- NewOps.push_back(OpAtScope);
- }
- if (isa<SCEVAddExpr>(Comm))
- return getAddExpr(NewOps, Comm->getNoWrapFlags());
- if (isa<SCEVMulExpr>(Comm))
- return getMulExpr(NewOps, Comm->getNoWrapFlags());
- if (isa<SCEVMinMaxExpr>(Comm))
- return getMinMaxExpr(Comm->getSCEVType(), NewOps);
- if (isa<SCEVSequentialMinMaxExpr>(Comm))
- return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
- llvm_unreachable("Unknown commutative / sequential min/max SCEV type!");
- }
- }
- // If we got here, all operands are loop invariant.
- return Comm;
- }
-
- if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
- const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
- const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
- if (LHS == Div->getLHS() && RHS == Div->getRHS())
- return Div; // must be loop invariant
- return getUDivExpr(LHS, RHS);
- }
-
- // If this is a loop recurrence for a loop that does not contain L, then we
- // are dealing with the final value computed by the loop.
- if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
- // First, attempt to evaluate each operand.
- // Avoid performing the look-up in the common case where the specified
- // expression has no loop-variant portions.
- for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
- const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
- if (OpAtScope == AddRec->getOperand(i))
- continue;
-
- // Okay, at least one of these operands is loop variant but might be
- // foldable. Build a new instance of the folded commutative expression.
- SmallVector<const SCEV *, 8> NewOps(AddRec->operands().take_front(i));
- NewOps.push_back(OpAtScope);
- for (++i; i != e; ++i)
- NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
-
- const SCEV *FoldedRec = getAddRecExpr(
- NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
- AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
- // The addrec may be folded to a nonrecurrence, for example, if the
- // induction variable is multiplied by zero after constant folding. Go
- // ahead and return the folded value.
- if (!AddRec)
- return FoldedRec;
- break;
- }
-
- // If the scope is outside the addrec's loop, evaluate it by using the
- // loop exit value of the addrec.
- if (!AddRec->getLoop()->contains(L)) {
- // To evaluate this recurrence, we need to know how many times the AddRec
- // loop iterates. Compute this now.
- const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
- if (BackedgeTakenCount == getCouldNotCompute())
- return AddRec;
-
- // Then, evaluate the AddRec.
- return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
- }
-
- return AddRec;
- }
-
- if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
- const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
- if (Op == Cast->getOperand())
- return Cast; // must be loop invariant
- return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
}
-
llvm_unreachable("Unknown SCEV type!");
}
More information about the llvm-commits
mailing list