[llvm] d486fdf - [NFC][SCEV] Reflow `getRangeRef()` into an exhaustive switch

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 21 14:40:43 PST 2023


Author: Roman Lebedev
Date: 2023-01-22T01:34:22+03:00
New Revision: d486fdffdaa6f7f820955ed615c44117e498168f

URL: https://github.com/llvm/llvm-project/commit/d486fdffdaa6f7f820955ed615c44117e498168f
DIFF: https://github.com/llvm/llvm-project/commit/d486fdffdaa6f7f820955ed615c44117e498168f.diff

LOG: [NFC][SCEV] Reflow `getRangeRef()` into an exhaustive switch

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 732f5921b535..4e381c4599d9 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -6610,7 +6610,37 @@ const ConstantRange &ScalarEvolution::getRangeRef(
           APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
   }
 
-  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
+  switch (S->getSCEVType()) {
+  case scConstant:
+    llvm_unreachable("Already handled above.");
+  case scTruncate: {
+    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
+    ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
+    return setRange(
+        Trunc, SignHint,
+        ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
+  }
+  case scZeroExtend: {
+    const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
+    ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
+    return setRange(
+        ZExt, SignHint,
+        ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
+  }
+  case scSignExtend: {
+    const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
+    ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
+    return setRange(
+        SExt, SignHint,
+        ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
+  }
+  case scPtrToInt: {
+    const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
+    ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
+    return setRange(PtrToInt, SignHint, X);
+  }
+  case scAddExpr: {
+    const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
     ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
     unsigned WrapType = OBO::AnyWrap;
     if (Add->hasNoSignedWrap())
@@ -6623,78 +6653,23 @@ const ConstantRange &ScalarEvolution::getRangeRef(
     return setRange(Add, SignHint,
                     ConservativeResult.intersectWith(X, RangeType));
   }
-
-  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
+  case scMulExpr: {
+    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
     ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
       X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
     return setRange(Mul, SignHint,
                     ConservativeResult.intersectWith(X, RangeType));
   }
-
-  if (isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) {
-    Intrinsic::ID ID;
-    switch (S->getSCEVType()) {
-    case scUMaxExpr:
-      ID = Intrinsic::umax;
-      break;
-    case scSMaxExpr:
-      ID = Intrinsic::smax;
-      break;
-    case scUMinExpr:
-    case scSequentialUMinExpr:
-      ID = Intrinsic::umin;
-      break;
-    case scSMinExpr:
-      ID = Intrinsic::smin;
-      break;
-    default:
-      llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
-    }
-
-    const auto *NAry = cast<SCEVNAryExpr>(S);
-    ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
-    for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
-      X = X.intrinsic(
-          ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
-    return setRange(S, SignHint,
-                    ConservativeResult.intersectWith(X, RangeType));
-  }
-
-  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
+  case scUDivExpr: {
+    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
     ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
     ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
     return setRange(UDiv, SignHint,
                     ConservativeResult.intersectWith(X.udiv(Y), RangeType));
   }
-
-  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
-    ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
-    return setRange(ZExt, SignHint,
-                    ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
-                                                     RangeType));
-  }
-
-  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
-    ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
-    return setRange(SExt, SignHint,
-                    ConservativeResult.intersectWith(X.signExtend(BitWidth),
-                                                     RangeType));
-  }
-
-  if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
-    ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
-    return setRange(PtrToInt, SignHint, X);
-  }
-
-  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
-    ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
-    return setRange(Trunc, SignHint,
-                    ConservativeResult.intersectWith(X.truncate(BitWidth),
-                                                     RangeType));
-  }
-
-  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
+  case scAddRecExpr: {
+    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
     // If there's no unsigned wrap, the value will never be less than its
     // initial value.
     if (AddRec->hasNoUnsignedWrap()) {
@@ -6725,15 +6700,16 @@ const ConstantRange &ScalarEvolution::getRangeRef(
             RangeType);
       else if (AllNonPos)
         ConservativeResult = ConservativeResult.intersectWith(
-            ConstantRange::getNonEmpty(
-                APInt::getSignedMinValue(BitWidth),
-                getSignedRangeMax(AddRec->getStart()) + 1),
+            ConstantRange::getNonEmpty(APInt::getSignedMinValue(BitWidth),
+                                       getSignedRangeMax(AddRec->getStart()) +
+                                           1),
             RangeType);
     }
 
     // TODO: non-affine addrec
     if (AddRec->isAffine()) {
-      const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(AddRec->getLoop());
+      const SCEV *MaxBECount =
+          getConstantMaxBackedgeTakenCount(AddRec->getLoop());
       if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
           getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
         auto RangeFromAffine = getRangeForAffineAR(
@@ -6766,8 +6742,40 @@ const ConstantRange &ScalarEvolution::getRangeRef(
 
     return setRange(AddRec, SignHint, std::move(ConservativeResult));
   }
+  case scUMaxExpr:
+  case scSMaxExpr:
+  case scUMinExpr:
+  case scSMinExpr:
+  case scSequentialUMinExpr: {
+    Intrinsic::ID ID;
+    switch (S->getSCEVType()) {
+    case scUMaxExpr:
+      ID = Intrinsic::umax;
+      break;
+    case scSMaxExpr:
+      ID = Intrinsic::smax;
+      break;
+    case scUMinExpr:
+    case scSequentialUMinExpr:
+      ID = Intrinsic::umin;
+      break;
+    case scSMinExpr:
+      ID = Intrinsic::smin;
+      break;
+    default:
+      llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
+    }
 
-  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+    const auto *NAry = cast<SCEVNAryExpr>(S);
+    ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
+    for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
+      X = X.intrinsic(
+          ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
+    return setRange(S, SignHint,
+                    ConservativeResult.intersectWith(X, RangeType));
+  }
+  case scUnknown: {
+    const SCEVUnknown *U = cast<SCEVUnknown>(S);
 
     // Check if the IR explicitly contains !range metadata.
     std::optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
@@ -6834,7 +6842,7 @@ const ConstantRange &ScalarEvolution::getRangeRef(
             ConservativeResult.intersectWith(RangeFromOps, RangeType);
         bool Erased = PendingPhiRanges.erase(Phi);
         assert(Erased && "Failed to erase Phi properly?");
-        (void) Erased;
+        (void)Erased;
       }
     }
 
@@ -6847,6 +6855,9 @@ const ConstantRange &ScalarEvolution::getRangeRef(
 
     return setRange(U, SignHint, std::move(ConservativeResult));
   }
+  case scCouldNotCompute:
+    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+  }
 
   return setRange(S, SignHint, std::move(ConservativeResult));
 }


        


More information about the llvm-commits mailing list