[llvm] f8eeeff - [NFC][SCEV] Reflow `computeSCEVAtScope()` into an exhaustive switch

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 21 12:53:46 PST 2023


Author: Roman Lebedev
Date: 2023-01-21T23:38:14+03:00
New Revision: f8eeeffadad33585027a489aaac79ff64d1e3464

URL: https://github.com/llvm/llvm-project/commit/f8eeeffadad33585027a489aaac79ff64d1e3464
DIFF: https://github.com/llvm/llvm-project/commit/f8eeeffadad33585027a489aaac79ff64d1e3464.diff

LOG: [NFC][SCEV] Reflow `computeSCEVAtScope()` into an exhaustive switch

Otherwise instead of a compile-time error that you forgot to modify it,
you'd get a run-time error, which happened every time i've added new expr.

This is completely NFC, there are no other changes here.

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 6c71d69dd71b..8dcf11f59da1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9798,12 +9798,112 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
 }
 
 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
-  if (isa<SCEVConstant>(V))
+  switch (V->getSCEVType()) {
+  case scConstant:
     return V;
+  case scTruncate:
+  case scZeroExtend:
+  case scSignExtend:
+  case scPtrToInt: {
+    const SCEVCastExpr *Cast = cast<SCEVCastExpr>(V);
+    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
+    if (Op == Cast->getOperand())
+      return Cast; // must be loop invariant
+    return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
+  }
+  case scUDivExpr: {
+    const SCEVUDivExpr *Div = cast<SCEVUDivExpr>(V);
+    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
+    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
+    if (LHS == Div->getLHS() && RHS == Div->getRHS())
+      return Div; // must be loop invariant
+    return getUDivExpr(LHS, RHS);
+  }
+  case scAddRecExpr: {
+    // If this is a loop recurrence for a loop that does not contain L, then we
+    // are dealing with the final value computed by the loop.
+    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
+    // First, attempt to evaluate each operand.
+    // Avoid performing the look-up in the common case where the specified
+    // expression has no loop-variant portions.
+    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
+      const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
+      if (OpAtScope == AddRec->getOperand(i))
+        continue;
+
+      // Okay, at least one of these operands is loop variant but might be
+      // foldable.  Build a new instance of the folded commutative expression.
+      SmallVector<const SCEV *, 8> NewOps(AddRec->operands().take_front(i));
+      NewOps.push_back(OpAtScope);
+      for (++i; i != e; ++i)
+        NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
+
+      const SCEV *FoldedRec = getAddRecExpr(
+          NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
+      AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
+      // The addrec may be folded to a nonrecurrence, for example, if the
+      // induction variable is multiplied by zero after constant folding. Go
+      // ahead and return the folded value.
+      if (!AddRec)
+        return FoldedRec;
+      break;
+    }
+
+    // If the scope is outside the addrec's loop, evaluate it by using the
+    // loop exit value of the addrec.
+    if (!AddRec->getLoop()->contains(L)) {
+      // To evaluate this recurrence, we need to know how many times the AddRec
+      // loop iterates.  Compute this now.
+      const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
+      if (BackedgeTakenCount == getCouldNotCompute())
+        return AddRec;
+
+      // Then, evaluate the AddRec.
+      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
+    }
+
+    return AddRec;
+  }
+  case scAddExpr:
+  case scMulExpr:
+  case scUMaxExpr:
+  case scSMaxExpr:
+  case scUMinExpr:
+  case scSMinExpr:
+  case scSequentialUMinExpr: {
+    const auto *Comm = cast<SCEVNAryExpr>(V);
+    // Avoid performing the look-up in the common case where the specified
+    // expression has no loop-variant portions.
+    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
+      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
+      if (OpAtScope != Comm->getOperand(i)) {
+        // Okay, at least one of these operands is loop variant but might be
+        // foldable.  Build a new instance of the folded commutative expression.
+        SmallVector<const SCEV *, 8> NewOps(Comm->operands().take_front(i));
+        NewOps.push_back(OpAtScope);
 
-  // If this instruction is evolved from a constant-evolving PHI, compute the
-  // exit value from the loop without using SCEVs.
-  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
+        for (++i; i != e; ++i) {
+          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
+          NewOps.push_back(OpAtScope);
+        }
+        if (isa<SCEVAddExpr>(Comm))
+          return getAddExpr(NewOps, Comm->getNoWrapFlags());
+        if (isa<SCEVMulExpr>(Comm))
+          return getMulExpr(NewOps, Comm->getNoWrapFlags());
+        if (isa<SCEVMinMaxExpr>(Comm))
+          return getMinMaxExpr(Comm->getSCEVType(), NewOps);
+        if (isa<SCEVSequentialMinMaxExpr>(Comm))
+          return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
+        llvm_unreachable("Unknown commutative / sequential min/max SCEV type!");
+      }
+    }
+    // If we got here, all operands are loop invariant.
+    return Comm;
+  }
+  case scUnknown: {
+    // If this instruction is evolved from a constant-evolving PHI, compute the
+    // exit value from the loop without using SCEVs.
+    const SCEVUnknown *SU = cast<SCEVUnknown>(V);
     if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
       if (PHINode *PN = dyn_cast<PHINode>(I)) {
         const Loop *CurrLoop = this->LI[I->getParent()];
@@ -9916,98 +10016,9 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
     // This is some other type of SCEVUnknown, just return it.
     return V;
   }
-
-  if (isa<SCEVCommutativeExpr>(V) || isa<SCEVSequentialMinMaxExpr>(V)) {
-    const auto *Comm = cast<SCEVNAryExpr>(V);
-    // Avoid performing the look-up in the common case where the specified
-    // expression has no loop-variant portions.
-    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
-      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
-      if (OpAtScope != Comm->getOperand(i)) {
-        // Okay, at least one of these operands is loop variant but might be
-        // foldable.  Build a new instance of the folded commutative expression.
-        SmallVector<const SCEV *, 8> NewOps(Comm->operands().take_front(i));
-        NewOps.push_back(OpAtScope);
-
-        for (++i; i != e; ++i) {
-          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
-          NewOps.push_back(OpAtScope);
-        }
-        if (isa<SCEVAddExpr>(Comm))
-          return getAddExpr(NewOps, Comm->getNoWrapFlags());
-        if (isa<SCEVMulExpr>(Comm))
-          return getMulExpr(NewOps, Comm->getNoWrapFlags());
-        if (isa<SCEVMinMaxExpr>(Comm))
-          return getMinMaxExpr(Comm->getSCEVType(), NewOps);
-        if (isa<SCEVSequentialMinMaxExpr>(Comm))
-          return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
-        llvm_unreachable("Unknown commutative / sequential min/max SCEV type!");
-      }
-    }
-    // If we got here, all operands are loop invariant.
-    return Comm;
-  }
-
-  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
-    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
-    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
-    if (LHS == Div->getLHS() && RHS == Div->getRHS())
-      return Div; // must be loop invariant
-    return getUDivExpr(LHS, RHS);
-  }
-
-  // If this is a loop recurrence for a loop that does not contain L, then we
-  // are dealing with the final value computed by the loop.
-  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
-    // First, attempt to evaluate each operand.
-    // Avoid performing the look-up in the common case where the specified
-    // expression has no loop-variant portions.
-    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
-      const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
-      if (OpAtScope == AddRec->getOperand(i))
-        continue;
-
-      // Okay, at least one of these operands is loop variant but might be
-      // foldable.  Build a new instance of the folded commutative expression.
-      SmallVector<const SCEV *, 8> NewOps(AddRec->operands().take_front(i));
-      NewOps.push_back(OpAtScope);
-      for (++i; i != e; ++i)
-        NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
-
-      const SCEV *FoldedRec = getAddRecExpr(
-          NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
-      AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
-      // The addrec may be folded to a nonrecurrence, for example, if the
-      // induction variable is multiplied by zero after constant folding. Go
-      // ahead and return the folded value.
-      if (!AddRec)
-        return FoldedRec;
-      break;
-    }
-
-    // If the scope is outside the addrec's loop, evaluate it by using the
-    // loop exit value of the addrec.
-    if (!AddRec->getLoop()->contains(L)) {
-      // To evaluate this recurrence, we need to know how many times the AddRec
-      // loop iterates.  Compute this now.
-      const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
-      if (BackedgeTakenCount == getCouldNotCompute())
-        return AddRec;
-
-      // Then, evaluate the AddRec.
-      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
-    }
-
-    return AddRec;
-  }
-
-  if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
-    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
-    if (Op == Cast->getOperand())
-      return Cast; // must be loop invariant
-    return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
+  case scCouldNotCompute:
+    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
   }
-
   llvm_unreachable("Unknown SCEV type!");
 }
 


        


More information about the llvm-commits mailing list