[llvm] 5e71b9f - Explicitly pass type to cast load constant folding result

Arthur Eubanks via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 20 00:59:32 PDT 2021


Author: Arthur Eubanks
Date: 2021-04-20T00:53:21-07:00
New Revision: 5e71b9fa933a279458f31adffb2a092493a9e1eb

URL: https://github.com/llvm/llvm-project/commit/5e71b9fa933a279458f31adffb2a092493a9e1eb
DIFF: https://github.com/llvm/llvm-project/commit/5e71b9fa933a279458f31adffb2a092493a9e1eb.diff

LOG: Explicitly pass type to cast load constant folding result

Previously we would use the type of the pointee to determine what to
cast the result of constant folding a load. To aid with opaque pointer
types, we should explicitly pass the type of the load rather than
looking at pointee types.

ConstantFoldLoadThroughBitcast() converts the const prop'd value to the
proper load type (e.g. [1 x i32] -> i32). Instead of calling this in
every intermediate step like bitcasts, we only call this when we
actually see the global initializer value.

In some existing uses of this API, we don't know the exact type we're
loading from immediately (e.g. first we visit a bitcast, then we visit
the load using the bitcast). In those cases we have to manually call
ConstantFoldLoadThroughBitcast() when simplifying the load to make sure
that we cast to the proper type.

Reviewed By: dblaikie

Differential Revision: https://reviews.llvm.org/D100718

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ConstantFolding.h
    llvm/include/llvm/Transforms/Utils/Evaluator.h
    llvm/lib/Analysis/ConstantFolding.cpp
    llvm/lib/Transforms/IPO/GlobalOpt.cpp
    llvm/lib/Transforms/Utils/Evaluator.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ConstantFolding.h b/llvm/include/llvm/Analysis/ConstantFolding.h
index ef6e66b2b88ea..62742fdf9a915 100644
--- a/llvm/include/llvm/Analysis/ConstantFolding.h
+++ b/llvm/include/llvm/Analysis/ConstantFolding.h
@@ -136,7 +136,9 @@ Constant *ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, const DataLayout &
 /// ConstantFoldLoadThroughGEPConstantExpr - Given a constant and a
 /// getelementptr constantexpr, return the constant value being addressed by the
 /// constant expression, or null if something is funny and we can't decide.
-Constant *ConstantFoldLoadThroughGEPConstantExpr(Constant *C, ConstantExpr *CE);
+Constant *ConstantFoldLoadThroughGEPConstantExpr(Constant *C, ConstantExpr *CE,
+                                                 Type *Ty,
+                                                 const DataLayout &DL);
 
 /// ConstantFoldLoadThroughGEPIndices - Given a constant and getelementptr
 /// indices (with an *implied* zero pointer index that is not in the list),

diff  --git a/llvm/include/llvm/Transforms/Utils/Evaluator.h b/llvm/include/llvm/Transforms/Utils/Evaluator.h
index 71e58e57af20d..1b93b0af86e29 100644
--- a/llvm/include/llvm/Transforms/Utils/Evaluator.h
+++ b/llvm/include/llvm/Transforms/Utils/Evaluator.h
@@ -92,7 +92,7 @@ class Evaluator {
   bool getFormalParams(CallBase &CB, Function *F,
                        SmallVectorImpl<Constant *> &Formals);
 
-  Constant *ComputeLoadResult(Constant *P);
+  Constant *ComputeLoadResult(Constant *P, Type *Ty);
 
   /// As we compute SSA register values, we store their contents here. The back
   /// of the deque contains the current function and the stack contains the

diff  --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index c1975d0e22e6d..d326af554e5a4 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -658,16 +658,10 @@ Constant *FoldReinterpretLoadFromConstPtr(Constant *C, Type *LoadTy,
 Constant *ConstantFoldLoadThroughBitcastExpr(ConstantExpr *CE, Type *DestTy,
                                              const DataLayout &DL) {
   auto *SrcPtr = CE->getOperand(0);
-  auto *SrcPtrTy = dyn_cast<PointerType>(SrcPtr->getType());
-  if (!SrcPtrTy)
+  if (!SrcPtr->getType()->isPointerTy())
     return nullptr;
-  Type *SrcTy = SrcPtrTy->getPointerElementType();
 
-  Constant *C = ConstantFoldLoadFromConstPtr(SrcPtr, SrcTy, DL);
-  if (!C)
-    return nullptr;
-
-  return llvm::ConstantFoldLoadThroughBitcast(C, DestTy, DL);
+  return ConstantFoldLoadFromConstPtr(SrcPtr, DestTy, DL);
 }
 
 } // end anonymous namespace
@@ -677,7 +671,7 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty,
   // First, try the easy cases:
   if (auto *GV = dyn_cast<GlobalVariable>(C))
     if (GV->isConstant() && GV->hasDefinitiveInitializer())
-      return GV->getInitializer();
+      return ConstantFoldLoadThroughBitcast(GV->getInitializer(), Ty, DL);
 
   if (auto *GA = dyn_cast<GlobalAlias>(C))
     if (GA->getAliasee() && !GA->isInterposable())
@@ -691,8 +685,8 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty,
   if (CE->getOpcode() == Instruction::GetElementPtr) {
     if (auto *GV = dyn_cast<GlobalVariable>(CE->getOperand(0))) {
       if (GV->isConstant() && GV->hasDefinitiveInitializer()) {
-        if (Constant *V =
-             ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE))
+        if (Constant *V = ConstantFoldLoadThroughGEPConstantExpr(
+                GV->getInitializer(), CE, Ty, DL))
           return V;
       }
     }
@@ -1410,7 +1404,9 @@ Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C,
 }
 
 Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C,
