[llvm] 6421dcc - [NFC] [DSE] Refactor DSE (#100956)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 29 11:28:52 PDT 2024


Author: Haopeng Liu
Date: 2024-08-29T11:28:49-07:00
New Revision: 6421dcc0a978900091cc7aa8fa443746602cb442

URL: https://github.com/llvm/llvm-project/commit/6421dcc0a978900091cc7aa8fa443746602cb442
DIFF: https://github.com/llvm/llvm-project/commit/6421dcc0a978900091cc7aa8fa443746602cb442.diff

LOG: [NFC] [DSE] Refactor DSE (#100956)

Refactor DSE with MemoryDefWrapper and MemoryLocationWrapper.

Normally, one MemoryDef accesses one MemoryLocation. With "initializes"
attribute, one MemoryDef (like call instruction) could initialize
multiple MemoryLocations.

Refactor DSE as a preparation to apply "initializes" attribute in DSE in
a follow-up PR
(https://github.com/llvm/llvm-project/commit/58dd8a440343055b1a4929d72317218e912c16fd).

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 992139a95a43d3..a37f295abbd31c 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -806,6 +806,34 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
   return false;
 }
 
+// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
+// defined by `MemDef`.
+struct MemoryLocationWrapper {
+  MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef)
+      : MemLoc(MemLoc), MemDef(MemDef) {
+    assert(MemLoc.Ptr && "MemLoc should be not null");
+    UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
+    DefInst = MemDef->getMemoryInst();
+  }
+
+  MemoryLocation MemLoc;
+  const Value *UnderlyingObject;
+  MemoryDef *MemDef;
+  Instruction *DefInst;
+};
+
+// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
+// defined by this MemoryDef.
+struct MemoryDefWrapper {
+  MemoryDefWrapper(MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
+    DefInst = MemDef->getMemoryInst();
+    if (MemLoc.has_value())
+      DefinedLocation = MemoryLocationWrapper(*MemLoc, MemDef);
+  }
+  Instruction *DefInst;
+  std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
+};
+
 struct DSEState {
   Function &F;
   AliasAnalysis &AA;
@@ -1119,6 +1147,15 @@ struct DSEState {
     return MemoryLocation::getOrNone(I);
   }
 
+  std::optional<MemoryLocation> getLocForInst(Instruction *I) {
+    if (isMemTerminatorInst(I)) {
+      if (auto Loc = getLocForTerminator(I)) {
+        return Loc->first;
+      }
+    }
+    return getLocForWrite(I);
+  }
+
   /// Assuming this instruction has a dead analyzable write, can we delete
   /// this instruction?
   bool isRemovable(Instruction *I) {
@@ -2132,182 +2169,196 @@ struct DSEState {
     }
     return MadeChange;
   }
-};
 
-static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
-                                DominatorTree &DT, PostDominatorTree &PDT,
-                                const TargetLibraryInfo &TLI,
-                                const LoopInfo &LI) {
-  bool MadeChange = false;
+  // Try to eliminate dead defs that access `KillingLocWrapper.MemLoc` and are
+  // killed by `KillingLocWrapper.MemDef`. Return whether
+  // any changes were made, and whether `KillingLocWrapper.DefInst` was deleted.
+  std::pair<bool, bool>
+  eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper);
 
