[llvm] 27963cc - [NFC][ScalarEvolution] Refactor createNodeForSelectOrPHI

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 16 12:32:48 PDT 2021


Author: Eli Friedman
Date: 2021-06-16T12:32:32-07:00
New Revision: 27963ccf07683df951235948622bd2bef33c0c6d

URL: https://github.com/llvm/llvm-project/commit/27963ccf07683df951235948622bd2bef33c0c6d
DIFF: https://github.com/llvm/llvm-project/commit/27963ccf07683df951235948622bd2bef33c0c6d.diff

LOG: [NFC][ScalarEvolution] Refactor createNodeForSelectOrPHI

In preparation for D103660.

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 0dd7be566ad4d..c5cb7c4cfe743 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5534,48 +5534,34 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
   switch (ICI->getPredicate()) {
   case ICmpInst::ICMP_SLT:
   case ICmpInst::ICMP_SLE:
-    std::swap(LHS, RHS);
-    LLVM_FALLTHROUGH;
-  case ICmpInst::ICMP_SGT:
-  case ICmpInst::ICMP_SGE:
-    // a >s b ? a+x : b+x  ->  smax(a, b)+x
-    // a >s b ? b+x : a+x  ->  smin(a, b)+x
-    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
-      const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
-      const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
-      const SCEV *LA = getSCEV(TrueVal);
-      const SCEV *RA = getSCEV(FalseVal);
-      const SCEV *LDiff = getMinusSCEV(LA, LS);
-      const SCEV *RDiff = getMinusSCEV(RA, RS);
-      if (LDiff == RDiff)
-        return getAddExpr(getSMaxExpr(LS, RS), LDiff);
-      LDiff = getMinusSCEV(LA, RS);
-      RDiff = getMinusSCEV(RA, LS);
-      if (LDiff == RDiff)
-        return getAddExpr(getSMinExpr(LS, RS), LDiff);
-    }
-    break;
   case ICmpInst::ICMP_ULT:
   case ICmpInst::ICMP_ULE:
     std::swap(LHS, RHS);
     LLVM_FALLTHROUGH;
+  case ICmpInst::ICMP_SGT:
+  case ICmpInst::ICMP_SGE:
   case ICmpInst::ICMP_UGT:
   case ICmpInst::ICMP_UGE:
-    // a >u b ? a+x : b+x  ->  umax(a, b)+x
-    // a >u b ? b+x : a+x  ->  umin(a, b)+x
+    // a > b ? a+x : b+x  ->  max(a, b)+x
+    // a > b ? b+x : a+x  ->  min(a, b)+x
     if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
-      const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
-      const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
+      bool Signed = ICI->isSigned();
+      const SCEV *LS = Signed ? getNoopOrSignExtend(getSCEV(LHS), I->getType())
+                              : getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+      const SCEV *RS = Signed ? getNoopOrSignExtend(getSCEV(RHS), I->getType())
+                              : getNoopOrZeroExtend(getSCEV(RHS), I->getType());
       const SCEV *LA = getSCEV(TrueVal);
       const SCEV *RA = getSCEV(FalseVal);
       const SCEV *LDiff = getMinusSCEV(LA, LS);
       const SCEV *RDiff = getMinusSCEV(RA, RS);
       if (LDiff == RDiff)
-        return getAddExpr(getUMaxExpr(LS, RS), LDiff);
+        return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
+                          LDiff);
       LDiff = getMinusSCEV(LA, RS);
       RDiff = getMinusSCEV(RA, LS);
       if (LDiff == RDiff)
-        return getAddExpr(getUMinExpr(LS, RS), LDiff);
+        return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
+                          LDiff);
     }
     break;
   case ICmpInst::ICMP_NE:


        


More information about the llvm-commits mailing list