[llvm] 921b8f4 - [SCEV][NFC] GetMinTrailingZeros switch case and naming cleanup

Joshua Cao via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 10 22:57:12 PDT 2023


Author: Joshua Cao
Date: 2023-04-10T22:56:29-07:00
New Revision: 921b8f40e816f980c39758087290c995cafbfd46

URL: https://github.com/llvm/llvm-project/commit/921b8f40e816f980c39758087290c995cafbfd46
DIFF: https://github.com/llvm/llvm-project/commit/921b8f40e816f980c39758087290c995cafbfd46.diff

LOG: [SCEV][NFC] GetMinTrailingZeros switch case and naming cleanup

* combine zext and sext into the one switch case
* combine vscale and udiv into one switch case
* renames according to LLVM style

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 8a25cf9b37c86..f27cf22639c96 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -962,7 +962,7 @@ class ScalarEvolution {
   /// (at every loop iteration).  It is, at the same time, the minimum number
   /// of times S is divisible by 2.  For example, given {4,+,8} it returns 2.
   /// If S is guaranteed to be 0, it returns the bitwidth of S.
-  uint32_t GetMinTrailingZeros(const SCEV *S);
+  uint32_t getMinTrailingZeros(const SCEV *S);
 
   /// Determine the unsigned range for a particular SCEV.
   /// NOTE: This returns a copy of the reference returned by getRangeRef.
@@ -1438,7 +1438,7 @@ class ScalarEvolution {
   ArrayRef<Value *> getSCEVValues(const SCEV *S);
 
   /// Private helper method for the GetMinTrailingZeros method
-  uint32_t GetMinTrailingZerosImpl(const SCEV *S);
+  uint32_t getMinTrailingZerosImpl(const SCEV *S);
 
   /// Information about the number of times a particular loop exit may be
   /// reached before exiting the loop.

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 283175201b919..4ee6d249e9531 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1212,7 +1212,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
   }
 
   // Return zero if truncating to known zeros.
-  uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
+  uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
   if (MinTrailingZeros >= getTypeSizeInBits(Ty))
     return getZero(Ty);
 
@@ -1497,7 +1497,7 @@ static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
   // Find number of trailing zeros of (x + y + ...) w/o the C first:
   uint32_t TZ = BitWidth;
   for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
-    TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
+    TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
   if (TZ) {
     // Set D to be as many least significant bits of C as possible while still
     // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
@@ -1514,7 +1514,7 @@ static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
                                             const APInt &ConstantStart,
                                             const SCEV *Step) {
   const unsigned BitWidth = ConstantStart.getBitWidth();
-  const uint32_t TZ = SE.GetMinTrailingZeros(Step);
+  const uint32_t TZ = SE.getMinTrailingZeros(Step);
   if (TZ)
     return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
                          : ConstantStart;
@@ -6270,27 +6270,19 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
   return getGEPExpr(GEP, IndexExprs);
 }
 
-uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
+uint32_t ScalarEvolution::getMinTrailingZerosImpl(const SCEV *S) {
   switch (S->getSCEVType()) {
   case scConstant:
     return cast<SCEVConstant>(S)->getAPInt().countr_zero();
-  case scVScale:
-    return 0;
   case scTruncate: {
     const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
-    return std::min(GetMinTrailingZeros(T->getOperand()),
+    return std::min(getMinTrailingZeros(T->getOperand()),
                     (uint32_t)getTypeSizeInBits(T->getType()));
   }
-  case scZeroExtend: {
-    const SCEVZeroExtendExpr *E = cast<SCEVZeroExtendExpr>(S);
-    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
-    return OpRes == getTypeSizeInBits(E->getOperand()->getType())
-               ? getTypeSizeInBits(E->getType())
-               : OpRes;
-  }
+  case scZeroExtend:
   case scSignExtend: {
-    const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
-    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+    const SCEVIntegralCastExpr *E = cast<SCEVIntegralCastExpr>(S);
+    uint32_t OpRes = getMinTrailingZeros(E->getOperand());
     return OpRes == getTypeSizeInBits(E->getOperand()->getType())
                ? getTypeSizeInBits(E->getType())
                : OpRes;
@@ -6298,14 +6290,16 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
   case scMulExpr: {
     const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
     // The result is the sum of all operands results.
-    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
+    uint32_t SumOpRes = getMinTrailingZeros(M->getOperand(0));
     uint32_t BitWidth = getTypeSizeInBits(M->getType());
-    for (unsigned i = 1, e = M->getNumOperands();
-         SumOpRes != BitWidth && i != e; ++i)
+    for (unsigned I = 1, E = M->getNumOperands();
+         SumOpRes != BitWidth && I != E; ++I)
       SumOpRes =
-          std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
+          std::min(SumOpRes + getMinTrailingZeros(M->getOperand(I)), BitWidth);
     return SumOpRes;
   }
+  case scVScale:
+    return 0;
   case scUDivExpr:
     return 0;
   case scPtrToInt:
@@ -6318,9 +6312,9 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
   case scSequentialUMinExpr: {
     // The result is the min of all operands results.
     ArrayRef<const SCEV *> Ops = S->operands();
-    uint32_t MinOpRes = GetMinTrailingZeros(Ops[0]);
+    uint32_t MinOpRes = getMinTrailingZeros(Ops[0]);
     for (unsigned I = 1, E = Ops.size(); MinOpRes && I != E; ++I)
-      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(Ops[I]));
+      MinOpRes = std::min(MinOpRes, getMinTrailingZeros(Ops[I]));
     return MinOpRes;
   }
   case scUnknown: {
@@ -6336,12 +6330,12 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
   llvm_unreachable("Unknown SCEV kind!");
 }
 
-uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
+uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
   auto I = MinTrailingZerosCache.find(S);
   if (I != MinTrailingZerosCache.end())
     return I->second;
 
-  uint32_t Result = GetMinTrailingZerosImpl(S);
+  uint32_t Result = getMinTrailingZerosImpl(S);
   auto InsertPair = MinTrailingZerosCache.insert({S, Result});
   assert(InsertPair.second && "Should insert a new key");
   return InsertPair.first->second;
@@ -6595,7 +6589,7 @@ const ConstantRange &ScalarEvolution::getRangeRef(
 
   // If the value has known zeros, the maximum value will have those known zeros
   // as well.
-  uint32_t TZ = GetMinTrailingZeros(S);
+  uint32_t TZ = getMinTrailingZeros(S);
   if (TZ != 0) {
     if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
       ConservativeResult =
@@ -8245,7 +8239,7 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
   if (!TC)
     // Attempt to factor more general cases. Returns the greatest power of
     // two divisor.
-    return GetSmallMultiple(GetMinTrailingZeros(TCExpr));
+    return GetSmallMultiple(getMinTrailingZeros(TCExpr));
 
   ConstantInt *Result = TC->getValue();
   assert(Result && "SCEVConstant expected to have non-null ConstantInt");
@@ -10108,7 +10102,7 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
   //
   // B is divisible by D if and only if the multiplicity of prime factor 2 for B
   // is not less than multiplicity of this prime factor for D.
-  if (SE.GetMinTrailingZeros(B) < Mult2)
+  if (SE.getMinTrailingZeros(B) < Mult2)
     return SE.getCouldNotCompute();
 
   // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic


        


More information about the llvm-commits mailing list