[llvm] Refactor DSE (PR #100956)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 28 18:31:09 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Haopeng Liu (haopliu)

<details>
<summary>Changes</summary>

Refactor DSE with MemoryDefWrapper and MemoryLocationWrapper.

Normally, one MemoryDef accesses one MemoryLocation. With "initializes" attribute, one MemoryDef (like call instruction) could initialize multiple MemoryLocations.

Refactor DSE as a preparation to apply "initializes" attribute in DSE in a follow-up PR.

---

Patch is 26.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/100956.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp (+245-191) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 931606c6f8fe1..e64b7da818a98 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -806,6 +806,63 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
   return false;
 }
 
+/// Returns true if \p I is a memory terminator instruction like
+/// llvm.lifetime.end or free.
+bool isMemTerminatorInst(Instruction *I, const TargetLibraryInfo &TLI) {
+  auto *CB = dyn_cast<CallBase>(I);
+  return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
+                getFreedOperand(CB, &TLI) != nullptr);
+}
+
+struct DSEState;
+enum class ChangeStateEnum : uint8_t {
+  NoChange,
+  DeleteByMemTerm,
+  CompleteDeleteByNonMemTerm,
+  PartiallyDeleteByNonMemTerm,
+};
+
+// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
+// defined by `MemDef`.
+class MemoryLocationWrapper {
+public:
+  MemoryLocationWrapper(MemoryLocation MemLoc, DSEState &State,
+                        MemoryDef *MemDef)
+      : MemLoc(MemLoc), State(State), MemDef(MemDef) {
+    assert(MemLoc.Ptr && "MemLoc should be not null");
+    UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
+    DefInst = MemDef->getMemoryInst();
+  }
+
+  // Try to eliminate dead defs killed by this MemoryLocation and return the
+  // change state.
+  ChangeStateEnum eliminateDeadDefs();
+  MemoryAccess *GetDefiningAccess() const {
+    return MemDef->getDefiningAccess();
+  }
+
+  MemoryLocation MemLoc;
+  const Value *UnderlyingObject;
+  DSEState &State;
+  MemoryDef *MemDef;
+  Instruction *DefInst;
+};
+
+// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
+// defined by this MemoryDef.
+class MemoryDefWrapper {
+public:
+  MemoryDefWrapper(MemoryDef *MemDef, DSEState &State);
+  // Try to eliminate dead defs killed by this MemoryDef and return the
+  // change state.
+  bool eliminateDeadDefs();
+
+  MemoryDef *MemDef;
+  Instruction *DefInst;
+  DSEState &State;
+  SmallVector<MemoryLocationWrapper, 1> DefinedLocations;
+};
+
 struct DSEState {
   Function &F;
   AliasAnalysis &AA;
@@ -883,7 +940,7 @@ struct DSEState {
 
         auto *MD = dyn_cast_or_null<MemoryDef>(MA);
         if (MD && MemDefs.size() < MemorySSADefsPerBlockLimit &&
-            (getLocForWrite(&I) || isMemTerminatorInst(&I)))
+            (getLocForWrite(&I) || isMemTerminatorInst(&I, TLI)))
           MemDefs.push_back(MD);
       }
     }
@@ -1225,14 +1282,6 @@ struct DSEState {
     return std::nullopt;
   }
 