-  DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
-  // For each store:
-  for (unsigned I = 0; I < State.MemDefs.size(); I++) {
-    MemoryDef *KillingDef = State.MemDefs[I];
-    if (State.SkipStores.count(KillingDef))
+  // Try to eliminate dead defs killed by `KillingDefWrapper` and return the
+  // change state: whether make any change.
+  bool eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper);
+};
+
+std::pair<bool, bool>
+DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
+  bool Changed = false;
+  bool DeletedKillingLoc = false;
+  unsigned ScanLimit = MemorySSAScanLimit;
+  unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
+  unsigned PartialLimit = MemorySSAPartialStoreLimit;
+  // Worklist of MemoryAccesses that may be killed by
+  // "KillingLocWrapper.MemDef".
+  SmallSetVector<MemoryAccess *, 8> ToCheck;
+  // Track MemoryAccesses that have been deleted in the loop below, so we can
+  // skip them. Don't use SkipStores for this, which may contain reused
+  // MemoryAccess addresses.
+  SmallPtrSet<MemoryAccess *, 8> Deleted;
+  [[maybe_unused]] unsigned OrigNumSkipStores = SkipStores.size();
+  ToCheck.insert(KillingLocWrapper.MemDef->getDefiningAccess());
+
+  // Check if MemoryAccesses in the worklist are killed by
+  // "KillingLocWrapper.MemDef".
+  for (unsigned I = 0; I < ToCheck.size(); I++) {
+    MemoryAccess *Current = ToCheck[I];
+    if (Deleted.contains(Current))
       continue;
-    Instruction *KillingI = KillingDef->getMemoryInst();
+    std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef(
+        KillingLocWrapper.MemDef, Current, KillingLocWrapper.MemLoc,
+        KillingLocWrapper.UnderlyingObject, ScanLimit, WalkerStepLimit,
+        isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit);
 
-    std::optional<MemoryLocation> MaybeKillingLoc;
-    if (State.isMemTerminatorInst(KillingI)) {
-      if (auto KillingLoc = State.getLocForTerminator(KillingI))
-        MaybeKillingLoc = KillingLoc->first;
-    } else {
-      MaybeKillingLoc = State.getLocForWrite(KillingI);
+    if (!MaybeDeadAccess) {
+      LLVM_DEBUG(dbgs() << "  finished walk\n");
+      continue;
     }
-
-    if (!MaybeKillingLoc) {
-      LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
-                        << *KillingI << "\n");
+    MemoryAccess *DeadAccess = *MaybeDeadAccess;
+    LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
+    if (isa<MemoryPhi>(DeadAccess)) {
+      LLVM_DEBUG(dbgs() << "\n  ... adding incoming values to worklist\n");
+      for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
+        MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
+        BasicBlock *IncomingBlock = IncomingAccess->getBlock();
+        BasicBlock *PhiBlock = DeadAccess->getBlock();
+
+        // We only consider incoming MemoryAccesses that come before the
+        // MemoryPhi. Otherwise we could discover candidates that do not
+        // strictly dominate our starting def.
+        if (PostOrderNumbers[IncomingBlock] > PostOrderNumbers[PhiBlock])
+          ToCheck.insert(IncomingAccess);
+      }
       continue;
     }
-    MemoryLocation KillingLoc = *MaybeKillingLoc;
-    assert(KillingLoc.Ptr && "KillingLoc should not be null");
-    const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr);
-    LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
-                      << *KillingDef << " (" << *KillingI << ")\n");
-
-    unsigned ScanLimit = MemorySSAScanLimit;
-    unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
-    unsigned PartialLimit = MemorySSAPartialStoreLimit;
-    // Worklist of MemoryAccesses that may be killed by KillingDef.
-    SmallSetVector<MemoryAccess *, 8> ToCheck;
-    // Track MemoryAccesses that have been deleted in the loop below, so we can
-    // skip them. Don't use SkipStores for this, which may contain reused
-    // MemoryAccess addresses.
-    SmallPtrSet<MemoryAccess *, 8> Deleted;
-    [[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size();
-    ToCheck.insert(KillingDef->getDefiningAccess());
-
-    bool Shortend = false;
-    bool IsMemTerm = State.isMemTerminatorInst(KillingI);
-    // Check if MemoryAccesses in the worklist are killed by KillingDef.
-    for (unsigned I = 0; I < ToCheck.size(); I++) {
-      MemoryAccess *Current = ToCheck[I];
-      if (Deleted.contains(Current))
-        continue;
-
-      std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
-          KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit,
-          WalkerStepLimit, IsMemTerm, PartialLimit);
-
-      if (!MaybeDeadAccess) {
-        LLVM_DEBUG(dbgs() << "  finished walk\n");
+    MemoryDefWrapper DeadDefWrapper(
+        cast<MemoryDef>(DeadAccess),
+        getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst()));
+    MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation;
+    LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n");
+    ToCheck.insert(DeadLocWrapper.MemDef->getDefiningAccess());
+    NumGetDomMemoryDefPassed++;
+
+    if (!DebugCounter::shouldExecute(MemorySSACounter))
+      continue;
+    if (isMemTerminatorInst(KillingLocWrapper.DefInst)) {
+      if (KillingLocWrapper.UnderlyingObject != DeadLocWrapper.UnderlyingObject)
         continue;
+      LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
+                        << *DeadLocWrapper.DefInst << "\n  KILLER: "
+                        << *KillingLocWrapper.DefInst << '\n');
+      deleteDeadInstruction(DeadLocWrapper.DefInst, &Deleted);
+      ++NumFastStores;
+      Changed = true;
+    } else {
+      // Check if DeadI overwrites KillingI.
+      int64_t KillingOffset = 0;
+      int64_t DeadOffset = 0;
+      OverwriteResult OR =
+          isOverwrite(KillingLocWrapper.DefInst, DeadLocWrapper.DefInst,
+                      KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc,
+                      KillingOffset, DeadOffset);
+      if (OR == OW_MaybePartial) {
+        auto Iter =
+            IOLs.insert(std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
+                DeadLocWrapper.DefInst->getParent(), InstOverlapIntervalsTy()));
+        auto &IOL = Iter.first->second;
+        OR = isPartialOverwrite(KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc,
+                                KillingOffset, DeadOffset,
+                                DeadLocWrapper.DefInst, IOL);
       }
-
-      MemoryAccess *DeadAccess = *MaybeDeadAccess;
-      LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
-      if (isa<MemoryPhi>(DeadAccess)) {
-        LLVM_DEBUG(dbgs() << "\n  ... adding incoming values to worklist\n");
-        for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
-          MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
-          BasicBlock *IncomingBlock = IncomingAccess->getBlock();
-          BasicBlock *PhiBlock = DeadAccess->getBlock();
-
-          // We only consider incoming MemoryAccesses that come before the
-          // MemoryPhi. Otherwise we could discover candidates that do not
-          // strictly dominate our starting def.
-          if (State.PostOrderNumbers[IncomingBlock] >
-              State.PostOrderNumbers[PhiBlock])
-            ToCheck.insert(IncomingAccess);
+      if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
+        auto *DeadSI = dyn_cast<StoreInst>(DeadLocWrapper.DefInst);
+        auto *KillingSI = dyn_cast<StoreInst>(KillingLocWrapper.DefInst);
+        // We are re-using tryToMergePartialOverlappingStores, which requires
+        // DeadSI to dominate KillingSI.
+        // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
+        if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) {
+          if (Constant *Merged = tryToMergePartialOverlappingStores(
+                  KillingSI, DeadSI, KillingOffset, DeadOffset, DL, BatchAA,
+                  &DT)) {
+
+            // Update stored value of earlier store to merged constant.
+            DeadSI->setOperand(0, Merged);
+            ++NumModifiedStores;
+            Changed = true;
+            DeletedKillingLoc = true;
+
+            // Remove killing store and remove any outstanding overlap
+            // intervals for the updated store.
+            deleteDeadInstruction(KillingSI, &Deleted);
+            auto I = IOLs.find(DeadSI->getParent());
+            if (I != IOLs.end())
+              I->second.erase(DeadSI);
+            break;
+          }
         }
-        continue;
       }
