[llvm] [NFC] [DSE] Refactor DSE (PR #100956)

Jan Voung via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 10:40:02 PDT 2024


================
@@ -2132,182 +2169,197 @@ struct DSEState {
     }
     return MadeChange;
   }
-};
 
-static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
-                                DominatorTree &DT, PostDominatorTree &PDT,
-                                const TargetLibraryInfo &TLI,
-                                const LoopInfo &LI) {
-  bool MadeChange = false;
+  // Try to eliminate dead defs that access `KillingLocWrapper.MemLoc` and are
+  // killed by `KillingLocWrapper.MemDef`. Return whether
+  // any changes were made, and whether `KillingLocWrapper.DefInst` was deleted.
+  std::pair<bool, bool>
+  eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper);
 
-  DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
-  // For each store:
-  for (unsigned I = 0; I < State.MemDefs.size(); I++) {
-    MemoryDef *KillingDef = State.MemDefs[I];
-    if (State.SkipStores.count(KillingDef))
+  // Try to eliminate dead defs killed by `KillingDefWrapper` and return the
+  // change state: whether make any change.
+  bool eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper);
+};
+
+std::pair<bool, bool>
+DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
+  bool Changed = false;
+  bool DeletedKillingLoc = false;
+  unsigned ScanLimit = MemorySSAScanLimit;
+  unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
+  unsigned PartialLimit = MemorySSAPartialStoreLimit;
+  // Worklist of MemoryAccesses that may be killed by
+  // "KillingLocWrapper.MemDef".
+  SmallSetVector<MemoryAccess *, 8> ToCheck;
+  // Track MemoryAccesses that have been deleted in the loop below, so we can
+  // skip them. Don't use SkipStores for this, which may contain reused
+  // MemoryAccess addresses.
+  SmallPtrSet<MemoryAccess *, 8> Deleted;
+  [[maybe_unused]] unsigned OrigNumSkipStores = SkipStores.size();
+  ToCheck.insert(KillingLocWrapper.MemDef->getDefiningAccess());
+
+  // Check if MemoryAccesses in the worklist are killed by
+  // "KillingLocWrapper.MemDef".
+  for (unsigned I = 0; I < ToCheck.size(); I++) {
+    MemoryAccess *Current = ToCheck[I];
+    if (Deleted.contains(Current))
       continue;
-    Instruction *KillingI = KillingDef->getMemoryInst();
+    std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef(
+        KillingLocWrapper.MemDef, Current, KillingLocWrapper.MemLoc,
+        KillingLocWrapper.UnderlyingObject, ScanLimit, WalkerStepLimit,
+        isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit);
 
-    std::optional<MemoryLocation> MaybeKillingLoc;
-    if (State.isMemTerminatorInst(KillingI)) {
-      if (auto KillingLoc = State.getLocForTerminator(KillingI))
-        MaybeKillingLoc = KillingLoc->first;
-    } else {
-      MaybeKillingLoc = State.getLocForWrite(KillingI);
+    if (!MaybeDeadAccess) {
+      LLVM_DEBUG(dbgs() << "  finished walk\n");
+      continue;
     }
-
-    if (!MaybeKillingLoc) {
-      LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
-                        << *KillingI << "\n");
+    MemoryAccess *DeadAccess = *MaybeDeadAccess;
+    LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
+    if (isa<MemoryPhi>(DeadAccess)) {
+      LLVM_DEBUG(dbgs() << "\n  ... adding incoming values to worklist\n");
+      for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
+        MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
+        BasicBlock *IncomingBlock = IncomingAccess->getBlock();
+        BasicBlock *PhiBlock = DeadAccess->getBlock();
+
+        // We only consider incoming MemoryAccesses that come before the
+        // MemoryPhi. Otherwise we could discover candidates that do not
+        // strictly dominate our starting def.
+        if (PostOrderNumbers[IncomingBlock] > PostOrderNumbers[PhiBlock])
+          ToCheck.insert(IncomingAccess);
+      }
       continue;
     }
-    MemoryLocation KillingLoc = *MaybeKillingLoc;
-    assert(KillingLoc.Ptr && "KillingLoc should not be null");
-    const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr);
-    LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
-                      << *KillingDef << " (" << *KillingI << ")\n");
-
-    unsigned ScanLimit = MemorySSAScanLimit;
-    unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
-    unsigned PartialLimit = MemorySSAPartialStoreLimit;
-    // Worklist of MemoryAccesses that may be killed by KillingDef.
-    SmallSetVector<MemoryAccess *, 8> ToCheck;
-    // Track MemoryAccesses that have been deleted in the loop below, so we can
-    // skip them. Don't use SkipStores for this, which may contain reused
-    // MemoryAccess addresses.
-    SmallPtrSet<MemoryAccess *, 8> Deleted;
-    [[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size();
-    ToCheck.insert(KillingDef->getDefiningAccess());
-
-    bool Shortend = false;
-    bool IsMemTerm = State.isMemTerminatorInst(KillingI);
-    // Check if MemoryAccesses in the worklist are killed by KillingDef.
-    for (unsigned I = 0; I < ToCheck.size(); I++) {
-      MemoryAccess *Current = ToCheck[I];
-      if (Deleted.contains(Current))
-        continue;
-
-      std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
-          KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit,
-          WalkerStepLimit, IsMemTerm, PartialLimit);
-
-      if (!MaybeDeadAccess) {
-        LLVM_DEBUG(dbgs() << "  finished walk\n");
+    MemoryDefWrapper DeadDefWrapper(
+        cast<MemoryDef>(DeadAccess),
+        getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst()));
+    MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation;
+    LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n");
+    ToCheck.insert(DeadLocWrapper.MemDef->getDefiningAccess());
+    NumGetDomMemoryDefPassed++;
+
+    if (!DebugCounter::shouldExecute(MemorySSACounter))
+      continue;
+    if (isMemTerminatorInst(KillingLocWrapper.DefInst)) {
+      if (!(KillingLocWrapper.UnderlyingObject ==
----------------
jvoung wrote:

nit: can do LHS != RHS instead ?

https://github.com/llvm/llvm-project/pull/100956


More information about the llvm-commits mailing list