-  /// Returns true if \p I is a memory terminator instruction like
-  /// llvm.lifetime.end or free.
-  bool isMemTerminatorInst(Instruction *I) const {
-    auto *CB = dyn_cast<CallBase>(I);
-    return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
-                  getFreedOperand(CB, &TLI) != nullptr);
-  }
-
   /// Returns true if \p MaybeTerm is a memory terminator for \p Loc from
   /// instruction \p AccessI.
   bool isMemTerminator(const MemoryLocation &Loc, Instruction *AccessI,
@@ -1325,17 +1374,15 @@ struct DSEState {
   // (completely) overwrite \p KillingLoc. Currently we bail out when we
   // encounter an aliasing MemoryUse (read).
   std::optional<MemoryAccess *>
-  getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess,
-                  const MemoryLocation &KillingLoc, const Value *KillingUndObj,
+  getDomMemoryDef(MemoryLocationWrapper &KillingLoc, MemoryAccess *StartAccess,
                   unsigned &ScanLimit, unsigned &WalkerStepLimit,
-                  bool IsMemTerm, unsigned &PartialLimit) {
+                  unsigned &PartialLimit) {
     if (ScanLimit == 0 || WalkerStepLimit == 0) {
       LLVM_DEBUG(dbgs() << "\n    ...  hit scan limit\n");
       return std::nullopt;
     }
 
     MemoryAccess *Current = StartAccess;
-    Instruction *KillingI = KillingDef->getMemoryInst();
     LLVM_DEBUG(dbgs() << "  trying to get dominating access\n");
 
     // Only optimize defining access of KillingDef when directly starting at its
@@ -1344,8 +1391,8 @@ struct DSEState {
     // it should be sufficient to disable optimizations for instructions that
     // also read from memory.
     bool CanOptimize = OptimizeMemorySSA &&
-                       KillingDef->getDefiningAccess() == StartAccess &&
-                       !KillingI->mayReadFromMemory();
+                       KillingLoc.GetDefiningAccess() == StartAccess &&
+                       !KillingLoc.DefInst->mayReadFromMemory();
 
     // Find the next clobbering Mod access for DefLoc, starting at StartAccess.
     std::optional<MemoryLocation> CurrentLoc;
@@ -1361,15 +1408,15 @@ struct DSEState {
       // Reached TOP.
       if (MSSA.isLiveOnEntryDef(Current)) {
         LLVM_DEBUG(dbgs() << "   ...  found LiveOnEntryDef\n");
-        if (CanOptimize && Current != KillingDef->getDefiningAccess())
+        if (CanOptimize && Current != KillingLoc.GetDefiningAccess())
           // The first clobbering def is... none.
-          KillingDef->setOptimized(Current);
+          KillingLoc.MemDef->setOptimized(Current);
         return std::nullopt;
       }
 
       // Cost of a step. Accesses in the same block are more likely to be valid
       // candidates for elimination, hence consider them cheaper.
-      unsigned StepCost = KillingDef->getBlock() == Current->getBlock()
+      unsigned StepCost = KillingLoc.MemDef->getBlock() == Current->getBlock()
                               ? MemorySSASameBBStepCost
                               : MemorySSAOtherBBStepCost;
       if (WalkerStepLimit <= StepCost) {
@@ -1390,21 +1437,23 @@ struct DSEState {
       MemoryDef *CurrentDef = cast<MemoryDef>(Current);
       Instruction *CurrentI = CurrentDef->getMemoryInst();
 
-      if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(KillingUndObj))) {
+      if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(
+                                     KillingLoc.UnderlyingObject))) {
         CanOptimize = false;
         continue;
       }
 
       // Before we try to remove anything, check for any extra throwing
       // instructions that block us from DSEing
-      if (mayThrowBetween(KillingI, CurrentI, KillingUndObj)) {
+      if (mayThrowBetween(KillingLoc.DefInst, CurrentI,
+                          KillingLoc.UnderlyingObject)) {
         LLVM_DEBUG(dbgs() << "  ... skip, may throw!\n");
         return std::nullopt;
       }
 
       // Check for anything that looks like it will be a barrier to further
       // removal
-      if (isDSEBarrier(KillingUndObj, CurrentI)) {
+      if (isDSEBarrier(KillingLoc.UnderlyingObject, CurrentI)) {
         LLVM_DEBUG(dbgs() << "  ... skip, barrier\n");
         return std::nullopt;
       }
@@ -1413,14 +1462,16 @@ struct DSEState {
       // clobber, bail out, as the path is not profitable. We skip this check
       // for intrinsic calls, because the code knows how to handle memcpy
       // intrinsics.
-      if (!isa<IntrinsicInst>(CurrentI) && isReadClobber(KillingLoc, CurrentI))
+      if (!isa<IntrinsicInst>(CurrentI) &&
+          isReadClobber(KillingLoc.MemLoc, CurrentI))
         return std::nullopt;
 
       // Quick check if there are direct uses that are read-clobbers.
       if (any_of(Current->uses(), [this, &KillingLoc, StartAccess](Use &U) {
             if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(U.getUser()))
               return !MSSA.dominates(StartAccess, UseOrDef) &&
-                     isReadClobber(KillingLoc, UseOrDef->getMemoryInst());
+                     isReadClobber(KillingLoc.MemLoc,
+                                   UseOrDef->getMemoryInst());
             return false;
           })) {
         LLVM_DEBUG(dbgs() << "   ...  found a read clobber\n");
@@ -1438,32 +1489,33 @@ struct DSEState {
       // AliasAnalysis does not account for loops. Limit elimination to
       // candidates for which we can guarantee they always store to the same
       // memory location and not located in different loops.
-      if (!isGuaranteedLoopIndependent(CurrentI, KillingI, *CurrentLoc)) {
+      if (!isGuaranteedLoopIndependent(CurrentI, KillingLoc.DefInst,
+                                       *CurrentLoc)) {
         LLVM_DEBUG(dbgs() << "  ... not guaranteed loop independent\n");
         CanOptimize = false;
         continue;
       }
 
-      if (IsMemTerm) {
+      if (isMemTerminatorInst(KillingLoc.DefInst, TLI)) {
         // If the killing def is a memory terminator (e.g. lifetime.end), check
         // the next candidate if the current Current does not write the same
         // underlying object as the terminator.
-        if (!isMemTerminator(*CurrentLoc, CurrentI, KillingI)) {
+        if (!isMemTerminator(*CurrentLoc, CurrentI, KillingLoc.DefInst)) {
           CanOptimize = false;
           continue;
         }
       } else {
         int64_t KillingOffset = 0;
         int64_t DeadOffset = 0;
-        auto OR = isOverwrite(KillingI, CurrentI, KillingLoc, *CurrentLoc,
-                              KillingOffset, DeadOffset);
+        auto OR = isOverwrite(KillingLoc.DefInst, CurrentI, KillingLoc.MemLoc,
+                              *CurrentLoc, KillingOffset, DeadOffset);
         if (CanOptimize) {
           // CurrentDef is the earliest write clobber of KillingDef. Use it as
           // optimized access. Do not optimize if CurrentDef is already the
           // defining access of KillingDef.
-          if (CurrentDef != KillingDef->getDefiningAccess() &&
+          if (CurrentDef != KillingLoc.GetDefiningAccess() &&
               (OR == OW_Complete || OR == OW_MaybePartial))
-            KillingDef->setOptimized(CurrentDef);
+            KillingLoc.MemDef->setOptimized(CurrentDef);
 
           // Once a may-aliasing def is encountered do not set an optimized
           // access.
@@ -1496,7 +1548,7 @@ struct DSEState {
     // the blocks with killing (=completely overwriting MemoryDefs) and check if
     // they cover all paths from MaybeDeadAccess to any function exit.
     SmallPtrSet<Instruction *, 16> KillingDefs;
-    KillingDefs.insert(KillingDef->getMemoryInst());
+    KillingDefs.insert(KillingLoc.DefInst);
     MemoryAccess *MaybeDeadAccess = Current;
     MemoryLocation MaybeDeadLoc = *CurrentLoc;
     Instruction *MaybeDeadI = cast<MemoryDef>(MaybeDeadAccess)->getMemoryInst();
@@ -1558,7 +1610,8 @@ struct DSEState {
         continue;
       }
 
-      if (UseInst->mayThrow() && !isInvisibleToCallerOnUnwind(KillingUndObj)) {
+      if (UseInst->mayThrow() &&
+          !isInvisibleToCallerOnUnwind(KillingLoc.UnderlyingObject)) {
         LLVM_DEBUG(dbgs() << "  ... found throwing instruction\n");
         return std::nullopt;
       }
@@ -1582,7 +1635,7 @@ struct DSEState {
       // if it reads the memory location.
       // TODO: It would probably be better to check for self-reads before
       // calling the function.
-      if (KillingDef == UseAccess || MaybeDeadAccess == UseAccess) {
+      if (KillingLoc.MemDef == UseAccess || MaybeDeadAccess == UseAccess) {
         LLVM_DEBUG(dbgs() << "    ... skipping killing def/dom access\n");
         continue;
       }
@@ -1602,7 +1655,7 @@ struct DSEState {
           BasicBlock *MaybeKillingBlock = UseInst->getParent();
           if (PostOrderNumbers.find(MaybeKillingBlock)->second <
               PostOrderNumbers.find(MaybeDeadAccess->getBlock())->second) {
-            if (!isInvisibleToCallerAfterRet(KillingUndObj)) {
+            if (!isInvisibleToCallerAfterRet(KillingLoc.UnderlyingObject)) {
               LLVM_DEBUG(dbgs()
                          << "    ... found killing def " << *UseInst << "\n");
               KillingDefs.insert(UseInst);
@@ -1620,7 +1673,7 @@ struct DSEState {
     // For accesses to locations visible after the function returns, make sure
     // that the location is dead (=overwritten) along all paths from
     // MaybeDeadAccess to the exit.
-    if (!isInvisibleToCallerAfterRet(KillingUndObj)) {
+    if (!isInvisibleToCallerAfterRet(KillingLoc.UnderlyingObject)) {
       SmallPtrSet<BasicBlock *, 16> KillingBlocks;
       for (Instruction *KD : KillingDefs)
         KillingBlocks.insert(KD->getParent());
@@ -2134,180 +2187,181 @@ struct DSEState {
   }
 };
 
-static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
-                                DominatorTree &DT, PostDominatorTree &PDT,
-                                const TargetLibraryInfo &TLI,
-                                const LoopInfo &LI) {
-  bool MadeChange = false;
+MemoryDefWrapper::MemoryDefWrapper(MemoryDef *MemDef, DSEState &State)
+    : MemDef(MemDef), State(State) {
+  DefInst = MemDef->getMemoryInst();
 
-  DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
-  // For each store:
-  for (unsigned I = 0; I < State.MemDefs.size(); I++) {
-    MemoryDef *KillingDef = State.MemDefs[I];
-    if (State.SkipStores.count(KillingDef))
+  if (isMemTerminatorInst(DefInst, State.TLI)) {
+    if (auto KillingLoc = State.getLocForTerminator(DefInst)) {
+      DefinedLocations.push_back(
+          MemoryLocationWrapper(KillingLoc->first, State, MemDef));
+    }
+    return;
+  }
+
+  if (auto KillingLoc = State.getLocForWrite(DefInst)) {
+    DefinedLocations.push_back(
+        MemoryLocationWrapper(*KillingLoc, State, MemDef));
+  }
+}
+
+ChangeStateEnum MemoryLocationWrapper::eliminateDeadDefs() {
+  ChangeStateEnum ChangeState = ChangeStateEnum::NoChange;
+  unsigned ScanLimit = MemorySSAScanLimit;
+  unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
+  unsigned PartialLimit = MemorySSAPartialStoreLimit;
+  // Worklist of MemoryAccesses that may be killed by KillingDef.
+  SmallSetVector<MemoryAccess *, 8> ToCheck;
+  ToCheck.insert(GetDefiningAccess());
+
+  // Check if MemoryAccesses in the worklist are killed by KillingDef.
+  for (unsigned I = 0; I < ToCheck.size(); I++) {
+    MemoryAccess *Current = ToCheck[I];
+    if (State.SkipStores.count(Current))
       continue;
-    Instruction *KillingI = KillingDef->getMemoryInst();
+    std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
+        *this, Current, ScanLimit, WalkerStepLimit, PartialLimit);
 
-    std::optional<MemoryLocation> MaybeKillingLoc;
-    if (State.isMemTerminatorInst(KillingI)) {
-      if (auto KillingLoc = State.getLocForTerminator(KillingI))
-        MaybeKillingLoc = KillingLoc->first;
-    } else {
-      MaybeKillingLoc = State.getLocForWrite(KillingI);
+    if (!MaybeDeadAccess) {
+      LLVM_DEBUG(dbgs() << "  finished walk\n");
+      continue;
     }
-
-    if (!MaybeKillingLoc) {
-      LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
-                        << *KillingI << "\n");
+    MemoryAccess *DeadAccess = *MaybeDeadAccess;
+    LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
+    if (isa<MemoryPhi>(DeadAccess)) {
+      LLVM_DEBUG(dbgs() << "\n  ... adding incoming values to worklist\n");
+      for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
+        MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
+        BasicBlock *IncomingBlock = IncomingAccess->getBlock();
+        BasicBlock *PhiBlock = DeadAccess->getBlock();
+
+        // We only consider incoming MemoryAccesses that come before the
+        // MemoryPhi. Otherwise we could discover candidates that do not
+        // strictly dominate our starting def.
+        if (State.PostOrderNumbers[IncomingBlock] >
+            State.PostOrderNumbers[PhiBlock])
+          ToCheck.insert(IncomingAccess);
+      }
       continue;
     }
-    MemoryLocation KillingLoc = *MaybeKillingLoc;
-    assert(KillingLoc.Ptr && "KillingLoc should not be null");
-    const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr);
-    LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
-                      << *KillingDef << " (" << *KillingI << ")\n");
-
-    unsigned ScanLimit = MemorySSAScanLimit;
-    unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
-    unsigned PartialLimit = MemorySSAPartialStoreLimit;
-    // Worklist of MemoryAccesses that may be killed by KillingDef.
-    SmallSetVector<MemoryAccess *, 8> ToCheck;
-    // Track MemoryAccesses that have been deleted in the loop below, so we can
-    // skip them. Don't use SkipStores for this, which may contain reused
-    // MemoryAccess addresses.
-    SmallPtrSet<MemoryAccess *, 8> Deleted;
-    [[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size();
-    ToCheck.insert(KillingDef->getDefiningAccess());
-
-    bool Shortend = false;
-    bool IsMemTerm = State.isMemTerminatorInst(KillingI);
-    // Check if MemoryAccesses in the worklist are killed by KillingDef.
-    for (unsigned I = 0; I < ToCheck.size(); I++) {
-      MemoryAccess *Current = ToCheck[I];
-      if (Deleted.contains(Current))
-        continue;
+    MemoryDefWrapper DeadDefWrapper(cast<MemoryDef>(DeadAccess), State);
+    MemoryLocationWrapper &DeadLoc = DeadDefWrapper.DefinedLocations.front();
+    LLVM_DEBUG(dbgs() << " (" << *DeadDefWrapper.DefInst << ")\n");
+    ToCheck.insert(DeadLoc.GetDefiningAccess());
+    NumGetDomMemoryDefPassed++;
 
-      std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
-          KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit,
-          WalkerStepLimit, IsMemTerm, PartialLimit);
-
-      if (!MaybeDeadAccess) {
-        LLVM_DEBUG(dbgs() << "  finished walk\n");
+    if (!DebugCounter::shouldExecute(MemorySSACounter))
+      continue;
+    if (isMemTerminatorInst(DefInst, State.TLI)) {
+      if (!(UnderlyingObject == DeadLoc.UnderlyingObject))
         continue;
+      LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
+                        << *DeadDefWrapper.DefInst << "\n  KILLER: " << *DefInst
+                        << '\n');
+      State.deleteDeadInstruction(DeadDefWrapper.DefInst);
+      ++NumFastStores;
+      ChangeState = ChangeStateEnum::DeleteByMemTerm;
+    } else {
+      // Check if DeadI overwrites KillingI.
+      int64_t KillingOffset = 0;
+      int64_t DeadOffset = 0;
+      OverwriteResult OR =
+          State.isOverwrite(DefInst, DeadDefWrapper.DefInst, MemLoc,
+                            DeadLoc.MemLoc, KillingOffset, DeadOffset);
+      if (OR == OW_MaybePartial) {
+        auto Iter = State.IOLs.insert(
+            std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
+                DeadDefWrapper.DefInst->getParent(), InstOverlapIntervalsTy()));
+        auto &IOL = Iter.first->second;
+        OR = isPartialOverwrite(MemLoc, DeadLoc.MemLoc, KillingOffset,
+                                DeadOffset, DeadDefWrapper.DefInst, IOL);
       }
-
-      MemoryAccess *DeadAccess = *MaybeDeadAccess;
-      LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
-      if (isa<MemoryPhi>(DeadAccess)) {
-        LLVM_DEBUG(dbgs() << "\n  ... adding incoming values to worklist\n");
-        for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
-          MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
-          BasicBlock *IncomingBlock = IncomingAccess->getBlock();
-          BasicBlock *PhiBlock = DeadAccess->getBlock();
-
-          // We only consider incoming MemoryAccesses that come before the
-          // MemoryPhi. Otherwise we could discover candidates that do not
-          // strictly dominate our starting def.
-          if (State.PostOrderNumbers[IncomingBlock] >
-              State.PostOrderNumbers[PhiBlock])
-            ToCheck.insert(IncomingAccess);
+      if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
+        auto *DeadSI = dyn_cast<StoreInst>(DeadDefWrapper.DefInst);
+        auto *KillingSI = dyn_cast<StoreInst>(DefInst);
+        // We are re-using tryToMergePartialOverlappingStores, which requires
+        // DeadSI to dominate KillingSI.
+        // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
+        if (DeadSI && KillingSI && State.DT.dominates(DeadSI, KillingSI)) {
+          if (Constant *Merged = tryToMergePartialOverlappingStores(
+                  KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL,
+                  State.BatchAA, &State.DT)) {
+
+            // Update stored value of earlier store to merged constant.
+            D...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/100956


More information about the llvm-commits mailing list