[llvm] [NFC] [DSE] Refactor DSE (PR #100956)

Haopeng Liu via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 21 10:06:54 PDT 2024


https://github.com/haopliu updated https://github.com/llvm/llvm-project/pull/100956

>From 4e366574119ebc9d130e4f08aca12a3ded89ae34 Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Mon, 29 Jul 2024 01:26:21 +0000
Subject: [PATCH 1/9] Refactor DSE with MemoryDefWrapper and
 MemoryLocationWrapper

---
 .../Scalar/DeadStoreElimination.cpp           | 436 ++++++++++--------
 1 file changed, 245 insertions(+), 191 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 931606c6f8fe12..e64b7da818a985 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -806,6 +806,63 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
   return false;
 }
 
+/// Returns true if \p I is a memory terminator instruction like
+/// llvm.lifetime.end or free.
+bool isMemTerminatorInst(Instruction *I, const TargetLibraryInfo &TLI) {
+  auto *CB = dyn_cast<CallBase>(I);
+  return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
+                getFreedOperand(CB, &TLI) != nullptr);
+}
+
+struct DSEState;
+enum class ChangeStateEnum : uint8_t {
+  NoChange,
+  DeleteByMemTerm,
+  CompleteDeleteByNonMemTerm,
+  PartiallyDeleteByNonMemTerm,
+};
+
+// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
+// defined by `MemDef`.
+class MemoryLocationWrapper {
+public:
+  MemoryLocationWrapper(MemoryLocation MemLoc, DSEState &State,
+                        MemoryDef *MemDef)
+      : MemLoc(MemLoc), State(State), MemDef(MemDef) {
+    assert(MemLoc.Ptr && "MemLoc should be not null");
+    UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
+    DefInst = MemDef->getMemoryInst();
+  }
+
+  // Try to eliminate dead defs killed by this MemoryLocation and return the
+  // change state.
+  ChangeStateEnum eliminateDeadDefs();
+  MemoryAccess *GetDefiningAccess() const {
+    return MemDef->getDefiningAccess();
+  }
+
+  MemoryLocation MemLoc;
+  const Value *UnderlyingObject;
+  DSEState &State;
+  MemoryDef *MemDef;
+  Instruction *DefInst;
+};
+
+// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
+// defined by this MemoryDef.
+class MemoryDefWrapper {
+public:
+  MemoryDefWrapper(MemoryDef *MemDef, DSEState &State);
+  // Try to eliminate dead defs killed by this MemoryDef and return the
+  // change state.
+  bool eliminateDeadDefs();
+
+  MemoryDef *MemDef;
+  Instruction *DefInst;
+  DSEState &State;
+  SmallVector<MemoryLocationWrapper, 1> DefinedLocations;
+};
+
 struct DSEState {
   Function &F;
   AliasAnalysis &AA;
@@ -883,7 +940,7 @@ struct DSEState {
 
         auto *MD = dyn_cast_or_null<MemoryDef>(MA);
         if (MD && MemDefs.size() < MemorySSADefsPerBlockLimit &&
-            (getLocForWrite(&I) || isMemTerminatorInst(&I)))
+            (getLocForWrite(&I) || isMemTerminatorInst(&I, TLI)))
           MemDefs.push_back(MD);
       }
     }
@@ -1225,14 +1282,6 @@ struct DSEState {
     return std::nullopt;
   }
 
