[llvm] e92a8e0 - [SCEV] BuildConstantFromSCEV(): actually properly handle SExt-of-pointer case

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 13 12:23:11 PDT 2020


Author: Roman Lebedev
Date: 2020-10-13T22:22:30+03:00
New Revision: e92a8e0c743f83552fac37ecf21e625ba3a4b11e

URL: https://github.com/llvm/llvm-project/commit/e92a8e0c743f83552fac37ecf21e625ba3a4b11e
DIFF: https://github.com/llvm/llvm-project/commit/e92a8e0c743f83552fac37ecf21e625ba3a4b11e.diff

LOG: [SCEV] BuildConstantFromSCEV(): actually properly handle SExt-of-pointer case

As being pointed out by @efriedma in
https://reviews.llvm.org/rGaaafe350bb65#inline-4883
of course we can't just call ptrtoint in sign-extending case
and be done with it, because it will zero-extend.

I'm not sure what i was thinking there.

This is very much not an NFC, however looking at the user of
BuildConstantFromSCEV() i'm not sure how to actually show that
it results in a different constant expression.

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 152351c10ad4..4f1d888ca0a2 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7976,7 +7976,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
 /// will return Constants for objects which aren't represented by a
 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
 /// Returns NULL if the SCEV isn't representable as a Constant.
-static Constant *BuildConstantFromSCEV(const SCEV *V) {
+static Constant *BuildConstantFromSCEV(const SCEV *V, const DataLayout &DL) {
   switch (static_cast<SCEVTypes>(V->getSCEVType())) {
     case scCouldNotCompute:
     case scAddRecExpr:
@@ -7987,16 +7987,22 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
       return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
     case scSignExtend: {
       const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
-      if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) {
-        if (!CastOp->getType()->isPointerTy())
-          return ConstantExpr::getSExt(CastOp, SS->getType());
-        return ConstantExpr::getPtrToInt(CastOp, SS->getType());
+      if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand(), DL)) {
+        if (CastOp->getType()->isPointerTy())
+          // Note that for SExt, unlike ZExt/Trunc, it is incorrect to just call
+          // ConstantExpr::getPtrToInt() and be done with it, because PtrToInt
+          // will zero-extend (otherwise ZExt case wouldn't work). So we need to
+          // first cast to the same-bitwidth integer, and then SExt it.
+          CastOp = ConstantExpr::getPtrToInt(
+              CastOp, DL.getIntPtrType(CastOp->getType()));
+        // And now, we can actually perform the sign-extension.
+        return ConstantExpr::getSExt(CastOp, SS->getType());
       }
       break;
     }
     case scZeroExtend: {
       const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
-      if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) {
+      if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand(), DL)) {
         if (!CastOp->getType()->isPointerTy())
           return ConstantExpr::getZExt(CastOp, SZ->getType());
         return ConstantExpr::getPtrToInt(CastOp, SZ->getType());
@@ -8005,7 +8011,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
     }
     case scTruncate: {
       const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
-      if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) {
+      if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand(), DL)) {
         if (!CastOp->getType()->isPointerTy())
           return ConstantExpr::getTrunc(CastOp, ST->getType());
         return ConstantExpr::getPtrToInt(CastOp, ST->getType());
@@ -8014,14 +8020,14 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
     }
     case scAddExpr: {
       const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
-      if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
+      if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0), DL)) {
         if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
           unsigned AS = PTy->getAddressSpace();
           Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
           C = ConstantExpr::getBitCast(C, DestPtrTy);
         }
         for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
-          Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
+          Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i), DL);
           if (!C2) return nullptr;
 
           // First pointer!
@@ -8053,11 +8059,11 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
     }
     case scMulExpr: {
       const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
-      if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
+      if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0), DL)) {
         // Don't bother with pointers at all.
         if (C->getType()->isPointerTy()) return nullptr;
         for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
-          Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
+          Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i), DL);
           if (!C2 || C2->getType()->isPointerTy()) return nullptr;
           C = ConstantExpr::getMul(C, C2);
         }
@@ -8067,8 +8073,8 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
     }
     case scUDivExpr: {
       const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
-      if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
-        if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
+      if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS(), DL))
+        if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS(), DL))
           if (LHS->getType() == RHS->getType())
             return ConstantExpr::getUDiv(LHS, RHS);
       break;
@@ -8173,7 +8179,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
           const SCEV *OpV = getSCEVAtScope(OrigV, L);
           MadeImprovement |= OrigV != OpV;
 
-          Constant *C = BuildConstantFromSCEV(OpV);
+          Constant *C = BuildConstantFromSCEV(OpV, getDataLayout());
           if (!C) return V;
           if (C->getType() != Op->getType())
             C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,


        


More information about the llvm-commits mailing list