[llvm] 9c4baf5 - [ScalarEvolution] Strictly enforce pointer/int type rules.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 9 17:33:08 PDT 2021


Author: Eli Friedman
Date: 2021-07-09T17:29:26-07:00
New Revision: 9c4baf5101e9ee55bfae1ce7657a7b89a06b4ccf

URL: https://github.com/llvm/llvm-project/commit/9c4baf5101e9ee55bfae1ce7657a7b89a06b4ccf
DIFF: https://github.com/llvm/llvm-project/commit/9c4baf5101e9ee55bfae1ce7657a7b89a06b4ccf.diff

LOG: [ScalarEvolution] Strictly enforce pointer/int type rules.

Rules:

1. SCEVUnknown is a pointer if and only if the LLVM IR value is a
   pointer.
2. SCEVPtrToInt is never a pointer.
3. If any other SCEV expression has no pointer operands, the result is
   an integer.
4. If a SCEVAddExpr has exactly one pointer operand, the result is a
   pointer.
5. If a SCEVAddRecExpr's first operand is a pointer, and it has no other
   pointer operands, the result is a pointer.
6. If every operand of a SCEVMinMaxExpr is a pointer, the result is a
   pointer.
7. Otherwise, the SCEV expression is invalid.

I'm not sure how useful rule 6 is in practice.  If we exclude it, we can
guarantee that ScalarEvolution::getPointerBase always returns a
SCEVUnknown, which might be a helpful property. Anyway, I'll leave that
for a followup.

This is basically mop-up at this point; all the changes with significant
functional effects have landed.  Some of the remaining changes could be
split off, but I don't see much point.

Differential Revision: https://reviews.llvm.org/D105510

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
    llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 221b116d4371e..c0ce5de7c60c1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1192,6 +1192,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
          "This is not a truncating conversion!");
   assert(isSCEVable(Ty) &&
          "This is not a conversion to a SCEVable type!");
+  assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
   Ty = getEffectiveSCEVType(Ty);
 
   FoldingSetNodeID ID;
@@ -1581,6 +1582,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
          "This is not an extending conversion!");
   assert(isSCEVable(Ty) &&
          "This is not a conversion to a SCEVable type!");
+  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
   Ty = getEffectiveSCEVType(Ty);
 
   // Fold if the operand is constant.
@@ -1883,6 +1885,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
          "This is not an extending conversion!");
   assert(isSCEVable(Ty) &&
          "This is not a conversion to a SCEVable type!");
+  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
   Ty = getEffectiveSCEVType(Ty);
 
   // Fold if the operand is constant.
@@ -2410,6 +2413,9 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
            "SCEVAddExpr operand types don't match!");
+  unsigned NumPtrs = count_if(
+      Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
+  assert(NumPtrs <= 1 && "add has at most one pointer operand");
 #endif
 
   // Sort by complexity, this groups all similar expression types together.
@@ -2645,12 +2651,16 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
       Ops.clear();
       if (AccumulatedConstant != 0)
         Ops.push_back(getConstant(AccumulatedConstant));
-      for (auto &MulOp : MulOpLists)
-        if (MulOp.first != 0)
+      for (auto &MulOp : MulOpLists) {
+        if (MulOp.first == 1) {
+          Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
+        } else if (MulOp.first != 0) {
           Ops.push_back(getMulExpr(
               getConstant(MulOp.first),
               getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
               SCEV::FlagAnyWrap, Depth + 1));
+        }
+      }
       if (Ops.empty())
         return getZero(Ty);
       if (Ops.size() == 1)
@@ -2969,9 +2979,10 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
   assert(!Ops.empty() && "Cannot get empty mul!");
   if (Ops.size() == 1) return Ops[0];
 #ifndef NDEBUG
-  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
+  Type *ETy = Ops[0]->getType();
+  assert(!ETy->isPointerTy());
   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
-    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
+    assert(Ops[i]->getType() == ETy &&
            "SCEVMulExpr operand types don't match!");
 #endif
 
@@ -3256,8 +3267,9 @@ const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
 /// possible.
 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
                                          const SCEV *RHS) {
-  assert(getEffectiveSCEVType(LHS->getType()) ==
-         getEffectiveSCEVType(RHS->getType()) &&
+  assert(!LHS->getType()->isPointerTy() &&
+         "SCEVUDivExpr operand can't be pointer!");
+  assert(LHS->getType() == RHS->getType() &&
          "SCEVUDivExpr operand types don't match!");
 
   FoldingSetNodeID ID;
@@ -3506,9 +3518,11 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
   if (Operands.size() == 1) return Operands[0];
 #ifndef NDEBUG
   Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
-  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
+  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
     assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
            "SCEVAddRecExpr operand types don't match!");
+    assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
+  }
   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
     assert(isLoopInvariant(Operands[i], L) &&
            "SCEVAddRecExpr operand is not loop-invariant!");
@@ -3662,9 +3676,13 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
   if (Ops.size() == 1) return Ops[0];
 #ifndef NDEBUG
   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
-  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
+  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
            "Operand types don't match!");
+    assert(Ops[0]->getType()->isPointerTy() ==
+               Ops[i]->getType()->isPointerTy() &&
+           "min/max should be consistently pointerish");
+  }
 #endif
 
   bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
@@ -10579,6 +10597,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
       }
     }
 