-  /// Returns true if \p I is a memory terminator instruction like
-  /// llvm.lifetime.end or free.
-  bool isMemTerminatorInst(Instruction *I) const {
-    auto *CB = dyn_cast<CallBase>(I);
-    return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
-                  getFreedOperand(CB, &TLI) != nullptr);
-  }
-
   /// Returns true if \p MaybeTerm is a memory terminator for \p Loc from
   /// instruction \p AccessI.
   bool isMemTerminator(const MemoryLocation &Loc, Instruction *AccessI,
@@ -1325,17 +1374,15 @@ struct DSEState {
   // (completely) overwrite \p KillingLoc. Currently we bail out when we
   // encounter an aliasing MemoryUse (read).
   std::optional<MemoryAccess *>
-  getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess,
-                  const MemoryLocation &KillingLoc, const Value *KillingUndObj,
+  getDomMemoryDef(MemoryLocationWrapper &KillingLoc, MemoryAccess *StartAccess,
                   unsigned &ScanLimit, unsigned &WalkerStepLimit,
-                  bool IsMemTerm, unsigned &PartialLimit) {
+                  unsigned &PartialLimit) {
     if (ScanLimit == 0 || WalkerStepLimit == 0) {
       LLVM_DEBUG(dbgs() << "\n    ...  hit scan limit\n");
       return std::nullopt;
     }
 
     MemoryAccess *Current = StartAccess;
-    Instruction *KillingI = KillingDef->getMemoryInst();
     LLVM_DEBUG(dbgs() << "  trying to get dominating access\n");
 
     // Only optimize defining access of KillingDef when directly starting at its
@@ -1344,8 +1391,8 @@ struct DSEState {
     // it should be sufficient to disable optimizations for instructions that
     // also read from memory.
     bool CanOptimize = OptimizeMemorySSA &&
-                       KillingDef->getDefiningAccess() == StartAccess &&
-                       !KillingI->mayReadFromMemory();
+                       KillingLoc.GetDefiningAccess() == StartAccess &&
+                       !KillingLoc.DefInst->mayReadFromMemory();
 
     // Find the next clobbering Mod access for DefLoc, starting at StartAccess.
     std::optional<MemoryLocation> CurrentLoc;
@@ -1361,15 +1408,15 @@ struct DSEState {
       // Reached TOP.
       if (MSSA.isLiveOnEntryDef(Current)) {
         LLVM_DEBUG(dbgs() << "   ...  found LiveOnEntryDef\n");
-        if (CanOptimize && Current != KillingDef->getDefiningAccess())
+        if (CanOptimize && Current != KillingLoc.GetDefiningAccess())
           // The first clobbering def is... none.
-          KillingDef->setOptimized(Current);
+          KillingLoc.MemDef->setOptimized(Current);
         return std::nullopt;
       }
 
       // Cost of a step. Accesses in the same block are more likely to be valid
       // candidates for elimination, hence consider them cheaper.
-      unsigned StepCost = KillingDef->getBlock() == Current->getBlock()
+      unsigned StepCost = KillingLoc.MemDef->getBlock() == Current->getBlock()
                               ? MemorySSASameBBStepCost
                               : MemorySSAOtherBBStepCost;
       if (WalkerStepLimit <= StepCost) {
@@ -1390,21 +1437,23 @@ struct DSEState {
       MemoryDef *CurrentDef = cast<MemoryDef>(Current);
       Instruction *CurrentI = CurrentDef->getMemoryInst();
 
-      if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(KillingUndObj))) {
+      if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(
+                                     KillingLoc.UnderlyingObject))) {
         CanOptimize = false;
         continue;
       }
 
       // Before we try to remove anything, check for any extra throwing
       // instructions that block us from DSEing
-      if (mayThrowBetween(KillingI, CurrentI, KillingUndObj)) {
+      if (mayThrowBetween(KillingLoc.DefInst, CurrentI,
+                          KillingLoc.UnderlyingObject)) {
         LLVM_DEBUG(dbgs() << "  ... skip, may throw!\n");
         return std::nullopt;
       }
 
       // Check for anything that looks like it will be a barrier to further
       // removal
-      if (isDSEBarrier(KillingUndObj, CurrentI)) {
+      if (isDSEBarrier(KillingLoc.UnderlyingObject, CurrentI)) {
         LLVM_DEBUG(dbgs() << "  ... skip, barrier\n");
         return std::nullopt;
       }
@@ -1413,14 +1462,16 @@ struct DSEState {
       // clobber, bail out, as the path is not profitable. We skip this check
       // for intrinsic calls, because the code knows how to handle memcpy
       // intrinsics.
-      if (!isa<IntrinsicInst>(CurrentI) && isReadClobber(KillingLoc, CurrentI))
+      if (!isa<IntrinsicInst>(CurrentI) &&
+          isReadClobber(KillingLoc.MemLoc, CurrentI))
         return std::nullopt;
 
       // Quick check if there are direct uses that are read-clobbers.
       if (any_of(Current->uses(), [this, &KillingLoc, StartAccess](Use &U) {
             if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(U.getUser()))
               return !MSSA.dominates(StartAccess, UseOrDef) &&
-                     isReadClobber(KillingLoc, UseOrDef->getMemoryInst());
+                     isReadClobber(KillingLoc.MemLoc,
+                                   UseOrDef->getMemoryInst());
             return false;
           })) {
         LLVM_DEBUG(dbgs() << "   ...  found a read clobber\n");
@@ -1438,32 +1489,33 @@ struct DSEState {
       // AliasAnalysis does not account for loops. Limit elimination to
       // candidates for which we can guarantee they always store to the same
       // memory location and not located in different loops.
-      if (!isGuaranteedLoopIndependent(CurrentI, KillingI, *CurrentLoc)) {
+      if (!isGuaranteedLoopIndependent(CurrentI, KillingLoc.DefInst,
+                                       *CurrentLoc)) {
         LLVM_DEBUG(dbgs() << "  ... not guaranteed loop independent\n");
         CanOptimize = false;
         continue;
       }
 
-      if (IsMemTerm) {
+      if (isMemTerminatorInst(KillingLoc.DefInst, TLI)) {
         // If the killing def is a memory terminator (e.g. lifetime.end), check
         // the next candidate if the current Current does not write the same
         // underlying object as the terminator.
-        if (!isMemTerminator(*CurrentLoc, CurrentI, KillingI)) {
+        if (!isMemTerminator(*CurrentLoc, CurrentI, KillingLoc.DefInst)) {
           CanOptimize = false;
           continue;
         }
       } else {
         int64_t KillingOffset = 0;
         int64_t DeadOffset = 0;
-        auto OR = isOverwrite(KillingI, CurrentI, KillingLoc, *CurrentLoc,
-                              KillingOffset, DeadOffset);
+        auto OR = isOverwrite(KillingLoc.DefInst, CurrentI, KillingLoc.MemLoc,
+                              *CurrentLoc, KillingOffset, DeadOffset);
         if (CanOptimize) {
           // CurrentDef is the earliest write clobber of KillingDef. Use it as
           // optimized access. Do not optimize if CurrentDef is already the
           // defining access of KillingDef.
-          if (CurrentDef != KillingDef->getDefiningAccess() &&
+          if (CurrentDef != KillingLoc.GetDefiningAccess() &&
               (OR == OW_Complete || OR == OW_MaybePartial))
-            KillingDef->setOptimized(CurrentDef);
+            KillingLoc.MemDef->setOptimized(CurrentDef);
 
           // Once a may-aliasing def is encountered do not set an optimized
           // access.
@@ -1496,7 +1548,7 @@ struct DSEState {
     // the blocks with killing (=completely overwriting MemoryDefs) and check if
     // they cover all paths from MaybeDeadAccess to any function exit.
     SmallPtrSet<Instruction *, 16> KillingDefs;
-    KillingDefs.insert(KillingDef->getMemoryInst());
+    KillingDefs.insert(KillingLoc.DefInst);
     MemoryAccess *MaybeDeadAccess = Current;
     MemoryLocation MaybeDeadLoc = *CurrentLoc;
     Instruction *MaybeDeadI = cast<MemoryDef>(MaybeDeadAccess)->getMemoryInst();
@@ -1558,7 +1610,8 @@ struct DSEState {
         continue;
       }
 
-      if (UseInst->mayThrow() && !isInvisibleToCallerOnUnwind(KillingUndObj)) {
+      if (UseInst->mayThrow() &&
+          !isInvisibleToCallerOnUnwind(KillingLoc.UnderlyingObject)) {
         LLVM_DEBUG(dbgs() << "  ... found throwing instruction\n");
         return std::nullopt;
       }
@@ -1582,7 +1635,7 @@ struct DSEState {
       // if it reads the memory location.
       // TODO: It would probably be better to check for self-reads before
       // calling the function.
-      if (KillingDef == UseAccess || MaybeDeadAccess == UseAccess) {
+      if (KillingLoc.MemDef == UseAccess || MaybeDeadAccess == UseAccess) {
         LLVM_DEBUG(dbgs() << "    ... skipping killing def/dom access\n");
         continue;
       }
@@ -1602,7 +1655,7 @@ struct DSEState {
           BasicBlock *MaybeKillingBlock = UseInst->getParent();
           if (PostOrderNumbers.find(MaybeKillingBlock)->second <
               PostOrderNumbers.find(MaybeDeadAccess->getBlock())->second) {
-            if (!isInvisibleToCallerAfterRet(KillingUndObj)) {
+            if (!isInvisibleToCallerAfterRet(KillingLoc.UnderlyingObject)) {
               LLVM_DEBUG(dbgs()
                          << "    ... found killing def " << *UseInst << "\n");
               KillingDefs.insert(UseInst);
@@ -1620,7 +1673,7 @@ struct DSEState {
     // For accesses to locations visible after the function returns, make sure
     // that the location is dead (=overwritten) along all paths from
     // MaybeDeadAccess to the exit.
-    if (!isInvisibleToCallerAfterRet(KillingUndObj)) {
+    if (!isInvisibleToCallerAfterRet(KillingLoc.UnderlyingObject)) {
       SmallPtrSet<BasicBlock *, 16> KillingBlocks;
       for (Instruction *KD : KillingDefs)
         KillingBlocks.insert(KD->getParent());
@@ -2134,180 +2187,181 @@ struct DSEState {
   }
 };
 
-static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
-                                DominatorTree &DT, PostDominatorTree &PDT,
-                                const TargetLibraryInfo &TLI,
-                                const LoopInfo &LI) {
-  bool MadeChange = false;
+MemoryDefWrapper::MemoryDefWrapper(MemoryDef *MemDef, DSEState &State)
+    : MemDef(MemDef), State(State) {
+  DefInst = MemDef->getMemoryInst();
 
-  DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
-  // For each store:
-  for (unsigned I = 0; I < State.MemDefs.size(); I++) {
-    MemoryDef *KillingDef = State.MemDefs[I];
-    if (State.SkipStores.count(KillingDef))
+  if (isMemTerminatorInst(DefInst, State.TLI)) {
+    if (auto KillingLoc = State.getLocForTerminator(DefInst)) {
+      DefinedLocations.push_back(
+          MemoryLocationWrapper(KillingLoc->first, State, MemDef));
+    }
+    return;
+  }
+
+  if (auto KillingLoc = State.getLocForWrite(DefInst)) {
+    DefinedLocations.push_back(
+        MemoryLocationWrapper(*KillingLoc, State, MemDef));
+  }
+}
+
+ChangeStateEnum MemoryLocationWrapper::eliminateDeadDefs() {
+  ChangeStateEnum ChangeState = ChangeStateEnum::NoChange;
+  unsigned ScanLimit = MemorySSAScanLimit;
+  unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
+  unsigned PartialLimit = MemorySSAPartialStoreLimit;
+  // Worklist of MemoryAccesses that may be killed by KillingDef.
+  SmallSetVector<MemoryAccess *, 8> ToCheck;
+  ToCheck.insert(GetDefiningAccess());
+
+  // Check if MemoryAccesses in the worklist are killed by KillingDef.
+  for (unsigned I = 0; I < ToCheck.size(); I++) {
+    MemoryAccess *Current = ToCheck[I];
+    if (State.SkipStores.count(Current))
       continue;
-    Instruction *KillingI = KillingDef->getMemoryInst();
+    std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
+        *this, Current, ScanLimit, WalkerStepLimit, PartialLimit);
 
-    std::optional<MemoryLocation> MaybeKillingLoc;
-    if (State.isMemTerminatorInst(KillingI)) {
-      if (auto KillingLoc = State.getLocForTerminator(KillingI))
-        MaybeKillingLoc = KillingLoc->first;
-    } else {
-      MaybeKillingLoc = State.getLocForWrite(KillingI);
+    if (!MaybeDeadAccess) {
+      LLVM_DEBUG(dbgs() << "  finished walk\n");
+      continue;
     }
-
-    if (!MaybeKillingLoc) {
-      LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
-                        << *KillingI << "\n");
+    MemoryAccess *DeadAccess = *MaybeDeadAccess;
+    LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
+    if (isa<MemoryPhi>(DeadAccess)) {
+      LLVM_DEBUG(dbgs() << "\n  ... adding incoming values to worklist\n");
+      for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
+        MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
+        BasicBlock *IncomingBlock = IncomingAccess->getBlock();
+        BasicBlock *PhiBlock = DeadAccess->getBlock();
+
+        // We only consider incoming MemoryAccesses that come before the
+        // MemoryPhi. Otherwise we could discover candidates that do not
+        // strictly dominate our starting def.
+        if (State.PostOrderNumbers[IncomingBlock] >
+            State.PostOrderNumbers[PhiBlock])
+          ToCheck.insert(IncomingAccess);
+      }
       continue;
     }
-    MemoryLocation KillingLoc = *MaybeKillingLoc;
-    assert(KillingLoc.Ptr && "KillingLoc should not be null");
-    const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr);
-    LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
-                      << *KillingDef << " (" << *KillingI << ")\n");
-
-    unsigned ScanLimit = MemorySSAScanLimit;
-    unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
-    unsigned PartialLimit = MemorySSAPartialStoreLimit;
-    // Worklist of MemoryAccesses that may be killed by KillingDef.
-    SmallSetVector<MemoryAccess *, 8> ToCheck;
-    // Track MemoryAccesses that have been deleted in the loop below, so we can
-    // skip them. Don't use SkipStores for this, which may contain reused
-    // MemoryAccess addresses.
-    SmallPtrSet<MemoryAccess *, 8> Deleted;
-    [[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size();
-    ToCheck.insert(KillingDef->getDefiningAccess());
-
-    bool Shortend = false;
-    bool IsMemTerm = State.isMemTerminatorInst(KillingI);
-    // Check if MemoryAccesses in the worklist are killed by KillingDef.
-    for (unsigned I = 0; I < ToCheck.size(); I++) {
-      MemoryAccess *Current = ToCheck[I];
-      if (Deleted.contains(Current))
-        continue;
+    MemoryDefWrapper DeadDefWrapper(cast<MemoryDef>(DeadAccess), State);
+    MemoryLocationWrapper &DeadLoc = DeadDefWrapper.DefinedLocations.front();
+    LLVM_DEBUG(dbgs() << " (" << *DeadDefWrapper.DefInst << ")\n");
+    ToCheck.insert(DeadLoc.GetDefiningAccess());
+    NumGetDomMemoryDefPassed++;
 
-      std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
-          KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit,
-          WalkerStepLimit, IsMemTerm, PartialLimit);
-
-      if (!MaybeDeadAccess) {
-        LLVM_DEBUG(dbgs() << "  finished walk\n");
+    if (!DebugCounter::shouldExecute(MemorySSACounter))
+      continue;
+    if (isMemTerminatorInst(DefInst, State.TLI)) {
+      if (!(UnderlyingObject == DeadLoc.UnderlyingObject))
         continue;
+      LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
+                        << *DeadDefWrapper.DefInst << "\n  KILLER: " << *DefInst
+                        << '\n');
+      State.deleteDeadInstruction(DeadDefWrapper.DefInst);
+      ++NumFastStores;
+      ChangeState = ChangeStateEnum::DeleteByMemTerm;
+    } else {
+      // Check if DeadI overwrites KillingI.
+      int64_t KillingOffset = 0;
+      int64_t DeadOffset = 0;
+      OverwriteResult OR =
+          State.isOverwrite(DefInst, DeadDefWrapper.DefInst, MemLoc,
+                            DeadLoc.MemLoc, KillingOffset, DeadOffset);
+      if (OR == OW_MaybePartial) {
+        auto Iter = State.IOLs.insert(
+            std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
+                DeadDefWrapper.DefInst->getParent(), InstOverlapIntervalsTy()));
+        auto &IOL = Iter.first->second;
+        OR = isPartialOverwrite(MemLoc, DeadLoc.MemLoc, KillingOffset,
+                                DeadOffset, DeadDefWrapper.DefInst, IOL);
       }
-
-      MemoryAccess *DeadAccess = *MaybeDeadAccess;
-      LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
-      if (isa<MemoryPhi>(DeadAccess)) {
-        LLVM_DEBUG(dbgs() << "\n  ... adding incoming values to worklist\n");
-        for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
-          MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
-          BasicBlock *IncomingBlock = IncomingAccess->getBlock();
-          BasicBlock *PhiBlock = DeadAccess->getBlock();
-
-          // We only consider incoming MemoryAccesses that come before the
-          // MemoryPhi. Otherwise we could discover candidates that do not
-          // strictly dominate our starting def.
-          if (State.PostOrderNumbers[IncomingBlock] >
-              State.PostOrderNumbers[PhiBlock])
-            ToCheck.insert(IncomingAccess);
+      if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
+        auto *DeadSI = dyn_cast<StoreInst>(DeadDefWrapper.DefInst);
+        auto *KillingSI = dyn_cast<StoreInst>(DefInst);
+        // We are re-using tryToMergePartialOverlappingStores, which requires
+        // DeadSI to dominate KillingSI.
+        // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
+        if (DeadSI && KillingSI && State.DT.dominates(DeadSI, KillingSI)) {
+          if (Constant *Merged = tryToMergePartialOverlappingStores(
+                  KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL,
+                  State.BatchAA, &State.DT)) {
+
+            // Update stored value of earlier store to merged constant.
+            DeadSI->setOperand(0, Merged);
+            ++NumModifiedStores;
+            ChangeState = ChangeStateEnum::PartiallyDeleteByNonMemTerm;
+
+            // Remove killing store and remove any outstanding overlap
+            // intervals for the updated store.
+            State.deleteDeadInstruction(KillingSI);
+            auto I = State.IOLs.find(DeadSI->getParent());
+            if (I != State.IOLs.end())
+              I->second.erase(DeadSI);
+            break;
+          }
         }
-        continue;
       }
-      auto *DeadDefAccess = cast<MemoryDef>(DeadAccess);
-      Instruction *DeadI = DeadDefAccess->getMemoryInst();
-      LLVM_DEBUG(dbgs() << " (" << *DeadI << ")\n");
-      ToCheck.insert(DeadDefAccess->getDefiningAccess());
-      NumGetDomMemoryDefPassed++;
-
-      if (!DebugCounter::shouldExecute(MemorySSACounter))
-        continue;
-
-      MemoryLocation DeadLoc = *State.getLocForWrite(DeadI);
-
-      if (IsMemTerm) {
-        const Value *DeadUndObj = getUnderlyingObject(DeadLoc.Ptr);
-        if (KillingUndObj != DeadUndObj)
-          continue;
-        LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: " << *DeadI
-                          << "\n  KILLER: " << *KillingI << '\n');
-        State.deleteDeadInstruction(DeadI, &Deleted);
+      if (OR == OW_Complete) {
+        LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
+                          << *DeadDefWrapper.DefInst
+                          << "\n  KILLER: " << *DefInst << '\n');
+        State.deleteDeadInstruction(DeadDefWrapper.DefInst);
         ++NumFastStores;
-        MadeChange = true;
-      } else {
-        // Check if DeadI overwrites KillingI.
-        int64_t KillingOffset = 0;
-        int64_t DeadOffset = 0;
-        OverwriteResult OR = State.isOverwrite(
-            KillingI, DeadI, KillingLoc, DeadLoc, KillingOffset, DeadOffset);
-        if (OR == OW_MaybePartial) {
-          auto Iter = State.IOLs.insert(
-              std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
-                  DeadI->getParent(), InstOverlapIntervalsTy()));
-          auto &IOL = Iter.first->second;
-          OR = isPartialOverwrite(KillingLoc, DeadLoc, KillingOffset,
-                                  DeadOffset, DeadI, IOL);
-        }
-
-        if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
-          auto *DeadSI = dyn_cast<StoreInst>(DeadI);
-          auto *KillingSI = dyn_cast<StoreInst>(KillingI);
-          // We are re-using tryToMergePartialOverlappingStores, which requires
-          // DeadSI to dominate KillingSI.
-          // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
-          if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) {
-            if (Constant *Merged = tryToMergePartialOverlappingStores(
-                    KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL,
-                    State.BatchAA, &DT)) {
-
-              // Update stored value of earlier store to merged constant.
-              DeadSI->setOperand(0, Merged);
-              ++NumModifiedStores;
-              MadeChange = true;
-
-              Shortend = true;
-              // Remove killing store and remove any outstanding overlap
-              // intervals for the updated store.
-              State.deleteDeadInstruction(KillingSI, &Deleted);
-              auto I = State.IOLs.find(DeadSI->getParent());
-              if (I != State.IOLs.end())
-                I->second.erase(DeadSI);
-              break;
-            }
-          }
-        }
-
-        if (OR == OW_Complete) {
-          LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: " << *DeadI
-                            << "\n  KILLER: " << *KillingI << '\n');
-          State.deleteDeadInstruction(DeadI, &Deleted);
-          ++NumFastStores;
-          MadeChange = true;
-        }
+        ChangeState = ChangeStateEnum::CompleteDeleteByNonMemTerm;
       }
     }
+  }
+  return ChangeState;
+}
 
-    assert(State.SkipStores.size() - OrigNumSkipStores == Deleted.size() &&
-           "SkipStores and Deleted out of sync?");
+bool MemoryDefWrapper::eliminateDeadDefs() {
+  if (DefinedLocations.empty()) {
+    LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
+                      << *DefInst << "\n");
+    return false;
+  }
+  LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " << *MemDef
+                    << " (" << *DefInst << ")\n");
+
+  assert(DefinedLocations.size() == 1 && "Expected a single defined location");
+  auto &KillingLoc = DefinedLocations.front();
+  ChangeStateEnum ChangeState = KillingLoc.eliminateDeadDefs();
+  bool Shortend = ChangeState == ChangeStateEnum::PartiallyDeleteByNonMemTerm;
+
+  // Check if the store is a no-op.
+  if (!Shortend && State.storeIsNoop(MemDef, KillingLoc.UnderlyingObject)) {
+    LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: " << *DefInst
+                      << '\n');
+    State.deleteDeadInstruction(DefInst);
+    NumRedundantStores++;
+    return true;
+  }
+  // Can we form a calloc from a memset/malloc pair?
+  if (!Shortend &&
+      State.tryFoldIntoCalloc(MemDef, KillingLoc.UnderlyingObject)) {
+    LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
+                      << "  DEAD: " << *DefInst << '\n');
+    State.deleteDeadInstruction(DefInst);
+    return true;
+  }
+  return ChangeState != ChangeStateEnum::NoChange;
+}
 
-    // Check if the store is a no-op.
-    if (!Shortend && State.storeIsNoop(KillingDef, KillingUndObj)) {
-      LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: " << *KillingI
-                        << '\n');
-      State.deleteDeadInstruction(KillingI);
-      NumRedundantStores++;
-      MadeChange = true;
+static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
+                                DominatorTree &DT, PostDominatorTree &PDT,
+                                const TargetLibraryInfo &TLI,
+                                const LoopInfo &LI) {
+  bool MadeChange = false;
+  DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
+  // For each store:
+  for (unsigned I = 0; I < State.MemDefs.size(); I++) {
+    MemoryDef *KillingDef = State.MemDefs[I];
+    if (State.SkipStores.count(KillingDef))
       continue;
-    }
 
-    // Can we form a calloc from a memset/malloc pair?
-    if (!Shortend && State.tryFoldIntoCalloc(KillingDef, KillingUndObj)) {
-      LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
-                        << "  DEAD: " << *KillingI << '\n');
-      State.deleteDeadInstruction(KillingI);
-      MadeChange = true;
-      continue;
-    }
+    MemoryDefWrapper KillingDefWrapper(KillingDef, State);
+    MadeChange |= KillingDefWrapper.eliminateDeadDefs();
   }
 
   if (EnablePartialOverwriteTracking)

>From 3347089c9dabd906460dd4373f1fadfc037449d1 Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Wed, 7 Aug 2024 21:25:33 +0000
Subject: [PATCH 2/9] Change DefinedLocations (SmallVector) to DefinedLocation
 (optional)

---
 .../lib/Transforms/Scalar/DeadStoreElimination.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index e64b7da818a985..7aae43c3a3713b 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -860,7 +860,7 @@ class MemoryDefWrapper {
   MemoryDef *MemDef;
   Instruction *DefInst;
   DSEState &State;
-  SmallVector<MemoryLocationWrapper, 1> DefinedLocations;
+  std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
 };
 
 struct DSEState {
@@ -2193,15 +2193,14 @@ MemoryDefWrapper::MemoryDefWrapper(MemoryDef *MemDef, DSEState &State)
 
   if (isMemTerminatorInst(DefInst, State.TLI)) {
     if (auto KillingLoc = State.getLocForTerminator(DefInst)) {
-      DefinedLocations.push_back(
+      DefinedLocation.emplace(
           MemoryLocationWrapper(KillingLoc->first, State, MemDef));
     }
     return;
   }
 
   if (auto KillingLoc = State.getLocForWrite(DefInst)) {
-    DefinedLocations.push_back(
-        MemoryLocationWrapper(*KillingLoc, State, MemDef));
+    DefinedLocation.emplace(MemoryLocationWrapper(*KillingLoc, State, MemDef));
   }
 }
 
@@ -2245,7 +2244,7 @@ ChangeStateEnum MemoryLocationWrapper::eliminateDeadDefs() {
       continue;
     }
     MemoryDefWrapper DeadDefWrapper(cast<MemoryDef>(DeadAccess), State);
-    MemoryLocationWrapper &DeadLoc = DeadDefWrapper.DefinedLocations.front();
+    MemoryLocationWrapper &DeadLoc = *DeadDefWrapper.DefinedLocation;
     LLVM_DEBUG(dbgs() << " (" << *DeadDefWrapper.DefInst << ")\n");
     ToCheck.insert(DeadLoc.GetDefiningAccess());
     NumGetDomMemoryDefPassed++;
@@ -2316,7 +2315,7 @@ ChangeStateEnum MemoryLocationWrapper::eliminateDeadDefs() {
 }
 
 bool MemoryDefWrapper::eliminateDeadDefs() {
-  if (DefinedLocations.empty()) {
+  if (!DefinedLocation.has_value()) {
     LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
                       << *DefInst << "\n");
     return false;
@@ -2324,8 +2323,7 @@ bool MemoryDefWrapper::eliminateDeadDefs() {
   LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " << *MemDef
                     << " (" << *DefInst << ")\n");
 
-  assert(DefinedLocations.size() == 1 && "Expected a single defined location");
-  auto &KillingLoc = DefinedLocations.front();
+  auto &KillingLoc = *DefinedLocation;
   ChangeStateEnum ChangeState = KillingLoc.eliminateDeadDefs();
   bool Shortend = ChangeState == ChangeStateEnum::PartiallyDeleteByNonMemTerm;
 

>From b0bba88dc7e3fcfaf2b953ab20ec75b8cb6a3edb Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Thu, 8 Aug 2024 20:20:15 +0000
Subject: [PATCH 3/9] Move MemoryDefWrapper and MemoryLocationWrapper to
 DSEState

---
 .../Scalar/DeadStoreElimination.cpp           | 160 +++++++++---------
 1 file changed, 80 insertions(+), 80 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 7aae43c3a3713b..ed4c69727aabd2 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -806,63 +806,6 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
   return false;
 }
 
-/// Returns true if \p I is a memory terminator instruction like
-/// llvm.lifetime.end or free.
-bool isMemTerminatorInst(Instruction *I, const TargetLibraryInfo &TLI) {
-  auto *CB = dyn_cast<CallBase>(I);
-  return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
-                getFreedOperand(CB, &TLI) != nullptr);
-}
-
-struct DSEState;
-enum class ChangeStateEnum : uint8_t {
-  NoChange,
-  DeleteByMemTerm,
-  CompleteDeleteByNonMemTerm,
-  PartiallyDeleteByNonMemTerm,
-};
-
-// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
-// defined by `MemDef`.
-class MemoryLocationWrapper {
-public:
-  MemoryLocationWrapper(MemoryLocation MemLoc, DSEState &State,
-                        MemoryDef *MemDef)
-      : MemLoc(MemLoc), State(State), MemDef(MemDef) {
-    assert(MemLoc.Ptr && "MemLoc should be not null");
-    UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
-    DefInst = MemDef->getMemoryInst();
-  }
-
-  // Try to eliminate dead defs killed by this MemoryLocation and return the
-  // change state.
-  ChangeStateEnum eliminateDeadDefs();
-  MemoryAccess *GetDefiningAccess() const {
-    return MemDef->getDefiningAccess();
-  }
-
-  MemoryLocation MemLoc;
-  const Value *UnderlyingObject;
-  DSEState &State;
-  MemoryDef *MemDef;
-  Instruction *DefInst;
-};
-
-// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
-// defined by this MemoryDef.
-class MemoryDefWrapper {
-public:
-  MemoryDefWrapper(MemoryDef *MemDef, DSEState &State);
-  // Try to eliminate dead defs killed by this MemoryDef and return the
-  // change state.
-  bool eliminateDeadDefs();
-
-  MemoryDef *MemDef;
-  Instruction *DefInst;
-  DSEState &State;
-  std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
-};
-
 struct DSEState {
   Function &F;
   AliasAnalysis &AA;
@@ -919,6 +862,72 @@ struct DSEState {
   /// Dead instructions to be removed at the end of DSE.
   SmallVector<Instruction *> ToRemove;
 
+  enum class ChangeStateEnum : uint8_t {
+    NoChange,
+    DeleteByMemTerm,
+    CompleteDeleteByNonMemTerm,
+    PartiallyDeleteByNonMemTerm,
+  };
+
+  // A memory location wrapper that represents a MemoryLocation, `MemLoc`,
+  // defined by `MemDef`.
+  class MemoryLocationWrapper {
+  public:
+    MemoryLocationWrapper(MemoryLocation MemLoc, DSEState &State,
+                          MemoryDef *MemDef)
+        : MemLoc(MemLoc), State(State), MemDef(MemDef) {
+      assert(MemLoc.Ptr && "MemLoc should be not null");
+      UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
+      DefInst = MemDef->getMemoryInst();
+    }
+
+    // Try to eliminate dead defs killed by this MemoryLocation and return the
+    // change state.
+    ChangeStateEnum eliminateDeadDefs();
+
+    MemoryAccess *GetDefiningAccess() const {
+      return MemDef->getDefiningAccess();
+    }
+
+    MemoryLocation MemLoc;
+    const Value *UnderlyingObject;
+    DSEState &State;
+    MemoryDef *MemDef;
+    Instruction *DefInst;
+  };
+
+  // A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
+  // defined by this MemoryDef.
+  class MemoryDefWrapper {
+  public:
+    MemoryDefWrapper(MemoryDef *MemDef, DSEState &State)
+        : MemDef(MemDef), State(State) {
+      DefInst = MemDef->getMemoryInst();
+
+      if (State.isMemTerminatorInst(DefInst)) {
+        if (auto KillingLoc = State.getLocForTerminator(DefInst)) {
+          DefinedLocation.emplace(
+              MemoryLocationWrapper(KillingLoc->first, State, MemDef));
+        }
+        return;
+      }
+
+      if (auto KillingLoc = State.getLocForWrite(DefInst)) {
+        DefinedLocation.emplace(
+            MemoryLocationWrapper(*KillingLoc, State, MemDef));
+      }
+    }
+
+    // Try to eliminate dead defs killed by this MemoryDef and return the
+    // change state.
+    bool eliminateDeadDefs();
+
+    MemoryDef *MemDef;
+    Instruction *DefInst;
+    DSEState &State;
+    std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
+  };
+
   // Class contains self-reference, make sure it's not copied/moved.
   DSEState(const DSEState &) = delete;
   DSEState &operator=(const DSEState &) = delete;
@@ -940,7 +949,7 @@ struct DSEState {
 
         auto *MD = dyn_cast_or_null<MemoryDef>(MA);
         if (MD && MemDefs.size() < MemorySSADefsPerBlockLimit &&
-            (getLocForWrite(&I) || isMemTerminatorInst(&I, TLI)))
+            (getLocForWrite(&I) || isMemTerminatorInst(&I)))
           MemDefs.push_back(MD);
       }
     }
@@ -1282,6 +1291,14 @@ struct DSEState {
     return std::nullopt;
   }
 
+  /// Returns true if \p I is a memory terminator instruction like
+  /// llvm.lifetime.end or free.
+  bool isMemTerminatorInst(Instruction *I) {
+    auto *CB = dyn_cast<CallBase>(I);
+    return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
+                  getFreedOperand(CB, &TLI) != nullptr);
+  }
+
   /// Returns true if \p MaybeTerm is a memory terminator for \p Loc from
   /// instruction \p AccessI.
   bool isMemTerminator(const MemoryLocation &Loc, Instruction *AccessI,
@@ -1496,7 +1513,7 @@ struct DSEState {
         continue;
       }
 
-      if (isMemTerminatorInst(KillingLoc.DefInst, TLI)) {
+      if (isMemTerminatorInst(KillingLoc.DefInst)) {
         // If the killing def is a memory terminator (e.g. lifetime.end), check
         // the next candidate if the current Current does not write the same
         // underlying object as the terminator.
@@ -2187,24 +2204,7 @@ struct DSEState {
   }
 };
 
-MemoryDefWrapper::MemoryDefWrapper(MemoryDef *MemDef, DSEState &State)
-    : MemDef(MemDef), State(State) {
-  DefInst = MemDef->getMemoryInst();
-
-  if (isMemTerminatorInst(DefInst, State.TLI)) {
-    if (auto KillingLoc = State.getLocForTerminator(DefInst)) {
-      DefinedLocation.emplace(
-          MemoryLocationWrapper(KillingLoc->first, State, MemDef));
-    }
-    return;
-  }
-
-  if (auto KillingLoc = State.getLocForWrite(DefInst)) {
-    DefinedLocation.emplace(MemoryLocationWrapper(*KillingLoc, State, MemDef));
-  }
-}
-
-ChangeStateEnum MemoryLocationWrapper::eliminateDeadDefs() {
+DSEState::ChangeStateEnum DSEState::MemoryLocationWrapper::eliminateDeadDefs() {
   ChangeStateEnum ChangeState = ChangeStateEnum::NoChange;
   unsigned ScanLimit = MemorySSAScanLimit;
   unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
@@ -2251,7 +2251,7 @@ ChangeStateEnum MemoryLocationWrapper::eliminateDeadDefs() {
 
     if (!DebugCounter::shouldExecute(MemorySSACounter))
       continue;
-    if (isMemTerminatorInst(DefInst, State.TLI)) {
+    if (State.isMemTerminatorInst(DefInst)) {
       if (!(UnderlyingObject == DeadLoc.UnderlyingObject))
         continue;
       LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
@@ -2314,7 +2314,7 @@ ChangeStateEnum MemoryLocationWrapper::eliminateDeadDefs() {
   return ChangeState;
 }
 
-bool MemoryDefWrapper::eliminateDeadDefs() {
+bool DSEState::MemoryDefWrapper::eliminateDeadDefs() {
   if (!DefinedLocation.has_value()) {
     LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
                       << *DefInst << "\n");
@@ -2358,7 +2358,7 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
     if (State.SkipStores.count(KillingDef))
       continue;
 
-    MemoryDefWrapper KillingDefWrapper(KillingDef, State);
+    DSEState::MemoryDefWrapper KillingDefWrapper(KillingDef, State);
     MadeChange |= KillingDefWrapper.eliminateDeadDefs();
   }
 

>From b737ebfa5e7b3f06a8f4bb548dd3b75949b9437a Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Thu, 8 Aug 2024 23:58:58 +0000
Subject: [PATCH 4/9] Move MemoryDefWrapper and MemoryLocationWrapper back to
 out DSEState, but keep eliminateDeadDefs()

---
 .../Scalar/DeadStoreElimination.cpp           | 218 +++++++++---------
 1 file changed, 107 insertions(+), 111 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index ed4c69727aabd2..366db7b598782a 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -806,6 +806,38 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
   return false;
 }
 
+// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
+// defined by `MemDef`.
+class MemoryLocationWrapper {
+public:
+  MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef)
+      : MemLoc(MemLoc), MemDef(MemDef) {
+    assert(MemLoc.Ptr && "MemLoc should be not null");
+    UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
+    DefInst = MemDef->getMemoryInst();
+  }
+
+  MemoryAccess *GetDefiningAccess() const {
+    return MemDef->getDefiningAccess();
+  }
+
+  MemoryLocation MemLoc;
+  const Value *UnderlyingObject;
+  MemoryDef *MemDef;
+  Instruction *DefInst;
+};
+
+// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
+// defined by this MemoryDef.
+class MemoryDefWrapper {
+public:
+  MemoryDefWrapper(MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
+    if (MemLoc.has_value())
+      DefinedLocation = MemoryLocationWrapper(*MemLoc, MemDef);
+  }
+  std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
+};
+
 struct DSEState {
   Function &F;
   AliasAnalysis &AA;
@@ -862,72 +894,6 @@ struct DSEState {
   /// Dead instructions to be removed at the end of DSE.
   SmallVector<Instruction *> ToRemove;
 
-  enum class ChangeStateEnum : uint8_t {
-    NoChange,
-    DeleteByMemTerm,
-    CompleteDeleteByNonMemTerm,
-    PartiallyDeleteByNonMemTerm,
-  };
-
-  // A memory location wrapper that represents a MemoryLocation, `MemLoc`,
-  // defined by `MemDef`.
-  class MemoryLocationWrapper {
-  public:
-    MemoryLocationWrapper(MemoryLocation MemLoc, DSEState &State,
-                          MemoryDef *MemDef)
-        : MemLoc(MemLoc), State(State), MemDef(MemDef) {
-      assert(MemLoc.Ptr && "MemLoc should be not null");
-      UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
-      DefInst = MemDef->getMemoryInst();
-    }
-
-    // Try to eliminate dead defs killed by this MemoryLocation and return the
-    // change state.
-    ChangeStateEnum eliminateDeadDefs();
-
-    MemoryAccess *GetDefiningAccess() const {
-      return MemDef->getDefiningAccess();
-    }
-
-    MemoryLocation MemLoc;
-    const Value *UnderlyingObject;
-    DSEState &State;
-    MemoryDef *MemDef;
-    Instruction *DefInst;
-  };
-
-  // A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
-  // defined by this MemoryDef.
-  class MemoryDefWrapper {
-  public:
-    MemoryDefWrapper(MemoryDef *MemDef, DSEState &State)
-        : MemDef(MemDef), State(State) {
-      DefInst = MemDef->getMemoryInst();
-
-      if (State.isMemTerminatorInst(DefInst)) {
-        if (auto KillingLoc = State.getLocForTerminator(DefInst)) {
-          DefinedLocation.emplace(
-              MemoryLocationWrapper(KillingLoc->first, State, MemDef));
-        }
-        return;
-      }
-
-      if (auto KillingLoc = State.getLocForWrite(DefInst)) {
-        DefinedLocation.emplace(
-            MemoryLocationWrapper(*KillingLoc, State, MemDef));
-      }
-    }
-
-    // Try to eliminate dead defs killed by this MemoryDef and return the
-    // change state.
-    bool eliminateDeadDefs();
-
-    MemoryDef *MemDef;
-    Instruction *DefInst;
-    DSEState &State;
-    std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
-  };
-
   // Class contains self-reference, make sure it's not copied/moved.
   DSEState(const DSEState &) = delete;
   DSEState &operator=(const DSEState &) = delete;
@@ -1185,6 +1151,15 @@ struct DSEState {
     return MemoryLocation::getOrNone(I);
   }
 
+  std::optional<MemoryLocation> getLocForInst(Instruction *I) {
+    if (isMemTerminatorInst(I)) {
+      if (auto Loc = getLocForTerminator(I)) {
+        return Loc->first;
+      }
+    }
+    return getLocForWrite(I);
+  }
+
   /// Assuming this instruction has a dead analyzable write, can we delete
   /// this instruction?
   bool isRemovable(Instruction *I) {
@@ -2202,24 +2177,39 @@ struct DSEState {
     }
     return MadeChange;
   }
+
+  enum class ChangeStateEnum : uint8_t {
+    NoChange,
+    DeleteByMemTerm,
+    CompleteDeleteByNonMemTerm,
+    PartiallyDeleteByNonMemTerm,
+  };
+  // Try to eliminate dead defs killed by `KillingLocWrapper` and return the
+  // change state.
+  ChangeStateEnum eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper);
+
+  // Try to eliminate dead defs killed by `KillingDefWrapper` and return the
+  // change state.
+  bool eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper);
 };
 
-DSEState::ChangeStateEnum DSEState::MemoryLocationWrapper::eliminateDeadDefs() {
+DSEState::ChangeStateEnum
+DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
   ChangeStateEnum ChangeState = ChangeStateEnum::NoChange;
   unsigned ScanLimit = MemorySSAScanLimit;
   unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
   unsigned PartialLimit = MemorySSAPartialStoreLimit;
   // Worklist of MemoryAccesses that may be killed by KillingDef.
   SmallSetVector<MemoryAccess *, 8> ToCheck;
-  ToCheck.insert(GetDefiningAccess());
+  ToCheck.insert(KillingLocWrapper.GetDefiningAccess());
 
   // Check if MemoryAccesses in the worklist are killed by KillingDef.
   for (unsigned I = 0; I < ToCheck.size(); I++) {
     MemoryAccess *Current = ToCheck[I];
-    if (State.SkipStores.count(Current))
+    if (SkipStores.count(Current))
       continue;
-    std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
-        *this, Current, ScanLimit, WalkerStepLimit, PartialLimit);
+    std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef(
+        KillingLocWrapper, Current, ScanLimit, WalkerStepLimit, PartialLimit);
 
     if (!MaybeDeadAccess) {
       LLVM_DEBUG(dbgs() << "  finished walk\n");
@@ -2237,27 +2227,29 @@ DSEState::ChangeStateEnum DSEState::MemoryLocationWrapper::eliminateDeadDefs() {
         // We only consider incoming MemoryAccesses that come before the
         // MemoryPhi. Otherwise we could discover candidates that do not
         // strictly dominate our starting def.
-        if (State.PostOrderNumbers[IncomingBlock] >
-            State.PostOrderNumbers[PhiBlock])
+        if (PostOrderNumbers[IncomingBlock] > PostOrderNumbers[PhiBlock])
           ToCheck.insert(IncomingAccess);
       }
       continue;
     }
-    MemoryDefWrapper DeadDefWrapper(cast<MemoryDef>(DeadAccess), State);
-    MemoryLocationWrapper &DeadLoc = *DeadDefWrapper.DefinedLocation;
-    LLVM_DEBUG(dbgs() << " (" << *DeadDefWrapper.DefInst << ")\n");
-    ToCheck.insert(DeadLoc.GetDefiningAccess());
+    MemoryDefWrapper DeadDefWrapper(
+        cast<MemoryDef>(DeadAccess),
+        getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst()));
+    MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation;
+    LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n");
+    ToCheck.insert(DeadLocWrapper.GetDefiningAccess());
     NumGetDomMemoryDefPassed++;
 
     if (!DebugCounter::shouldExecute(MemorySSACounter))
       continue;
-    if (State.isMemTerminatorInst(DefInst)) {
-      if (!(UnderlyingObject == DeadLoc.UnderlyingObject))
+    if (isMemTerminatorInst(KillingLocWrapper.DefInst)) {
+      if (!(KillingLocWrapper.UnderlyingObject ==
+            DeadLocWrapper.UnderlyingObject))
         continue;
       LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
-                        << *DeadDefWrapper.DefInst << "\n  KILLER: " << *DefInst
-                        << '\n');
-      State.deleteDeadInstruction(DeadDefWrapper.DefInst);
+                        << *DeadLocWrapper.DefInst << "\n  KILLER: "
+                        << *KillingLocWrapper.DefInst << '\n');
+      deleteDeadInstruction(DeadLocWrapper.DefInst);
       ++NumFastStores;
       ChangeState = ChangeStateEnum::DeleteByMemTerm;
     } else {
@@ -2265,26 +2257,28 @@ DSEState::ChangeStateEnum DSEState::MemoryLocationWrapper::eliminateDeadDefs() {
       int64_t KillingOffset = 0;
       int64_t DeadOffset = 0;
       OverwriteResult OR =
-          State.isOverwrite(DefInst, DeadDefWrapper.DefInst, MemLoc,
-                            DeadLoc.MemLoc, KillingOffset, DeadOffset);
+          isOverwrite(KillingLocWrapper.DefInst, DeadLocWrapper.DefInst,
+                      KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc,
+                      KillingOffset, DeadOffset);
       if (OR == OW_MaybePartial) {
-        auto Iter = State.IOLs.insert(
-            std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
-                DeadDefWrapper.DefInst->getParent(), InstOverlapIntervalsTy()));
+        auto Iter =
+            IOLs.insert(std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
+                DeadLocWrapper.DefInst->getParent(), InstOverlapIntervalsTy()));
         auto &IOL = Iter.first->second;
-        OR = isPartialOverwrite(MemLoc, DeadLoc.MemLoc, KillingOffset,
-                                DeadOffset, DeadDefWrapper.DefInst, IOL);
+        OR = isPartialOverwrite(KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc,
+                                KillingOffset, DeadOffset,
+                                DeadLocWrapper.DefInst, IOL);
       }
       if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
-        auto *DeadSI = dyn_cast<StoreInst>(DeadDefWrapper.DefInst);
-        auto *KillingSI = dyn_cast<StoreInst>(DefInst);
+        auto *DeadSI = dyn_cast<StoreInst>(DeadLocWrapper.DefInst);
+        auto *KillingSI = dyn_cast<StoreInst>(KillingLocWrapper.DefInst);
         // We are re-using tryToMergePartialOverlappingStores, which requires
         // DeadSI to dominate KillingSI.
         // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
-        if (DeadSI && KillingSI && State.DT.dominates(DeadSI, KillingSI)) {
+        if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) {
           if (Constant *Merged = tryToMergePartialOverlappingStores(
-                  KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL,
-                  State.BatchAA, &State.DT)) {
+                  KillingSI, DeadSI, KillingOffset, DeadOffset, DL, BatchAA,
+                  &DT)) {
 
             // Update stored value of earlier store to merged constant.
             DeadSI->setOperand(0, Merged);
@@ -2293,9 +2287,9 @@ DSEState::ChangeStateEnum DSEState::MemoryLocationWrapper::eliminateDeadDefs() {
 
             // Remove killing store and remove any outstanding overlap
             // intervals for the updated store.
-            State.deleteDeadInstruction(KillingSI);
-            auto I = State.IOLs.find(DeadSI->getParent());
-            if (I != State.IOLs.end())
+            deleteDeadInstruction(KillingSI);
+            auto I = IOLs.find(DeadSI->getParent());
+            if (I != IOLs.end())
               I->second.erase(DeadSI);
             break;
           }
@@ -2303,9 +2297,9 @@ DSEState::ChangeStateEnum DSEState::MemoryLocationWrapper::eliminateDeadDefs() {
       }
       if (OR == OW_Complete) {
         LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
-                          << *DeadDefWrapper.DefInst
+                          << *DeadLocWrapper.DefInst
                           << "\n  KILLER: " << *DefInst << '\n');
-        State.deleteDeadInstruction(DeadDefWrapper.DefInst);
+        deleteDeadInstruction(DeadLocWrapper.DefInst);
         ++NumFastStores;
         ChangeState = ChangeStateEnum::CompleteDeleteByNonMemTerm;
       }
@@ -2314,8 +2308,8 @@ DSEState::ChangeStateEnum DSEState::MemoryLocationWrapper::eliminateDeadDefs() {
   return ChangeState;
 }
 
-bool DSEState::MemoryDefWrapper::eliminateDeadDefs() {
-  if (!DefinedLocation.has_value()) {
+bool DSEState::eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper) {
+  if (!KillingDefWrapper.DefinedLocation.has_value()) {
     LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
                       << *DefInst << "\n");
     return false;
@@ -2323,24 +2317,25 @@ bool DSEState::MemoryDefWrapper::eliminateDeadDefs() {
   LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " << *MemDef
                     << " (" << *DefInst << ")\n");
 
-  auto &KillingLoc = *DefinedLocation;
-  ChangeStateEnum ChangeState = KillingLoc.eliminateDeadDefs();
+  auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation;
+  ChangeStateEnum ChangeState = eliminateDeadDefs(KillingLocWrapper);
   bool Shortend = ChangeState == ChangeStateEnum::PartiallyDeleteByNonMemTerm;
 
   // Check if the store is a no-op.
-  if (!Shortend && State.storeIsNoop(MemDef, KillingLoc.UnderlyingObject)) {
+  if (!Shortend && storeIsNoop(KillingLocWrapper.MemDef,
+                               KillingLocWrapper.UnderlyingObject)) {
     LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: " << *DefInst
                       << '\n');
-    State.deleteDeadInstruction(DefInst);
+    deleteDeadInstruction(KillingLocWrapper.DefInst);
     NumRedundantStores++;
     return true;
   }
   // Can we form a calloc from a memset/malloc pair?
-  if (!Shortend &&
-      State.tryFoldIntoCalloc(MemDef, KillingLoc.UnderlyingObject)) {
+  if (!Shortend && tryFoldIntoCalloc(KillingLocWrapper.MemDef,
+                                     KillingLocWrapper.UnderlyingObject)) {
     LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
-                      << "  DEAD: " << *DefInst << '\n');
-    State.deleteDeadInstruction(DefInst);
+                      << "  DEAD: " << *KillingLocWrapper.DefInst << '\n');
+    deleteDeadInstruction(KillingLocWrapper.DefInst);
     return true;
   }
   return ChangeState != ChangeStateEnum::NoChange;
@@ -2358,8 +2353,9 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
     if (State.SkipStores.count(KillingDef))
       continue;
 
-    DSEState::MemoryDefWrapper KillingDefWrapper(KillingDef, State);
-    MadeChange |= KillingDefWrapper.eliminateDeadDefs();
+    MemoryDefWrapper KillingDefWrapper(
+        KillingDef, State.getLocForInst(KillingDef->getMemoryInst()));
+    MadeChange |= State.eliminateDeadDefs(KillingDefWrapper);
   }
 
   if (EnablePartialOverwriteTracking)

>From 06b293f7f5c96922f6af30b5c2f5898a577ac311 Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Mon, 12 Aug 2024 22:17:01 +0000
Subject: [PATCH 5/9] Change ChangeStateEnum to std::pair<bool, bool>

---
 .../Scalar/DeadStoreElimination.cpp           | 41 +++++++++----------
 1 file changed, 19 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 366db7b598782a..002826c6285178 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -2178,24 +2178,21 @@ struct DSEState {
     return MadeChange;
   }
 
-  enum class ChangeStateEnum : uint8_t {
-    NoChange,
-    DeleteByMemTerm,
-    CompleteDeleteByNonMemTerm,
-    PartiallyDeleteByNonMemTerm,
-  };
   // Try to eliminate dead defs killed by `KillingLocWrapper` and return the
-  // change state.
-  ChangeStateEnum eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper);
+  // change state: whether make any change, and whether make a partial delete
+  // by a non memory-terminator instruction.
+  std::pair<bool, bool>
+  eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper);
 
   // Try to eliminate dead defs killed by `KillingDefWrapper` and return the
-  // change state.
+  // change state: whether make any change.
   bool eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper);
 };
 
-DSEState::ChangeStateEnum
+std::pair<bool, bool>
 DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
-  ChangeStateEnum ChangeState = ChangeStateEnum::NoChange;
+  bool Changed = false;
+  bool Shortened = false;
   unsigned ScanLimit = MemorySSAScanLimit;
   unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
   unsigned PartialLimit = MemorySSAPartialStoreLimit;
@@ -2251,7 +2248,7 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
                         << *KillingLocWrapper.DefInst << '\n');
       deleteDeadInstruction(DeadLocWrapper.DefInst);
       ++NumFastStores;
-      ChangeState = ChangeStateEnum::DeleteByMemTerm;
+      Changed = true;
     } else {
       // Check if DeadI overwrites KillingI.
       int64_t KillingOffset = 0;
@@ -2283,7 +2280,8 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
             // Update stored value of earlier store to merged constant.
             DeadSI->setOperand(0, Merged);
             ++NumModifiedStores;
-            ChangeState = ChangeStateEnum::PartiallyDeleteByNonMemTerm;
+            Changed = true;
+            Shortened = true;
 
             // Remove killing store and remove any outstanding overlap
             // intervals for the updated store.
@@ -2301,11 +2299,11 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
                           << "\n  KILLER: " << *DefInst << '\n');
         deleteDeadInstruction(DeadLocWrapper.DefInst);
         ++NumFastStores;
-        ChangeState = ChangeStateEnum::CompleteDeleteByNonMemTerm;
+        Changed = true;
       }
     }
   }
-  return ChangeState;
+  return {Changed, Shortened};
 }
 
 bool DSEState::eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper) {
@@ -2318,12 +2316,11 @@ bool DSEState::eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper) {
                     << " (" << *DefInst << ")\n");
 
   auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation;
-  ChangeStateEnum ChangeState = eliminateDeadDefs(KillingLocWrapper);
-  bool Shortend = ChangeState == ChangeStateEnum::PartiallyDeleteByNonMemTerm;
+  auto [Changed, Shortened] = eliminateDeadDefs(KillingLocWrapper);
 
   // Check if the store is a no-op.
-  if (!Shortend && storeIsNoop(KillingLocWrapper.MemDef,
-                               KillingLocWrapper.UnderlyingObject)) {
+  if (!Shortened && storeIsNoop(KillingLocWrapper.MemDef,
+                                KillingLocWrapper.UnderlyingObject)) {
     LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: " << *DefInst
                       << '\n');
     deleteDeadInstruction(KillingLocWrapper.DefInst);
@@ -2331,14 +2328,14 @@ bool DSEState::eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper) {
     return true;
   }
   // Can we form a calloc from a memset/malloc pair?
-  if (!Shortend && tryFoldIntoCalloc(KillingLocWrapper.MemDef,
-                                     KillingLocWrapper.UnderlyingObject)) {
+  if (!Shortened && tryFoldIntoCalloc(KillingLocWrapper.MemDef,
+                                      KillingLocWrapper.UnderlyingObject)) {
     LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
                       << "  DEAD: " << *KillingLocWrapper.DefInst << '\n');
     deleteDeadInstruction(KillingLocWrapper.DefInst);
     return true;
   }
-  return ChangeState != ChangeStateEnum::NoChange;
+  return Changed;
 }
 
 static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,

>From 65971be4f3569b7f2d8f9882419190e7c729bccb Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Sat, 17 Aug 2024 00:05:42 +0000
Subject: [PATCH 6/9] Update variable names and comments

---
 .../Scalar/DeadStoreElimination.cpp           | 184 ++++++++++--------
 1 file changed, 98 insertions(+), 86 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 002826c6285178..585fd4c037d536 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -817,10 +817,6 @@ class MemoryLocationWrapper {
     DefInst = MemDef->getMemoryInst();
   }
 
-  MemoryAccess *GetDefiningAccess() const {
-    return MemDef->getDefiningAccess();
-  }
-
   MemoryLocation MemLoc;
   const Value *UnderlyingObject;
   MemoryDef *MemDef;
@@ -832,9 +828,11 @@ class MemoryLocationWrapper {
 class MemoryDefWrapper {
 public:
   MemoryDefWrapper(MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
+    DefInst = MemDef->getMemoryInst();
     if (MemLoc.has_value())
       DefinedLocation = MemoryLocationWrapper(*MemLoc, MemDef);
   }
+  Instruction *DefInst;
   std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
 };
 
@@ -1268,7 +1266,7 @@ struct DSEState {
 
   /// Returns true if \p I is a memory terminator instruction like
   /// llvm.lifetime.end or free.
-  bool isMemTerminatorInst(Instruction *I) {
+  bool isMemTerminatorInst(Instruction *I) const {
     auto *CB = dyn_cast<CallBase>(I);
     return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
                   getFreedOperand(CB, &TLI) != nullptr);
@@ -1359,16 +1357,17 @@ struct DSEState {
     return true;
   }
 
-  // Find a MemoryDef writing to \p KillingLoc and dominating \p StartAccess,
-  // with no read access between them or on any other path to a function exit
-  // block if \p KillingLoc is not accessible after the function returns. If
-  // there is no such MemoryDef, return std::nullopt. The returned value may not
-  // (completely) overwrite \p KillingLoc. Currently we bail out when we
-  // encounter an aliasing MemoryUse (read).
+  // Find a MemoryDef writing to \p KillingLocWrapper.MemLoc and dominating
+  // \p StartAccess, with no read access between them or on any other path to
+  // a function exit block if \p KillingLocWrapper.MemLoc is not accessible
+  // after the function returns. If there is no such MemoryDef, return
+  // std::nullopt. The returned value may not (completely) overwrite
+  // \p KillingLocWrapper.MemLoc. Currently we bail out when we encounter
+  // an aliasing MemoryUse (read).
   std::optional<MemoryAccess *>
-  getDomMemoryDef(MemoryLocationWrapper &KillingLoc, MemoryAccess *StartAccess,
-                  unsigned &ScanLimit, unsigned &WalkerStepLimit,
-                  unsigned &PartialLimit) {
+  getDomMemoryDef(MemoryLocationWrapper &KillingLocWrapper,
+                  MemoryAccess *StartAccess, unsigned &ScanLimit,
+                  unsigned &WalkerStepLimit, unsigned &PartialLimit) {
     if (ScanLimit == 0 || WalkerStepLimit == 0) {
       LLVM_DEBUG(dbgs() << "\n    ...  hit scan limit\n");
       return std::nullopt;
@@ -1377,14 +1376,15 @@ struct DSEState {
     MemoryAccess *Current = StartAccess;
     LLVM_DEBUG(dbgs() << "  trying to get dominating access\n");
 
-    // Only optimize defining access of KillingDef when directly starting at its
-    // defining access. The defining access also must only access KillingLoc. At
-    // the moment we only support instructions with a single write location, so
-    // it should be sufficient to disable optimizations for instructions that
-    // also read from memory.
-    bool CanOptimize = OptimizeMemorySSA &&
-                       KillingLoc.GetDefiningAccess() == StartAccess &&
-                       !KillingLoc.DefInst->mayReadFromMemory();
+    // Only optimize defining access of "KillingLocWrapper.MemDef" when directly
+    // starting at its defining access. The defining access also must only
+    // access "KillingLocWrapper.MemLoc". At the moment we only support
+    // instructions with a single write location, so it should be sufficient
+    // to disable optimizations for instructions that also read from memory.
+    bool CanOptimize =
+        OptimizeMemorySSA &&
+        KillingLocWrapper.MemDef->getDefiningAccess() == StartAccess &&
+        !KillingLocWrapper.DefInst->mayReadFromMemory();
 
     // Find the next clobbering Mod access for DefLoc, starting at StartAccess.
     std::optional<MemoryLocation> CurrentLoc;
@@ -1400,17 +1400,19 @@ struct DSEState {
       // Reached TOP.
       if (MSSA.isLiveOnEntryDef(Current)) {
         LLVM_DEBUG(dbgs() << "   ...  found LiveOnEntryDef\n");
-        if (CanOptimize && Current != KillingLoc.GetDefiningAccess())
+        if (CanOptimize &&
+            Current != KillingLocWrapper.MemDef->getDefiningAccess())
           // The first clobbering def is... none.
-          KillingLoc.MemDef->setOptimized(Current);
+          KillingLocWrapper.MemDef->setOptimized(Current);
         return std::nullopt;
       }
 
       // Cost of a step. Accesses in the same block are more likely to be valid
       // candidates for elimination, hence consider them cheaper.
-      unsigned StepCost = KillingLoc.MemDef->getBlock() == Current->getBlock()
-                              ? MemorySSASameBBStepCost
-                              : MemorySSAOtherBBStepCost;
+      unsigned StepCost =
+          KillingLocWrapper.MemDef->getBlock() == Current->getBlock()
+              ? MemorySSASameBBStepCost
+              : MemorySSAOtherBBStepCost;
       if (WalkerStepLimit <= StepCost) {
         LLVM_DEBUG(dbgs() << "   ...  hit walker step limit\n");
         return std::nullopt;
@@ -1425,27 +1427,27 @@ struct DSEState {
       }
 
       // Below, check if CurrentDef is a valid candidate to be eliminated by
-      // KillingDef. If it is not, check the next candidate.
+      // "KillingLocWrapper.MemDef". If it is not, check the next candidate.
       MemoryDef *CurrentDef = cast<MemoryDef>(Current);
       Instruction *CurrentI = CurrentDef->getMemoryInst();
 
       if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(
-                                     KillingLoc.UnderlyingObject))) {
+                                     KillingLocWrapper.UnderlyingObject))) {
         CanOptimize = false;
         continue;
       }
 
       // Before we try to remove anything, check for any extra throwing
       // instructions that block us from DSEing
-      if (mayThrowBetween(KillingLoc.DefInst, CurrentI,
-                          KillingLoc.UnderlyingObject)) {
+      if (mayThrowBetween(KillingLocWrapper.DefInst, CurrentI,
+                          KillingLocWrapper.UnderlyingObject)) {
         LLVM_DEBUG(dbgs() << "  ... skip, may throw!\n");
         return std::nullopt;
       }
 
       // Check for anything that looks like it will be a barrier to further
       // removal
-      if (isDSEBarrier(KillingLoc.UnderlyingObject, CurrentI)) {
+      if (isDSEBarrier(KillingLocWrapper.UnderlyingObject, CurrentI)) {
         LLVM_DEBUG(dbgs() << "  ... skip, barrier\n");
         return std::nullopt;
       }
@@ -1455,17 +1457,18 @@ struct DSEState {
       // for intrinsic calls, because the code knows how to handle memcpy
       // intrinsics.
       if (!isa<IntrinsicInst>(CurrentI) &&
-          isReadClobber(KillingLoc.MemLoc, CurrentI))
+          isReadClobber(KillingLocWrapper.MemLoc, CurrentI))
         return std::nullopt;
 
       // Quick check if there are direct uses that are read-clobbers.
-      if (any_of(Current->uses(), [this, &KillingLoc, StartAccess](Use &U) {
-            if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(U.getUser()))
-              return !MSSA.dominates(StartAccess, UseOrDef) &&
-                     isReadClobber(KillingLoc.MemLoc,
-                                   UseOrDef->getMemoryInst());
-            return false;
-          })) {
+      if (any_of(Current->uses(),
+                 [this, &KillingLocWrapper, StartAccess](Use &U) {
+                   if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(U.getUser()))
+                     return !MSSA.dominates(StartAccess, UseOrDef) &&
+                            isReadClobber(KillingLocWrapper.MemLoc,
+                                          UseOrDef->getMemoryInst());
+                   return false;
+                 })) {
         LLVM_DEBUG(dbgs() << "   ...  found a read clobber\n");
         return std::nullopt;
       }
@@ -1481,33 +1484,36 @@ struct DSEState {
       // AliasAnalysis does not account for loops. Limit elimination to
       // candidates for which we can guarantee they always store to the same
       // memory location and not located in different loops.
-      if (!isGuaranteedLoopIndependent(CurrentI, KillingLoc.DefInst,
+      if (!isGuaranteedLoopIndependent(CurrentI, KillingLocWrapper.DefInst,
                                        *CurrentLoc)) {
         LLVM_DEBUG(dbgs() << "  ... not guaranteed loop independent\n");
         CanOptimize = false;
         continue;
       }
 
-      if (isMemTerminatorInst(KillingLoc.DefInst)) {
-        // If the killing def is a memory terminator (e.g. lifetime.end), check
-        // the next candidate if the current Current does not write the same
-        // underlying object as the terminator.
-        if (!isMemTerminator(*CurrentLoc, CurrentI, KillingLoc.DefInst)) {
+      if (isMemTerminatorInst(KillingLocWrapper.DefInst)) {
+        // If "KillingLocWrapper.DefInst" is a memory terminator
+        // (e.g. lifetime.end), check the next candidate if the current
+        // Current does not write the same underlying object as the terminator.
+        if (!isMemTerminator(*CurrentLoc, CurrentI,
+                             KillingLocWrapper.DefInst)) {
           CanOptimize = false;
           continue;
         }
       } else {
         int64_t KillingOffset = 0;
         int64_t DeadOffset = 0;
-        auto OR = isOverwrite(KillingLoc.DefInst, CurrentI, KillingLoc.MemLoc,
-                              *CurrentLoc, KillingOffset, DeadOffset);
+        auto OR = isOverwrite(KillingLocWrapper.DefInst, CurrentI,
+                              KillingLocWrapper.MemLoc, *CurrentLoc,
+                              KillingOffset, DeadOffset);
         if (CanOptimize) {
-          // CurrentDef is the earliest write clobber of KillingDef. Use it as
-          // optimized access. Do not optimize if CurrentDef is already the
-          // defining access of KillingDef.
-          if (CurrentDef != KillingLoc.GetDefiningAccess() &&
+          // CurrentDef is the earliest write clobber of
+          // "KillingLocWrapper.MemDef". Use it as optimized access. Do not
+          // optimize if CurrentDef is already the defining access of
+          // "KillingLocWrapper.MemDef".
+          if (CurrentDef != KillingLocWrapper.MemDef->getDefiningAccess() &&
               (OR == OW_Complete || OR == OW_MaybePartial))
-            KillingLoc.MemDef->setOptimized(CurrentDef);
+            KillingLocWrapper.MemDef->setOptimized(CurrentDef);
 
           // Once a may-aliasing def is encountered do not set an optimized
           // access.
@@ -1515,15 +1521,15 @@ struct DSEState {
             CanOptimize = false;
         }
 
-        // If Current does not write to the same object as KillingDef, check
-        // the next candidate.
+        // If Current does not write to the same object as
+        // "KillingLocWrapper.MemDef", check the next candidate.
         if (OR == OW_Unknown || OR == OW_None)
           continue;
         else if (OR == OW_MaybePartial) {
-          // If KillingDef only partially overwrites Current, check the next
-          // candidate if the partial step limit is exceeded. This aggressively
-          // limits the number of candidates for partial store elimination,
-          // which are less likely to be removable in the end.
+          // If "KillingLocWrapper.MemDef" only partially overwrites Current,
+          // check the next candidate if the partial step limit is exceeded.
+          // This aggressively limits the number of candidates for partial store
+          // elimination, which are less likely to be removable in the end.
           if (PartialLimit <= 1) {
             WalkerStepLimit -= 1;
             LLVM_DEBUG(dbgs() << "   ... reached partial limit ... continue with next access\n");
@@ -1540,7 +1546,7 @@ struct DSEState {
     // the blocks with killing (=completely overwriting MemoryDefs) and check if
     // they cover all paths from MaybeDeadAccess to any function exit.
     SmallPtrSet<Instruction *, 16> KillingDefs;
-    KillingDefs.insert(KillingLoc.DefInst);
+    KillingDefs.insert(KillingLocWrapper.DefInst);
     MemoryAccess *MaybeDeadAccess = Current;
     MemoryLocation MaybeDeadLoc = *CurrentLoc;
     Instruction *MaybeDeadI = cast<MemoryDef>(MaybeDeadAccess)->getMemoryInst();
@@ -1603,7 +1609,7 @@ struct DSEState {
       }
 
       if (UseInst->mayThrow() &&
-          !isInvisibleToCallerOnUnwind(KillingLoc.UnderlyingObject)) {
+          !isInvisibleToCallerOnUnwind(KillingLocWrapper.UnderlyingObject)) {
         LLVM_DEBUG(dbgs() << "  ... found throwing instruction\n");
         return std::nullopt;
       }
@@ -1623,11 +1629,12 @@ struct DSEState {
         LLVM_DEBUG(dbgs() << "    ... found not loop invariant self access\n");
         return std::nullopt;
       }
-      // Otherwise, for the KillingDef and MaybeDeadAccess we only have to check
-      // if it reads the memory location.
+      // Otherwise, for the "KillingLocWrapper.MemDef" and MaybeDeadAccess we
+      // only have to check if it reads the memory location.
       // TODO: It would probably be better to check for self-reads before
       // calling the function.
-      if (KillingLoc.MemDef == UseAccess || MaybeDeadAccess == UseAccess) {
+      if (KillingLocWrapper.MemDef == UseAccess ||
+          MaybeDeadAccess == UseAccess) {
         LLVM_DEBUG(dbgs() << "    ... skipping killing def/dom access\n");
         continue;
       }
@@ -1647,7 +1654,8 @@ struct DSEState {
           BasicBlock *MaybeKillingBlock = UseInst->getParent();
           if (PostOrderNumbers.find(MaybeKillingBlock)->second <
               PostOrderNumbers.find(MaybeDeadAccess->getBlock())->second) {
-            if (!isInvisibleToCallerAfterRet(KillingLoc.UnderlyingObject)) {
+            if (!isInvisibleToCallerAfterRet(
+                    KillingLocWrapper.UnderlyingObject)) {
               LLVM_DEBUG(dbgs()
                          << "    ... found killing def " << *UseInst << "\n");
               KillingDefs.insert(UseInst);
@@ -1665,7 +1673,7 @@ struct DSEState {
     // For accesses to locations visible after the function returns, make sure
     // that the location is dead (=overwritten) along all paths from
     // MaybeDeadAccess to the exit.
-    if (!isInvisibleToCallerAfterRet(KillingLoc.UnderlyingObject)) {
+    if (!isInvisibleToCallerAfterRet(KillingLocWrapper.UnderlyingObject)) {
       SmallPtrSet<BasicBlock *, 16> KillingBlocks;
       for (Instruction *KD : KillingDefs)
         KillingBlocks.insert(KD->getParent());
@@ -2178,9 +2186,9 @@ struct DSEState {
     return MadeChange;
   }
 
-  // Try to eliminate dead defs killed by `KillingLocWrapper` and return the
-  // change state: whether make any change, and whether make a partial delete
-  // by a non memory-terminator instruction.
+  // Try to eliminate dead defs that access `KillingLocWrapper.MemLoc` and are
+  // killed by `KillingLocWrapper.MemDef`. Return whether
+  // any changes were made, and whether `KillingLocWrapper.DefInst` was deleted.
   std::pair<bool, bool>
   eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper);
 
@@ -2192,15 +2200,17 @@ struct DSEState {
 std::pair<bool, bool>
 DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
   bool Changed = false;
-  bool Shortened = false;
+  bool DeletedKillingLoc = false;
   unsigned ScanLimit = MemorySSAScanLimit;
   unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
   unsigned PartialLimit = MemorySSAPartialStoreLimit;
-  // Worklist of MemoryAccesses that may be killed by KillingDef.
+  // Worklist of MemoryAccesses that may be killed by
+  // "KillingLocWrapper.MemDef".
   SmallSetVector<MemoryAccess *, 8> ToCheck;
-  ToCheck.insert(KillingLocWrapper.GetDefiningAccess());
+  ToCheck.insert(KillingLocWrapper.MemDef->getDefiningAccess());
 
-  // Check if MemoryAccesses in the worklist are killed by KillingDef.
+  // Check if MemoryAccesses in the worklist are killed by
+  // "KillingLocWrapper.MemDef".
   for (unsigned I = 0; I < ToCheck.size(); I++) {
     MemoryAccess *Current = ToCheck[I];
     if (SkipStores.count(Current))
@@ -2234,7 +2244,7 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
         getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst()));
     MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation;
     LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n");
-    ToCheck.insert(DeadLocWrapper.GetDefiningAccess());
+    ToCheck.insert(DeadLocWrapper.MemDef->getDefiningAccess());
     NumGetDomMemoryDefPassed++;
 
     if (!DebugCounter::shouldExecute(MemorySSACounter))
@@ -2281,7 +2291,7 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
             DeadSI->setOperand(0, Merged);
             ++NumModifiedStores;
             Changed = true;
-            Shortened = true;
+            DeletedKillingLoc = true;
 
             // Remove killing store and remove any outstanding overlap
             // intervals for the updated store.
@@ -2303,33 +2313,35 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
       }
     }
   }
-  return {Changed, Shortened};
+  return {Changed, DeletedKillingLoc};
 }
 
 bool DSEState::eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper) {
   if (!KillingDefWrapper.DefinedLocation.has_value()) {
     LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
-                      << *DefInst << "\n");
+                      << *KillingDefWrapper.DefInst << "\n");
     return false;
   }
-  LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " << *MemDef
-                    << " (" << *DefInst << ")\n");
 
   auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation;
-  auto [Changed, Shortened] = eliminateDeadDefs(KillingLocWrapper);
+  LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
+                    << *KillingLocWrapper.MemDef << " ("
+                    << *KillingLocWrapper.DefInst << ")\n");
+  auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);
 
   // Check if the store is a no-op.
-  if (!Shortened && storeIsNoop(KillingLocWrapper.MemDef,
-                                KillingLocWrapper.UnderlyingObject)) {
-    LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: " << *DefInst
-                      << '\n');
+  if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
+                                        KillingLocWrapper.UnderlyingObject)) {
+    LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: "
+                      << *KillingLocWrapper.DefInst << '\n');
     deleteDeadInstruction(KillingLocWrapper.DefInst);
     NumRedundantStores++;
     return true;
   }
   // Can we form a calloc from a memset/malloc pair?
-  if (!Shortened && tryFoldIntoCalloc(KillingLocWrapper.MemDef,
-                                      KillingLocWrapper.UnderlyingObject)) {
+  if (!DeletedKillingLoc &&
+      tryFoldIntoCalloc(KillingLocWrapper.MemDef,
+                        KillingLocWrapper.UnderlyingObject)) {
     LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
                       << "  DEAD: " << *KillingLocWrapper.DefInst << '\n');
     deleteDeadInstruction(KillingLocWrapper.DefInst);

>From c488bf89d7f495db614f8aeed92cb4ee56b02909 Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Tue, 20 Aug 2024 17:46:49 +0000
Subject: [PATCH 7/9] Revert changes in getDomMemoryDef()

---
 .../Scalar/DeadStoreElimination.cpp           | 138 ++++++++----------
 1 file changed, 63 insertions(+), 75 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 585fd4c037d536..f071827f253369 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -1357,34 +1357,34 @@ struct DSEState {
     return true;
   }
 
-  // Find a MemoryDef writing to \p KillingLocWrapper.MemLoc and dominating
-  // \p StartAccess, with no read access between them or on any other path to
-  // a function exit block if \p KillingLocWrapper.MemLoc is not accessible
-  // after the function returns. If there is no such MemoryDef, return
-  // std::nullopt. The returned value may not (completely) overwrite
-  // \p KillingLocWrapper.MemLoc. Currently we bail out when we encounter
-  // an aliasing MemoryUse (read).
+  // Find a MemoryDef writing to \p KillingLoc and dominating \p StartAccess,
+  // with no read access between them or on any other path to a function exit
+  // block if \p KillingLoc is not accessible after the function returns. If
+  // there is no such MemoryDef, return std::nullopt. The returned value may not
+  // (completely) overwrite \p KillingLoc. Currently we bail out when we
+  // encounter an aliasing MemoryUse (read).
   std::optional<MemoryAccess *>
-  getDomMemoryDef(MemoryLocationWrapper &KillingLocWrapper,
-                  MemoryAccess *StartAccess, unsigned &ScanLimit,
-                  unsigned &WalkerStepLimit, unsigned &PartialLimit) {
+  getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess,
+                  const MemoryLocation &KillingLoc, const Value *KillingUndObj,
+                  unsigned &ScanLimit, unsigned &WalkerStepLimit,
+                  bool IsMemTerm, unsigned &PartialLimit) {
     if (ScanLimit == 0 || WalkerStepLimit == 0) {
       LLVM_DEBUG(dbgs() << "\n    ...  hit scan limit\n");
       return std::nullopt;
     }
 
     MemoryAccess *Current = StartAccess;
+    Instruction *KillingI = KillingDef->getMemoryInst();
     LLVM_DEBUG(dbgs() << "  trying to get dominating access\n");
 
-    // Only optimize defining access of "KillingLocWrapper.MemDef" when directly
-    // starting at its defining access. The defining access also must only
-    // access "KillingLocWrapper.MemLoc". At the moment we only support
-    // instructions with a single write location, so it should be sufficient
-    // to disable optimizations for instructions that also read from memory.
-    bool CanOptimize =
-        OptimizeMemorySSA &&
-        KillingLocWrapper.MemDef->getDefiningAccess() == StartAccess &&
-        !KillingLocWrapper.DefInst->mayReadFromMemory();
+    // Only optimize defining access of KillingDef when directly starting at its
+    // defining access. The defining access also must only access KillingLoc. At
+    // the moment we only support instructions with a single write location, so
+    // it should be sufficient to disable optimizations for instructions that
+    // also read from memory.
+    bool CanOptimize = OptimizeMemorySSA &&
+                       KillingDef->getDefiningAccess() == StartAccess &&
+                       !KillingI->mayReadFromMemory();
 
     // Find the next clobbering Mod access for DefLoc, starting at StartAccess.
     std::optional<MemoryLocation> CurrentLoc;
@@ -1400,19 +1400,17 @@ struct DSEState {
       // Reached TOP.
       if (MSSA.isLiveOnEntryDef(Current)) {
         LLVM_DEBUG(dbgs() << "   ...  found LiveOnEntryDef\n");
-        if (CanOptimize &&
-            Current != KillingLocWrapper.MemDef->getDefiningAccess())
+        if (CanOptimize && Current != KillingDef->getDefiningAccess())
           // The first clobbering def is... none.
-          KillingLocWrapper.MemDef->setOptimized(Current);
+          KillingDef->setOptimized(Current);
         return std::nullopt;
       }
 
       // Cost of a step. Accesses in the same block are more likely to be valid
       // candidates for elimination, hence consider them cheaper.
-      unsigned StepCost =
-          KillingLocWrapper.MemDef->getBlock() == Current->getBlock()
-              ? MemorySSASameBBStepCost
-              : MemorySSAOtherBBStepCost;
+      unsigned StepCost = KillingDef->getBlock() == Current->getBlock()
+                              ? MemorySSASameBBStepCost
+                              : MemorySSAOtherBBStepCost;
       if (WalkerStepLimit <= StepCost) {
         LLVM_DEBUG(dbgs() << "   ...  hit walker step limit\n");
         return std::nullopt;
@@ -1427,27 +1425,25 @@ struct DSEState {
       }
 
       // Below, check if CurrentDef is a valid candidate to be eliminated by
-      // "KillingLocWrapper.MemDef". If it is not, check the next candidate.
+      // KillingDef. If it is not, check the next candidate.
       MemoryDef *CurrentDef = cast<MemoryDef>(Current);
       Instruction *CurrentI = CurrentDef->getMemoryInst();
 
-      if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(
-                                     KillingLocWrapper.UnderlyingObject))) {
+      if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(KillingUndObj))) {
         CanOptimize = false;
         continue;
       }
 
       // Before we try to remove anything, check for any extra throwing
       // instructions that block us from DSEing
-      if (mayThrowBetween(KillingLocWrapper.DefInst, CurrentI,
-                          KillingLocWrapper.UnderlyingObject)) {
+      if (mayThrowBetween(KillingI, CurrentI, KillingUndObj)) {
         LLVM_DEBUG(dbgs() << "  ... skip, may throw!\n");
         return std::nullopt;
       }
 
       // Check for anything that looks like it will be a barrier to further
       // removal
-      if (isDSEBarrier(KillingLocWrapper.UnderlyingObject, CurrentI)) {
+      if (isDSEBarrier(KillingUndObj, CurrentI)) {
         LLVM_DEBUG(dbgs() << "  ... skip, barrier\n");
         return std::nullopt;
       }
@@ -1456,19 +1452,16 @@ struct DSEState {
       // clobber, bail out, as the path is not profitable. We skip this check
       // for intrinsic calls, because the code knows how to handle memcpy
       // intrinsics.
-      if (!isa<IntrinsicInst>(CurrentI) &&
-          isReadClobber(KillingLocWrapper.MemLoc, CurrentI))
+      if (!isa<IntrinsicInst>(CurrentI) && isReadClobber(KillingLoc, CurrentI))
         return std::nullopt;
 
       // Quick check if there are direct uses that are read-clobbers.
-      if (any_of(Current->uses(),
-                 [this, &KillingLocWrapper, StartAccess](Use &U) {
-                   if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(U.getUser()))
-                     return !MSSA.dominates(StartAccess, UseOrDef) &&
-                            isReadClobber(KillingLocWrapper.MemLoc,
-                                          UseOrDef->getMemoryInst());
-                   return false;
-                 })) {
+      if (any_of(Current->uses(), [this, &KillingLoc, StartAccess](Use &U) {
+            if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(U.getUser()))
+              return !MSSA.dominates(StartAccess, UseOrDef) &&
+                     isReadClobber(KillingLoc, UseOrDef->getMemoryInst());
+            return false;
+          })) {
         LLVM_DEBUG(dbgs() << "   ...  found a read clobber\n");
         return std::nullopt;
       }
@@ -1484,36 +1477,32 @@ struct DSEState {
       // AliasAnalysis does not account for loops. Limit elimination to
       // candidates for which we can guarantee they always store to the same
       // memory location and not located in different loops.
-      if (!isGuaranteedLoopIndependent(CurrentI, KillingLocWrapper.DefInst,
-                                       *CurrentLoc)) {
+      if (!isGuaranteedLoopIndependent(CurrentI, KillingI, *CurrentLoc)) {
         LLVM_DEBUG(dbgs() << "  ... not guaranteed loop independent\n");
         CanOptimize = false;
         continue;
       }
 
-      if (isMemTerminatorInst(KillingLocWrapper.DefInst)) {
-        // If "KillingLocWrapper.DefInst" is a memory terminator
-        // (e.g. lifetime.end), check the next candidate if the current
-        // Current does not write the same underlying object as the terminator.
-        if (!isMemTerminator(*CurrentLoc, CurrentI,
-                             KillingLocWrapper.DefInst)) {
+      if (IsMemTerm) {
+        // If the killing def is a memory terminator (e.g. lifetime.end), check
+        // the next candidate if the current Current does not write the same
+        // underlying object as the terminator.
+        if (!isMemTerminator(*CurrentLoc, CurrentI, KillingI)) {
           CanOptimize = false;
           continue;
         }
       } else {
         int64_t KillingOffset = 0;
         int64_t DeadOffset = 0;
-        auto OR = isOverwrite(KillingLocWrapper.DefInst, CurrentI,
-                              KillingLocWrapper.MemLoc, *CurrentLoc,
+        auto OR = isOverwrite(KillingI, CurrentI, KillingLoc, *CurrentLoc,
                               KillingOffset, DeadOffset);
         if (CanOptimize) {
-          // CurrentDef is the earliest write clobber of
-          // "KillingLocWrapper.MemDef". Use it as optimized access. Do not
-          // optimize if CurrentDef is already the defining access of
-          // "KillingLocWrapper.MemDef".
-          if (CurrentDef != KillingLocWrapper.MemDef->getDefiningAccess() &&
+          // CurrentDef is the earliest write clobber of KillingDef. Use it as
+          // optimized access. Do not optimize if CurrentDef is already the
+          // defining access of KillingDef.
+          if (CurrentDef != KillingDef->getDefiningAccess() &&
               (OR == OW_Complete || OR == OW_MaybePartial))
-            KillingLocWrapper.MemDef->setOptimized(CurrentDef);
+            KillingDef->setOptimized(CurrentDef);
 
           // Once a may-aliasing def is encountered do not set an optimized
           // access.
@@ -1521,15 +1510,15 @@ struct DSEState {
             CanOptimize = false;
         }
 
-        // If Current does not write to the same object as
-        // "KillingLocWrapper.MemDef", check the next candidate.
+        // If Current does not write to the same object as KillingDef, check
+        // the next candidate.
         if (OR == OW_Unknown || OR == OW_None)
           continue;
         else if (OR == OW_MaybePartial) {
-          // If "KillingLocWrapper.MemDef" only partially overwrites Current,
-          // check the next candidate if the partial step limit is exceeded.
-          // This aggressively limits the number of candidates for partial store
-          // elimination, which are less likely to be removable in the end.
+          // If KillingDef only partially overwrites Current, check the next
+          // candidate if the partial step limit is exceeded. This aggressively
+          // limits the number of candidates for partial store elimination,
+          // which are less likely to be removable in the end.
           if (PartialLimit <= 1) {
             WalkerStepLimit -= 1;
             LLVM_DEBUG(dbgs() << "   ... reached partial limit ... continue with next access\n");
@@ -1546,7 +1535,7 @@ struct DSEState {
     // the blocks with killing (=completely overwriting MemoryDefs) and check if
     // they cover all paths from MaybeDeadAccess to any function exit.
     SmallPtrSet<Instruction *, 16> KillingDefs;
-    KillingDefs.insert(KillingLocWrapper.DefInst);
+    KillingDefs.insert(KillingDef->getMemoryInst());
     MemoryAccess *MaybeDeadAccess = Current;
     MemoryLocation MaybeDeadLoc = *CurrentLoc;
     Instruction *MaybeDeadI = cast<MemoryDef>(MaybeDeadAccess)->getMemoryInst();
@@ -1608,8 +1597,7 @@ struct DSEState {
         continue;
       }
 
-      if (UseInst->mayThrow() &&
-          !isInvisibleToCallerOnUnwind(KillingLocWrapper.UnderlyingObject)) {
+      if (UseInst->mayThrow() && !isInvisibleToCallerOnUnwind(KillingUndObj)) {
         LLVM_DEBUG(dbgs() << "  ... found throwing instruction\n");
         return std::nullopt;
       }
@@ -1629,12 +1617,11 @@ struct DSEState {
         LLVM_DEBUG(dbgs() << "    ... found not loop invariant self access\n");
         return std::nullopt;
       }
-      // Otherwise, for the "KillingLocWrapper.MemDef" and MaybeDeadAccess we
-      // only have to check if it reads the memory location.
+      // Otherwise, for the KillingDef and MaybeDeadAccess we only have to check
+      // if it reads the memory location.
       // TODO: It would probably be better to check for self-reads before
       // calling the function.
-      if (KillingLocWrapper.MemDef == UseAccess ||
-          MaybeDeadAccess == UseAccess) {
+      if (KillingDef == UseAccess || MaybeDeadAccess == UseAccess) {
         LLVM_DEBUG(dbgs() << "    ... skipping killing def/dom access\n");
         continue;
       }
@@ -1654,8 +1641,7 @@ struct DSEState {
           BasicBlock *MaybeKillingBlock = UseInst->getParent();
           if (PostOrderNumbers.find(MaybeKillingBlock)->second <
               PostOrderNumbers.find(MaybeDeadAccess->getBlock())->second) {
-            if (!isInvisibleToCallerAfterRet(
-                    KillingLocWrapper.UnderlyingObject)) {
+            if (!isInvisibleToCallerAfterRet(KillingUndObj)) {
               LLVM_DEBUG(dbgs()
                          << "    ... found killing def " << *UseInst << "\n");
               KillingDefs.insert(UseInst);
@@ -1673,7 +1659,7 @@ struct DSEState {
     // For accesses to locations visible after the function returns, make sure
     // that the location is dead (=overwritten) along all paths from
     // MaybeDeadAccess to the exit.
-    if (!isInvisibleToCallerAfterRet(KillingLocWrapper.UnderlyingObject)) {
+    if (!isInvisibleToCallerAfterRet(KillingUndObj)) {
       SmallPtrSet<BasicBlock *, 16> KillingBlocks;
       for (Instruction *KD : KillingDefs)
         KillingBlocks.insert(KD->getParent());
@@ -2216,7 +2202,9 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
     if (SkipStores.count(Current))
       continue;
     std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef(
-        KillingLocWrapper, Current, ScanLimit, WalkerStepLimit, PartialLimit);
+        KillingLocWrapper.MemDef, Current, KillingLocWrapper.MemLoc,
+        KillingLocWrapper.UnderlyingObject, ScanLimit, WalkerStepLimit,
+        isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit);
 
     if (!MaybeDeadAccess) {
       LLVM_DEBUG(dbgs() << "  finished walk\n");

>From ed91944e6639a9306da5885a986d26186cc72578 Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Tue, 20 Aug 2024 19:20:31 +0000
Subject: [PATCH 8/9] Fix a mistake in LLVM_DEBUG

---
 llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index f071827f253369..66196f07686356 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -2293,8 +2293,8 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
       }
       if (OR == OW_Complete) {
         LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
-                          << *DeadLocWrapper.DefInst
-                          << "\n  KILLER: " << *DefInst << '\n');
+                          << *DeadLocWrapper.DefInst << "\n  KILLER: "
+                          << *KillingLocWrapper.DefInst << '\n');
         deleteDeadInstruction(DeadLocWrapper.DefInst);
         ++NumFastStores;
         Changed = true;

>From adaf8c4f501d119d3c9fa1ce79e7fb5292383051 Mon Sep 17 00:00:00 2001
From: Haopeng Liu <haopliu at google.com>
Date: Wed, 21 Aug 2024 17:06:37 +0000
Subject: [PATCH 9/9] Mark eliminateDeadDefs() parameter with const

---
 .../lib/Transforms/Scalar/DeadStoreElimination.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 66196f07686356..589ecb762e5186 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -808,8 +808,7 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
 
 // A memory location wrapper that represents a MemoryLocation, `MemLoc`,
 // defined by `MemDef`.
-class MemoryLocationWrapper {
-public:
+struct MemoryLocationWrapper {
   MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef)
       : MemLoc(MemLoc), MemDef(MemDef) {
     assert(MemLoc.Ptr && "MemLoc should be not null");
@@ -825,8 +824,7 @@ class MemoryLocationWrapper {
 
 // A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
 // defined by this MemoryDef.
-class MemoryDefWrapper {
-public:
+struct MemoryDefWrapper {
   MemoryDefWrapper(MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
     DefInst = MemDef->getMemoryInst();
     if (MemLoc.has_value())
@@ -2176,15 +2174,15 @@ struct DSEState {
   // killed by `KillingLocWrapper.MemDef`. Return whether
   // any changes were made, and whether `KillingLocWrapper.DefInst` was deleted.
   std::pair<bool, bool>
-  eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper);
+  eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper);
 
   // Try to eliminate dead defs killed by `KillingDefWrapper` and return the
   // change state: whether make any change.
-  bool eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper);
+  bool eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper);
 };
 
 std::pair<bool, bool>
-DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
+DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
   bool Changed = false;
   bool DeletedKillingLoc = false;
   unsigned ScanLimit = MemorySSAScanLimit;
@@ -2304,7 +2302,7 @@ DSEState::eliminateDeadDefs(MemoryLocationWrapper &KillingLocWrapper) {
   return {Changed, DeletedKillingLoc};
 }
 
-bool DSEState::eliminateDeadDefs(MemoryDefWrapper &KillingDefWrapper) {
+bool DSEState::eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper) {
   if (!KillingDefWrapper.DefinedLocation.has_value()) {
     LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
                       << *KillingDefWrapper.DefInst << "\n");



More information about the llvm-commits mailing list