-      auto *DeadDefAccess = cast<MemoryDef>(DeadAccess);
-      Instruction *DeadI = DeadDefAccess->getMemoryInst();
-      LLVM_DEBUG(dbgs() << " (" << *DeadI << ")\n");
-      ToCheck.insert(DeadDefAccess->getDefiningAccess());
-      NumGetDomMemoryDefPassed++;
-
-      if (!DebugCounter::shouldExecute(MemorySSACounter))
-        continue;
-
-      MemoryLocation DeadLoc = *State.getLocForWrite(DeadI);
-
-      if (IsMemTerm) {
-        const Value *DeadUndObj = getUnderlyingObject(DeadLoc.Ptr);
-        if (KillingUndObj != DeadUndObj)
-          continue;
-        LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: " << *DeadI
-                          << "\n  KILLER: " << *KillingI << '\n');
-        State.deleteDeadInstruction(DeadI, &Deleted);
+      if (OR == OW_Complete) {
+        LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: "
+                          << *DeadLocWrapper.DefInst << "\n  KILLER: "
+                          << *KillingLocWrapper.DefInst << '\n');
+        deleteDeadInstruction(DeadLocWrapper.DefInst, &Deleted);
         ++NumFastStores;
-        MadeChange = true;
-      } else {
-        // Check if DeadI overwrites KillingI.
-        int64_t KillingOffset = 0;
-        int64_t DeadOffset = 0;
-        OverwriteResult OR = State.isOverwrite(
-            KillingI, DeadI, KillingLoc, DeadLoc, KillingOffset, DeadOffset);
-        if (OR == OW_MaybePartial) {
-          auto Iter = State.IOLs.insert(
-              std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
-                  DeadI->getParent(), InstOverlapIntervalsTy()));
-          auto &IOL = Iter.first->second;
-          OR = isPartialOverwrite(KillingLoc, DeadLoc, KillingOffset,
-                                  DeadOffset, DeadI, IOL);
-        }
-
-        if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
-          auto *DeadSI = dyn_cast<StoreInst>(DeadI);
-          auto *KillingSI = dyn_cast<StoreInst>(KillingI);
-          // We are re-using tryToMergePartialOverlappingStores, which requires
-          // DeadSI to dominate KillingSI.
-          // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
-          if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) {
-            if (Constant *Merged = tryToMergePartialOverlappingStores(
-                    KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL,
-                    State.BatchAA, &DT)) {
-
-              // Update stored value of earlier store to merged constant.
-              DeadSI->setOperand(0, Merged);
-              ++NumModifiedStores;
-              MadeChange = true;
-
-              Shortend = true;
-              // Remove killing store and remove any outstanding overlap
-              // intervals for the updated store.
-              State.deleteDeadInstruction(KillingSI, &Deleted);
-              auto I = State.IOLs.find(DeadSI->getParent());
-              if (I != State.IOLs.end())
-                I->second.erase(DeadSI);
-              break;
-            }
-          }
-        }
-
-        if (OR == OW_Complete) {
-          LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: " << *DeadI
-                            << "\n  KILLER: " << *KillingI << '\n');
-          State.deleteDeadInstruction(DeadI, &Deleted);
-          ++NumFastStores;
-          MadeChange = true;
-        }
+        Changed = true;
       }
     }
