[llvm] e056758 - [NFCI][SCEV] Always refer to enum SCEVTypes as enum, not integer

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 19 14:11:15 PDT 2020


Author: Roman Lebedev
Date: 2020-10-20T00:10:22+03:00
New Revision: e0567582b8b1def8656f4a5addce0909fa51c86e

URL: https://github.com/llvm/llvm-project/commit/e0567582b8b1def8656f4a5addce0909fa51c86e
DIFF: https://github.com/llvm/llvm-project/commit/e0567582b8b1def8656f4a5addce0909fa51c86e.diff

LOG: [NFCI][SCEV] Always refer to enum SCEVTypes as enum, not integer

The main tricky thing here is forward-declaring the enum:
we have to specify it's underlying data type.

In particular, this avoids the danger of switching over the SCEVTypes,
but actually switching over an integer, and not being notified
when some case is not handled.

I have updated most of such switches to be exaustive and not have
a default case, where it's pretty obvious to be the intent,
however not all of them.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
    llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 19e3607bc0dd..85c04a79ba45 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -70,6 +70,7 @@ class StructType;
 class TargetLibraryInfo;
 class Type;
 class Value;
+enum SCEVTypes : unsigned short;
 
 /// This class represents an analyzed expression in the program.  These are
 /// opaque objects that the client is not allowed to do much with directly.
