[llvm] e92a8e0 - [SCEV] BuildConstantFromSCEV(): actually properly handle SExt-of-pointer case
Roman Lebedev via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 13 12:23:11 PDT 2020
Author: Roman Lebedev
Date: 2020-10-13T22:22:30+03:00
New Revision: e92a8e0c743f83552fac37ecf21e625ba3a4b11e
URL: https://github.com/llvm/llvm-project/commit/e92a8e0c743f83552fac37ecf21e625ba3a4b11e
DIFF: https://github.com/llvm/llvm-project/commit/e92a8e0c743f83552fac37ecf21e625ba3a4b11e.diff
LOG: [SCEV] BuildConstantFromSCEV(): actually properly handle SExt-of-pointer case
As being pointed out by @efriedma in
https://reviews.llvm.org/rGaaafe350bb65#inline-4883
of course we can't just call ptrtoint in sign-extending case
and be done with it, because it will zero-extend.
I'm not sure what i was thinking there.
This is very much not an NFC, however looking at the user of
BuildConstantFromSCEV() i'm not sure how to actually show that
it results in a different constant expression.
Added:
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 152351c10ad4..4f1d888ca0a2 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7976,7 +7976,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
/// will return Constants for objects which aren't represented by a
/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
/// Returns NULL if the SCEV isn't representable as a Constant.
-static Constant *BuildConstantFromSCEV(const SCEV *V) {
+static Constant *BuildConstantFromSCEV(const SCEV *V, const DataLayout &DL) {
switch (static_cast<SCEVTypes>(V->getSCEVType())) {
case scCouldNotCompute:
case scAddRecExpr:
@@ -7987,16 +7987,22 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
case scSignExtend: {
const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
- if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) {
- if (!CastOp->getType()->isPointerTy())
- return ConstantExpr::getSExt(CastOp, SS->getType());
- return ConstantExpr::getPtrToInt(CastOp, SS->getType());
+ if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand(), DL)) {
+ if (CastOp->getType()->isPointerTy())
+ // Note that for SExt, unlike ZExt/Trunc, it is incorrect to just call
+ // ConstantExpr::getPtrToInt() and be done with it, because PtrToInt
+ // will zero-extend (otherwise ZExt case wouldn't work). So we need to
+ // first cast to the same-bitwidth integer, and then SExt it.
+ CastOp = ConstantExpr::getPtrToInt(
+ CastOp, DL.getIntPtrType(CastOp->getType()));
+ // And now, we can actually perform the sign-extension.
+ return ConstantExpr::getSExt(CastOp, SS->getType());
}
break;
}
case scZeroExtend: {
const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
- if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) {
+ if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand(), DL)) {
if (!CastOp->getType()->isPointerTy())
return ConstantExpr::getZExt(CastOp, SZ->getType());
return ConstantExpr::getPtrToInt(CastOp, SZ->getType());
@@ -8005,7 +8011,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
}
case scTruncate: {
const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
- if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) {
+ if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand(), DL)) {
if (!CastOp->getType()->isPointerTy())
return ConstantExpr::getTrunc(CastOp, ST->getType());
return ConstantExpr::getPtrToInt(CastOp, ST->getType());
@@ -8014,14 +8020,14 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
}
case scAddExpr: {
const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
- if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
+ if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0), DL)) {
if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
unsigned AS = PTy->getAddressSpace();
Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
C = ConstantExpr::getBitCast(C, DestPtrTy);
}
for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
- Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
+ Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i), DL);
if (!C2) return nullptr;
// First pointer!
@@ -8053,11 +8059,11 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
}
case scMulExpr: {
const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
- if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
+ if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0), DL)) {
// Don't bother with pointers at all.
if (C->getType()->isPointerTy()) return nullptr;
for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
- Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
+ Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i), DL);
if (!C2 || C2->getType()->isPointerTy()) return nullptr;
C = ConstantExpr::getMul(C, C2);
}
@@ -8067,8 +8073,8 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
}
case scUDivExpr: {
const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
- if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
- if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
+ if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS(), DL))
+ if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS(), DL))
if (LHS->getType() == RHS->getType())
return ConstantExpr::getUDiv(LHS, RHS);
break;
@@ -8173,7 +8179,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
const SCEV *OpV = getSCEVAtScope(OrigV, L);
MadeImprovement |= OrigV != OpV;
- Constant *C = BuildConstantFromSCEV(OpV);
+ Constant *C = BuildConstantFromSCEV(OpV, getDataLayout());
if (!C) return V;
if (C->getType() != Op->getType())
C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
More information about the llvm-commits
mailing list