-                                                       ConstantExpr *CE) {
+                                                       ConstantExpr *CE,
+                                                       Type *Ty,
+                                                       const DataLayout &DL) {
   if (!CE->getOperand(1)->isNullValue())
     return nullptr;  // Do not allow stepping over the value!
 
@@ -1421,7 +1417,7 @@ Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C,
     if (!C)
       return nullptr;
   }
-  return C;
+  return ConstantFoldLoadThroughBitcast(C, Ty, DL);
 }
 
 Constant *

diff  --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
index f275f371c77ae..94278a0d501d4 100644
--- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp
+++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
@@ -296,10 +296,13 @@ static bool CleanupConstantGlobalUsers(
 
     if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
       if (Init) {
-        // Replace the load with the initializer.
-        LI->replaceAllUsesWith(Init);
-        LI->eraseFromParent();
-        Changed = true;
+        if (auto *Casted =
+                ConstantFoldLoadThroughBitcast(Init, LI->getType(), DL)) {
+          // Replace the load with the initializer.
+          LI->replaceAllUsesWith(Casted);
+          LI->eraseFromParent();
+          Changed = true;
+        }
       }
     } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
       // Store must be unreachable or storing Init into the global.
@@ -309,7 +312,8 @@ static bool CleanupConstantGlobalUsers(
       if (CE->getOpcode() == Instruction::GetElementPtr) {
         Constant *SubInit = nullptr;
         if (Init)
-          SubInit = ConstantFoldLoadThroughGEPConstantExpr(Init, CE);
+          SubInit = ConstantFoldLoadThroughGEPConstantExpr(
+              Init, CE, V->getType()->getPointerElementType(), DL);
         Changed |= CleanupConstantGlobalUsers(CE, SubInit, DL, GetTLI);
       } else if ((CE->getOpcode() == Instruction::BitCast &&
                   CE->getType()->isPointerTy()) ||
@@ -331,7 +335,8 @@ static bool CleanupConstantGlobalUsers(
         ConstantExpr *CE = dyn_cast_or_null<ConstantExpr>(
             ConstantFoldInstruction(GEP, DL, &GetTLI(*GEP->getFunction())));
         if (Init && CE && CE->getOpcode() == Instruction::GetElementPtr)
-          SubInit = ConstantFoldLoadThroughGEPConstantExpr(Init, CE);
+          SubInit = ConstantFoldLoadThroughGEPConstantExpr(
+              Init, CE, V->getType()->getPointerElementType(), DL);
 
         // If the initializer is an all-null value and we have an inbounds GEP,
         // we already know what the result of any load from that GEP is.

diff  --git a/llvm/lib/Transforms/Utils/Evaluator.cpp b/llvm/lib/Transforms/Utils/Evaluator.cpp
index ac29565567a90..7404ceb83ab87 100644
--- a/llvm/lib/Transforms/Utils/Evaluator.cpp
+++ b/llvm/lib/Transforms/Utils/Evaluator.cpp
@@ -127,7 +127,7 @@ isSimpleEnoughValueToCommit(Constant *C,
 /// another pointer type, we punt.  We basically just support direct accesses to
 /// globals and GEP's of globals.  This should be kept up to date with
 /// CommitValueTo.
-static bool isSimpleEnoughPointerToCommit(Constant *C) {
+static bool isSimpleEnoughPointerToCommit(Constant *C, const DataLayout &DL) {
   // Conservatively, avoid aggregate types. This is because we don't
   // want to worry about them partially overlapping other stores.
   if (!cast<PointerType>(C->getType())->getElementType()->isSingleValueType())
@@ -157,13 +157,14 @@ static bool isSimpleEnoughPointerToCommit(Constant *C) {
       if (!CE->isGEPWithNoNotionalOverIndexing())
         return false;
 
-      return ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE);
-
-    // A constantexpr bitcast from a pointer to another pointer is a no-op,
-    // and we know how to evaluate it by moving the bitcast from the pointer
-    // operand to the value operand.
+      return ConstantFoldLoadThroughGEPConstantExpr(
+          GV->getInitializer(), CE,
+          cast<GEPOperator>(CE)->getResultElementType(), DL);
     } else if (CE->getOpcode() == Instruction::BitCast &&
                isa<GlobalVariable>(CE->getOperand(0))) {
+      // A constantexpr bitcast from a pointer to another pointer is a no-op,
+      // and we know how to evaluate it by moving the bitcast from the pointer
+      // operand to the value operand.
       // Do not allow weak/*_odr/linkonce/dllimport/dllexport linkage or
       // external globals.
       return cast<GlobalVariable>(CE->getOperand(0))->hasUniqueInitializer();
@@ -207,7 +208,7 @@ static Constant *getInitializer(Constant *C) {
 
 /// Return the value that would be computed by a load from P after the stores
 /// reflected by 'memory' have been performed.  If we can't decide, return null.
-Constant *Evaluator::ComputeLoadResult(Constant *P) {
+Constant *Evaluator::ComputeLoadResult(Constant *P, Type *Ty) {
   // If this memory location has been recently stored, use the stored value: it
   // is the most up-to-date.
   auto findMemLoc = [this](Constant *Ptr) { return MutatedMemory.lookup(Ptr); };
@@ -227,7 +228,7 @@ Constant *Evaluator::ComputeLoadResult(Constant *P) {
     // Handle a constantexpr getelementptr.
     case Instruction::GetElementPtr:
       if (auto *I = getInitializer(CE->getOperand(0)))
-        return ConstantFoldLoadThroughGEPConstantExpr(I, CE);
+        return ConstantFoldLoadThroughGEPConstantExpr(I, CE, Ty, DL);
       break;
     // Handle a constantexpr bitcast.
     case Instruction::BitCast:
@@ -340,7 +341,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB,
         Ptr = FoldedPtr;
         LLVM_DEBUG(dbgs() << "; To: " << *Ptr << "\n");
       }
-      if (!isSimpleEnoughPointerToCommit(Ptr)) {
+      if (!isSimpleEnoughPointerToCommit(Ptr, DL)) {
         // If this is too complex for us to commit, reject it.
         LLVM_DEBUG(
             dbgs() << "Pointer is too complex for us to evaluate store.");
@@ -450,7 +451,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB,
                              "folding: "
                           << *Ptr << "\n");
       }
-      InstResult = ComputeLoadResult(Ptr);
+      InstResult = ComputeLoadResult(Ptr, LI->getType());
       if (!InstResult) {
         LLVM_DEBUG(
             dbgs() << "Failed to compute load result. Can not evaluate load."
@@ -496,7 +497,8 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB,
           }
           Constant *Ptr = getVal(MSI->getDest());
           Constant *Val = getVal(MSI->getValue());
-          Constant *DestVal = ComputeLoadResult(getVal(Ptr));
+          Constant *DestVal =
+              ComputeLoadResult(getVal(Ptr), MSI->getValue()->getType());
           if (Val->isNullValue() && DestVal && DestVal->isNullValue()) {
             // This memset is a no-op.
             LLVM_DEBUG(dbgs() << "Ignoring no-op memset.\n");


        


More information about the llvm-commits mailing list