+    if (LHS->getType()->isPointerTy())
+      return false;
     if (CmpInst::isSigned(Pred)) {
       LHS = getSignExtendExpr(LHS, FoundLHS->getType());
       RHS = getSignExtendExpr(RHS, FoundLHS->getType());
@@ -10588,6 +10608,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
     }
   } else if (getTypeSizeInBits(LHS->getType()) >
       getTypeSizeInBits(FoundLHS->getType())) {
+    if (FoundLHS->getType()->isPointerTy())
+      return false;
     if (CmpInst::isSigned(FoundPred)) {
       FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
       FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());

diff  --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index d49df59a9d6c7..acf741d270fea 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -682,8 +682,11 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
     const APInt &RA = RC->getAPInt();
     // Handle x /s -1 as x * -1, to give ScalarEvolution a chance to do
     // some folding.
-    if (RA.isAllOnesValue())
+    if (RA.isAllOnesValue()) {
+      if (LHS->getType()->isPointerTy())
+        return nullptr;
       return SE.getMulExpr(LHS, RC);
+    }
     // Handle x /s 1 as x.
     if (RA == 1)
       return LHS;
@@ -4063,7 +4066,8 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
   // Determine the integer type for the base formula.
   Type *DstTy = Base.getType();
   if (!DstTy) return;
-  DstTy = SE.getEffectiveSCEVType(DstTy);
+  if (DstTy->isPointerTy())
+    return;
 
   for (Type *SrcTy : Types) {
     if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) {
@@ -5301,7 +5305,7 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF,
   if (F.BaseGV) {
     // Flush the operand list to suppress SCEVExpander hoisting.
     if (!Ops.empty()) {
-      Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty);
+      Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), IntTy);
       Ops.clear();
       Ops.push_back(SE.getUnknown(FullV));
     }

diff  --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 1cf3f97fba0f3..5af1c37e6197b 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1147,6 +1147,10 @@ static bool canBeCheaplyTransformed(ScalarEvolution &SE,
                                     const SCEVAddRecExpr *Phi,
                                     const SCEVAddRecExpr *Requested,
                                     bool &InvertStep) {
+  // We can't transform to match a pointer PHI.
+  if (Phi->getType()->isPointerTy())
+    return false;
+
   Type *PhiTy = SE.getEffectiveSCEVType(Phi->getType());
   Type *RequestedTy = SE.getEffectiveSCEVType(Requested->getType());
 
@@ -1165,8 +1169,7 @@ static bool canBeCheaplyTransformed(ScalarEvolution &SE,
   }
 
   // Check whether inverting will help: {R,+,-1} == R - {0,+,1}.
-  if (SE.getAddExpr(Requested->getStart(),
-                    SE.getNegativeSCEV(Requested)) == Phi) {
+  if (SE.getMinusSCEV(Requested->getStart(), Requested) == Phi) {
     InvertStep = true;
     return true;
   }
@@ -1577,8 +1580,8 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
   // Rewrite an AddRec in terms of the canonical induction variable, if
   // its type is more narrow.
   if (CanonicalIV &&
-      SE.getTypeSizeInBits(CanonicalIV->getType()) >
-      SE.getTypeSizeInBits(Ty)) {
+      SE.getTypeSizeInBits(CanonicalIV->getType()) > SE.getTypeSizeInBits(Ty) &&
+      !S->getType()->isPointerTy()) {
     SmallVector<const SCEV *, 4> NewOps(S->getNumOperands());
     for (unsigned i = 0, e = S->getNumOperands(); i != e; ++i)
       NewOps[i] = SE.getAnyExtendExpr(S->op_begin()[i], CanonicalIV->getType());

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 060b0456bae47..8a21646d9b68d 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -96,13 +96,13 @@ TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) {
   const SCEV *S1 = SE.getSCEV(V1);
   const SCEV *S2 = SE.getSCEV(V2);
 
-  const SCEV *P0 = SE.getAddExpr(S0, S0);
-  const SCEV *P1 = SE.getAddExpr(S1, S1);
-  const SCEV *P2 = SE.getAddExpr(S2, S2);
+  const SCEV *P0 = SE.getAddExpr(S0, SE.getConstant(S0->getType(), 2));
+  const SCEV *P1 = SE.getAddExpr(S1, SE.getConstant(S0->getType(), 2));
+  const SCEV *P2 = SE.getAddExpr(S2, SE.getConstant(S0->getType(), 2));
 
-  const SCEVMulExpr *M0 = cast<SCEVMulExpr>(P0);
-  const SCEVMulExpr *M1 = cast<SCEVMulExpr>(P1);
-  const SCEVMulExpr *M2 = cast<SCEVMulExpr>(P2);
+  auto *M0 = cast<SCEVAddExpr>(P0);
+  auto *M1 = cast<SCEVAddExpr>(P1);
+  auto *M2 = cast<SCEVAddExpr>(P2);
 
   EXPECT_EQ(cast<SCEVConstant>(M0->getOperand(0))->getValue()->getZExtValue(),
             2u);
@@ -707,6 +707,7 @@ TEST_F(ScalarEvolutionsTest, SCEVZeroExtendExpr) {
   ReturnInst::Create(Context, nullptr, EndBB);
   ScalarEvolution SE = buildSE(*F);
   const SCEV *S = SE.getSCEV(Accum);
+  S = SE.getLosslessPtrToIntExpr(S);
   Type *I128Ty = Type::getInt128Ty(Context);
   SE.getZeroExtendExpr(S, I128Ty);
 }


        


More information about the llvm-commits mailing list