+  }
 
-    assert(State.SkipStores.size() - OrigNumSkipStores == Deleted.size() &&
-           "SkipStores and Deleted out of sync?");
+  assert(SkipStores.size() - OrigNumSkipStores == Deleted.size() &&
+         "SkipStores and Deleted out of sync?");
 
-    // Check if the store is a no-op.
-    if (!Shortend && State.storeIsNoop(KillingDef, KillingUndObj)) {
-      LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: " << *KillingI
-                        << '\n');
-      State.deleteDeadInstruction(KillingI);
-      NumRedundantStores++;
-      MadeChange = true;
-      continue;
-    }
+  return {Changed, DeletedKillingLoc};
+}
 
-    // Can we form a calloc from a memset/malloc pair?
-    if (!Shortend && State.tryFoldIntoCalloc(KillingDef, KillingUndObj)) {
-      LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
-                        << "  DEAD: " << *KillingI << '\n');
-      State.deleteDeadInstruction(KillingI);
-      MadeChange = true;
+bool DSEState::eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper) {
+  if (!KillingDefWrapper.DefinedLocation.has_value()) {
+    LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
+                      << *KillingDefWrapper.DefInst << "\n");
+    return false;
+  }
+
+  auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation;
+  LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
+                    << *KillingLocWrapper.MemDef << " ("
+                    << *KillingLocWrapper.DefInst << ")\n");
+  auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);
+
+  // Check if the store is a no-op.
+  if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
+                                        KillingLocWrapper.UnderlyingObject)) {
+    LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: "
+                      << *KillingLocWrapper.DefInst << '\n');
+    deleteDeadInstruction(KillingLocWrapper.DefInst);
+    NumRedundantStores++;
+    return true;
+  }
+  // Can we form a calloc from a memset/malloc pair?
+  if (!DeletedKillingLoc &&
+      tryFoldIntoCalloc(KillingLocWrapper.MemDef,
+                        KillingLocWrapper.UnderlyingObject)) {
+    LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
+                      << "  DEAD: " << *KillingLocWrapper.DefInst << '\n');
+    deleteDeadInstruction(KillingLocWrapper.DefInst);
+    return true;
+  }
+  return Changed;
+}
+
+static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
+                                DominatorTree &DT, PostDominatorTree &PDT,
+                                const TargetLibraryInfo &TLI,
+                                const LoopInfo &LI) {
+  bool MadeChange = false;
+  DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
+  // For each store:
+  for (unsigned I = 0; I < State.MemDefs.size(); I++) {
+    MemoryDef *KillingDef = State.MemDefs[I];
+    if (State.SkipStores.count(KillingDef))
       continue;
-    }
+
+    MemoryDefWrapper KillingDefWrapper(
+        KillingDef, State.getLocForInst(KillingDef->getMemoryInst()));
+    MadeChange |= State.eliminateDeadDefs(KillingDefWrapper);
   }
 
   if (EnablePartialOverwriteTracking)


        


More information about the llvm-commits mailing list