@@ -82,7 +83,7 @@ class SCEV : public FoldingSetNode {
   FoldingSetNodeIDRef FastID;
 
   // The SCEV baseclass this node corresponds to
-  const unsigned short SCEVType;
+  const SCEVTypes SCEVType;
 
 protected:
   // Estimated complexity of this node's expression tree size.
@@ -119,13 +120,13 @@ class SCEV : public FoldingSetNode {
     NoWrapMask = (1 << 3) - 1
   };
 
-  explicit SCEV(const FoldingSetNodeIDRef ID, unsigned SCEVTy,
+  explicit SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                 unsigned short ExpressionSize)
       : FastID(ID), SCEVType(SCEVTy), ExpressionSize(ExpressionSize) {}
   SCEV(const SCEV &) = delete;
   SCEV &operator=(const SCEV &) = delete;
 
-  unsigned getSCEVType() const { return SCEVType; }
+  SCEVTypes getSCEVType() const { return SCEVType; }
 
   /// Return the LLVM type of this SCEV expression.
   Type *getType() const;
@@ -574,7 +575,7 @@ class ScalarEvolution {
                          const SmallVectorImpl<const SCEV *> &IndexExprs);
   const SCEV *getAbsExpr(const SCEV *Op, bool IsNSW);
   const SCEV *getSignumExpr(const SCEV *Op);
-  const SCEV *getMinMaxExpr(unsigned Kind,
+  const SCEV *getMinMaxExpr(SCEVTypes Kind,
                             SmallVectorImpl<const SCEV *> &Operands);
   const SCEV *getSMaxExpr(const SCEV *LHS, const SCEV *RHS);
   const SCEV *getSMaxExpr(SmallVectorImpl<const SCEV *> &Operands);
@@ -1958,7 +1959,7 @@ class ScalarEvolution {
   /// constructed to look up the SCEV and the third component is the insertion
   /// point.
   std::tuple<SCEV *, FoldingSetNodeID, void *>
-  findExistingSCEVInCache(int SCEVType, ArrayRef<const SCEV *> Ops);
+  findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);
 
   FoldingSet<SCEV> UniqueSCEVs;
   FoldingSet<SCEVPredicate> UniquePreds;

diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index e94dddbadec1..b995de7d88d5 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -35,7 +35,7 @@ class ConstantRange;
 class Loop;
 class Type;
 
-  enum SCEVTypes {
+  enum SCEVTypes : unsigned short {
     // These should be ordered in terms of increasing complexity to make the
     // folders simpler.
     scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
@@ -77,7 +77,7 @@ class Type;
     std::array<const SCEV *, 1> Operands;
     Type *Ty;
 
-    SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, unsigned SCEVTy,
+    SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                          const SCEV *op, Type *ty);
 
   public:
@@ -412,7 +412,7 @@ class Type;
 
   public:
     static bool classof(const SCEV *S) {
-      return isMinMaxType(static_cast<SCEVTypes>(S->getSCEVType()));
+      return isMinMaxType(S->getSCEVType());
     }
 
     static enum SCEVTypes negate(enum SCEVTypes T) {
@@ -567,9 +567,8 @@ class Type;
         return ((SC*)this)->visitUnknown((const SCEVUnknown*)S);
       case scCouldNotCompute:
         return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S);
-      default:
-        llvm_unreachable("Unknown SCEV type!");
       }
+      llvm_unreachable("Unknown SCEV kind!");
     }
 
     RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
@@ -606,12 +605,12 @@ class Type;
         switch (S->getSCEVType()) {
         case scConstant:
         case scUnknown:
-          break;
+          continue;
         case scTruncate:
         case scZeroExtend:
         case scSignExtend:
           push(cast<SCEVIntegralCastExpr>(S)->getOperand());
-          break;
+          continue;
         case scAddExpr:
         case scMulExpr:
         case scSMaxExpr:
@@ -621,18 +620,17 @@ class Type;
         case scAddRecExpr:
           for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
             push(Op);
-          break;
+          continue;
         case scUDivExpr: {
           const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
           push(UDiv->getLHS());
           push(UDiv->getRHS());
-          break;
+          continue;
         }
         case scCouldNotCompute:
           llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
-        default:
-          llvm_unreachable("Unknown SCEV kind!");
         }
+        llvm_unreachable("Unknown SCEV kind!");
       }
     }
   };

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 8441095467d5..efc4600e248f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -243,7 +243,7 @@ LLVM_DUMP_METHOD void SCEV::dump() const {
 #endif
 
 void SCEV::print(raw_ostream &OS) const {
-  switch (static_cast<SCEVTypes>(getSCEVType())) {
+  switch (getSCEVType()) {
   case scConstant:
     cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
     return;
@@ -304,6 +304,8 @@ void SCEV::print(raw_ostream &OS) const {
     case scSMinExpr:
       OpStr = " smin ";
       break;
+    default:
+      llvm_unreachable("There are no other nary expression types.");
     }
     OS << "(";
     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
@@ -320,6 +322,10 @@ void SCEV::print(raw_ostream &OS) const {
         OS << "<nuw>";
       if (NAry->hasNoSignedWrap())
         OS << "<nsw>";
+      break;
+    default:
+      // Nothing to print for other nary expressions.
+      break;
     }
     return;
   }
@@ -361,7 +367,7 @@ void SCEV::print(raw_ostream &OS) const {
 }
 
 Type *SCEV::getType() const {
-  switch (static_cast<SCEVTypes>(getSCEVType())) {
+  switch (getSCEVType()) {
   case scConstant:
     return cast<SCEVConstant>(this)->getType();
   case scTruncate:
@@ -446,7 +452,7 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
 }
 
 SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
-                                           unsigned SCEVTy, const SCEV *op,
+                                           SCEVTypes SCEVTy, const SCEV *op,
                                            Type *ty)
     : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
   Operands[0] = op;
@@ -668,7 +674,7 @@ static int CompareSCEVComplexity(
     return 0;
 
   // Primarily, sort the SCEVs by their getSCEVType().
-  unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
+  SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
   if (LType != RType)
     return (int)LType - (int)RType;
 
@@ -677,7 +683,7 @@ static int CompareSCEVComplexity(
   // Aside from the getSCEVType() ordering, the particular ordering
   // isn't very important except that it's beneficial to be consistent,
   // so that (a + b) and (b + a) don't end up as 
diff erent expressions.
-  switch (static_cast<SCEVTypes>(LType)) {
+  switch (LType) {
   case scUnknown: {
     const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
     const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
@@ -3325,7 +3331,7 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP,
 }
 
 std::tuple<SCEV *, FoldingSetNodeID, void *>
-ScalarEvolution::findExistingSCEVInCache(int SCEVType,
+ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
                                          ArrayRef<const SCEV *> Ops) {
   FoldingSetNodeID ID;
   void *IP = nullptr;
@@ -3346,7 +3352,7 @@ const SCEV *ScalarEvolution::getSignumExpr(const SCEV *Op) {
   return getSMinExpr(getSMaxExpr(Op, getMinusOne(Ty)), getOne(Ty));
 }
 
-const SCEV *ScalarEvolution::getMinMaxExpr(unsigned Kind,
+const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
                                            SmallVectorImpl<const SCEV *> &Ops) {
   assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
   if (Ops.size() == 1) return Ops[0];
@@ -3470,8 +3476,8 @@ const SCEV *ScalarEvolution::getMinMaxExpr(unsigned Kind,
     return ExistingSCEV;
   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
-  SCEV *S = new (SCEVAllocator) SCEVMinMaxExpr(
-      ID.Intern(SCEVAllocator), static_cast<SCEVTypes>(Kind), O, Ops.size());
+  SCEV *S = new (SCEVAllocator)
+      SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
 
   UniqueSCEVs.InsertNode(S, IP);
   addToLoopUseLists(S);
@@ -3792,9 +3798,8 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
           return (const SCEV *)nullptr;
         MatchedOperands.push_back(Matched);
       }
-      return getMinMaxExpr(
-          SCEVMinMaxExpr::negate(static_cast<SCEVTypes>(MME->getSCEVType())),
-          MatchedOperands);
+      return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
+                           MatchedOperands);
     };
     if (const SCEV *Replaced = MatchMinMaxNegation(MME))
       return Replaced;
@@ -5036,7 +5041,7 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
         // We do not try to smart about these at all.
         return setUnavailable();
       }
-      llvm_unreachable("switch should be fully covered!");
+      llvm_unreachable("Unknown SCEV kind!");
     }
 
     bool isDone() { return TraversalDone; }
@@ -7970,10 +7975,10 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
 /// Returns NULL if the SCEV isn't representable as a Constant.
 static Constant *BuildConstantFromSCEV(const SCEV *V) {
-  switch (static_cast<SCEVTypes>(V->getSCEVType())) {
+  switch (V->getSCEVType()) {
   case scCouldNotCompute:
   case scAddRecExpr:
-    break;
+    return nullptr;
   case scConstant:
     return cast<SCEVConstant>(V)->getValue();
   case scUnknown:
@@ -7982,19 +7987,19 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
     const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
     if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
       return ConstantExpr::getSExt(CastOp, SS->getType());
-    break;
+    return nullptr;
   }
   case scZeroExtend: {
     const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
     if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
       return ConstantExpr::getZExt(CastOp, SZ->getType());
-    break;
+    return nullptr;
   }
   case scTruncate: {
     const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
     if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
       return ConstantExpr::getTrunc(CastOp, ST->getType());
-    break;
+    return nullptr;
   }
   case scAddExpr: {
     const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
@@ -8034,7 +8039,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
       }
       return C;
     }
-    break;
+    return nullptr;
   }
   case scMulExpr: {
     const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
@@ -8050,7 +8055,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
       }
       return C;
     }
-    break;
+    return nullptr;
   }
   case scUDivExpr: {
     const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
@@ -8058,15 +8063,15 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
       if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
         if (LHS->getType() == RHS->getType())
           return ConstantExpr::getUDiv(LHS, RHS);
-    break;
+    return nullptr;
   }
   case scSMaxExpr:
   case scUMaxExpr:
   case scSMinExpr:
   case scUMinExpr:
-    break; // TODO: smax, umax, smin, umax.
+    return nullptr; // TODO: smax, umax, smin, umax.
   }
-  return nullptr;
+  llvm_unreachable("Unknown SCEV kind!");
 }
 
 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
@@ -11794,7 +11799,7 @@ ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
 
 ScalarEvolution::LoopDisposition
 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
-  switch (static_cast<SCEVTypes>(S->getSCEVType())) {
+  switch (S->getSCEVType()) {
   case scConstant:
     return LoopInvariant;
   case scTruncate:
@@ -11901,7 +11906,7 @@ ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
 
 ScalarEvolution::BlockDisposition
 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
-  switch (static_cast<SCEVTypes>(S->getSCEVType())) {
+  switch (S->getSCEVType()) {
   case scConstant:
     return ProperlyDominatesBlock;
   case scTruncate:

diff  --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index b6d987978f87..b67b68cfbd37 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -937,6 +937,8 @@ static bool isHighCostExpansion(const SCEV *S,
   case scSignExtend:
     return isHighCostExpansion(cast<SCEVSignExtendExpr>(S)->getOperand(),
                                Processed, SE);
+  default:
+    break;
   }
 
   if (!Processed.insert(S).second)
@@ -2788,6 +2790,7 @@ static const SCEV *getExprBase(const SCEV *S) {
   case scAddRecExpr:
     return getExprBase(cast<SCEVAddRecExpr>(S)->getStart());
   }
+  llvm_unreachable("Unknown SCEV kind!");
 }
 
 /// Return true if the chain increment is profitable to expand into a loop

diff  --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index ef5664f5f004..3e22b30a211c 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -2236,8 +2236,8 @@ template<typename T> static int costAndCollectOperands(
   };
 
   switch (S->getSCEVType()) {
-  default:
-    llvm_unreachable("No other scev expressions possible.");
+  case scCouldNotCompute:
+    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
   case scUnknown:
   case scConstant:
     return 0;
@@ -2351,6 +2351,8 @@ bool SCEVExpander::isHighCostExpansionHelper(
           : TargetTransformInfo::TCK_RecipThroughput;
 
   switch (S->getSCEVType()) {
+  case scCouldNotCompute:
+    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
   case scUnknown:
     // Assume to be zero-cost.
     return false;
@@ -2416,7 +2418,7 @@ bool SCEVExpander::isHighCostExpansionHelper(
     return BudgetRemaining < 0;
   }
   }
-  llvm_unreachable("Switch is exaustive and we return in all of them.");
+  llvm_unreachable("Unknown SCEV kind!");
 }
 
 Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,


        


More information about the llvm-commits mailing list