[llvm] 6421dcc - [NFC] [DSE] Refactor DSE (#100956)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 29 11:28:52 PDT 2024
Author: Haopeng Liu
Date: 2024-08-29T11:28:49-07:00
New Revision: 6421dcc0a978900091cc7aa8fa443746602cb442
URL: https://github.com/llvm/llvm-project/commit/6421dcc0a978900091cc7aa8fa443746602cb442
DIFF: https://github.com/llvm/llvm-project/commit/6421dcc0a978900091cc7aa8fa443746602cb442.diff
LOG: [NFC] [DSE] Refactor DSE (#100956)
Refactor DSE with MemoryDefWrapper and MemoryLocationWrapper.
Normally, one MemoryDef accesses one MemoryLocation. With "initializes"
attribute, one MemoryDef (like call instruction) could initialize
multiple MemoryLocations.
Refactor DSE as a preparation to apply "initializes" attribute in DSE in
a follow-up PR
(https://github.com/llvm/llvm-project/commit/58dd8a440343055b1a4929d72317218e912c16fd).
Added:
Modified:
llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 992139a95a43d3..a37f295abbd31c 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -806,6 +806,34 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
return false;
}
+// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
+// defined by `MemDef`.
+struct MemoryLocationWrapper {
+ MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef)
+ : MemLoc(MemLoc), MemDef(MemDef) {
+ assert(MemLoc.Ptr && "MemLoc should be not null");
+ UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
+ DefInst = MemDef->getMemoryInst();
+ }
+
+ MemoryLocation MemLoc;
+ const Value *UnderlyingObject;
+ MemoryDef *MemDef;
+ Instruction *DefInst;
+};
+
+// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
+// defined by this MemoryDef.
+struct MemoryDefWrapper {
+ MemoryDefWrapper(MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
+ DefInst = MemDef->getMemoryInst();
+ if (MemLoc.has_value())
+ DefinedLocation = MemoryLocationWrapper(*MemLoc, MemDef);
+ }
+ Instruction *DefInst;
+ std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
+};
+
struct DSEState {
Function &F;
AliasAnalysis &AA;
@@ -1119,6 +1147,15 @@ struct DSEState {
return MemoryLocation::getOrNone(I);
}
+ std::optional<MemoryLocation> getLocForInst(Instruction *I) {
+ if (isMemTerminatorInst(I)) {
+ if (auto Loc = getLocForTerminator(I)) {
+ return Loc->first;
+ }
+ }
+ return getLocForWrite(I);
+ }
+
/// Assuming this instruction has a dead analyzable write, can we delete
/// this instruction?
bool isRemovable(Instruction *I) {
@@ -2132,182 +2169,196 @@ struct DSEState {
}
return MadeChange;
}
-};
-static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
- DominatorTree &DT, PostDominatorTree &PDT,
- const TargetLibraryInfo &TLI,
- const LoopInfo &LI) {
- bool MadeChange = false;
+ // Try to eliminate dead defs that access `KillingLocWrapper.MemLoc` and are
+ // killed by `KillingLocWrapper.MemDef`. Return whether
+ // any changes were made, and whether `KillingLocWrapper.DefInst` was deleted.
+ std::pair<bool, bool>
+ eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper);
- DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
- // For each store:
- for (unsigned I = 0; I < State.MemDefs.size(); I++) {
- MemoryDef *KillingDef = State.MemDefs[I];
- if (State.SkipStores.count(KillingDef))
+ // Try to eliminate dead defs killed by `KillingDefWrapper` and return the
+ // change state: whether make any change.
+ bool eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper);
+};
+
+std::pair<bool, bool>
+DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
+ bool Changed = false;
+ bool DeletedKillingLoc = false;
+ unsigned ScanLimit = MemorySSAScanLimit;
+ unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
+ unsigned PartialLimit = MemorySSAPartialStoreLimit;
+ // Worklist of MemoryAccesses that may be killed by
+ // "KillingLocWrapper.MemDef".
+ SmallSetVector<MemoryAccess *, 8> ToCheck;
+ // Track MemoryAccesses that have been deleted in the loop below, so we can
+ // skip them. Don't use SkipStores for this, which may contain reused
+ // MemoryAccess addresses.
+ SmallPtrSet<MemoryAccess *, 8> Deleted;
+ [[maybe_unused]] unsigned OrigNumSkipStores = SkipStores.size();
+ ToCheck.insert(KillingLocWrapper.MemDef->getDefiningAccess());
+
+ // Check if MemoryAccesses in the worklist are killed by
+ // "KillingLocWrapper.MemDef".
+ for (unsigned I = 0; I < ToCheck.size(); I++) {
+ MemoryAccess *Current = ToCheck[I];
+ if (Deleted.contains(Current))
continue;
- Instruction *KillingI = KillingDef->getMemoryInst();
+ std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef(
+ KillingLocWrapper.MemDef, Current, KillingLocWrapper.MemLoc,
+ KillingLocWrapper.UnderlyingObject, ScanLimit, WalkerStepLimit,
+ isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit);
- std::optional<MemoryLocation> MaybeKillingLoc;
- if (State.isMemTerminatorInst(KillingI)) {
- if (auto KillingLoc = State.getLocForTerminator(KillingI))
- MaybeKillingLoc = KillingLoc->first;
- } else {
- MaybeKillingLoc = State.getLocForWrite(KillingI);
+ if (!MaybeDeadAccess) {
+ LLVM_DEBUG(dbgs() << " finished walk\n");
+ continue;
}
-
- if (!MaybeKillingLoc) {
- LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
- << *KillingI << "\n");
+ MemoryAccess *DeadAccess = *MaybeDeadAccess;
+ LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
+ if (isa<MemoryPhi>(DeadAccess)) {
+ LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n");
+ for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
+ MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
+ BasicBlock *IncomingBlock = IncomingAccess->getBlock();
+ BasicBlock *PhiBlock = DeadAccess->getBlock();
+
+ // We only consider incoming MemoryAccesses that come before the
+ // MemoryPhi. Otherwise we could discover candidates that do not
+ // strictly dominate our starting def.
+ if (PostOrderNumbers[IncomingBlock] > PostOrderNumbers[PhiBlock])
+ ToCheck.insert(IncomingAccess);
+ }
continue;
}
- MemoryLocation KillingLoc = *MaybeKillingLoc;
- assert(KillingLoc.Ptr && "KillingLoc should not be null");
- const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr);
- LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
- << *KillingDef << " (" << *KillingI << ")\n");
-
- unsigned ScanLimit = MemorySSAScanLimit;
- unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
- unsigned PartialLimit = MemorySSAPartialStoreLimit;
- // Worklist of MemoryAccesses that may be killed by KillingDef.
- SmallSetVector<MemoryAccess *, 8> ToCheck;
- // Track MemoryAccesses that have been deleted in the loop below, so we can
- // skip them. Don't use SkipStores for this, which may contain reused
- // MemoryAccess addresses.
- SmallPtrSet<MemoryAccess *, 8> Deleted;
- [[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size();
- ToCheck.insert(KillingDef->getDefiningAccess());
-
- bool Shortend = false;
- bool IsMemTerm = State.isMemTerminatorInst(KillingI);
- // Check if MemoryAccesses in the worklist are killed by KillingDef.
- for (unsigned I = 0; I < ToCheck.size(); I++) {
- MemoryAccess *Current = ToCheck[I];
- if (Deleted.contains(Current))
- continue;
-
- std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
- KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit,
- WalkerStepLimit, IsMemTerm, PartialLimit);
-
- if (!MaybeDeadAccess) {
- LLVM_DEBUG(dbgs() << " finished walk\n");
+ MemoryDefWrapper DeadDefWrapper(
+ cast<MemoryDef>(DeadAccess),
+ getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst()));
+ MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation;
+ LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n");
+ ToCheck.insert(DeadLocWrapper.MemDef->getDefiningAccess());
+ NumGetDomMemoryDefPassed++;
+
+ if (!DebugCounter::shouldExecute(MemorySSACounter))
+ continue;
+ if (isMemTerminatorInst(KillingLocWrapper.DefInst)) {
+ if (KillingLocWrapper.UnderlyingObject != DeadLocWrapper.UnderlyingObject)
continue;
+ LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: "
+ << *DeadLocWrapper.DefInst << "\n KILLER: "
+ << *KillingLocWrapper.DefInst << '\n');
+ deleteDeadInstruction(DeadLocWrapper.DefInst, &Deleted);
+ ++NumFastStores;
+ Changed = true;
+ } else {
+ // Check if DeadI overwrites KillingI.
+ int64_t KillingOffset = 0;
+ int64_t DeadOffset = 0;
+ OverwriteResult OR =
+ isOverwrite(KillingLocWrapper.DefInst, DeadLocWrapper.DefInst,
+ KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc,
+ KillingOffset, DeadOffset);
+ if (OR == OW_MaybePartial) {
+ auto Iter =
+ IOLs.insert(std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
+ DeadLocWrapper.DefInst->getParent(), InstOverlapIntervalsTy()));
+ auto &IOL = Iter.first->second;
+ OR = isPartialOverwrite(KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc,
+ KillingOffset, DeadOffset,
+ DeadLocWrapper.DefInst, IOL);
}
-
- MemoryAccess *DeadAccess = *MaybeDeadAccess;
- LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
- if (isa<MemoryPhi>(DeadAccess)) {
- LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n");
- for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
- MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
- BasicBlock *IncomingBlock = IncomingAccess->getBlock();
- BasicBlock *PhiBlock = DeadAccess->getBlock();
-
- // We only consider incoming MemoryAccesses that come before the
- // MemoryPhi. Otherwise we could discover candidates that do not
- // strictly dominate our starting def.
- if (State.PostOrderNumbers[IncomingBlock] >
- State.PostOrderNumbers[PhiBlock])
- ToCheck.insert(IncomingAccess);
+ if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
+ auto *DeadSI = dyn_cast<StoreInst>(DeadLocWrapper.DefInst);
+ auto *KillingSI = dyn_cast<StoreInst>(KillingLocWrapper.DefInst);
+ // We are re-using tryToMergePartialOverlappingStores, which requires
+ // DeadSI to dominate KillingSI.
+ // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
+ if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) {
+ if (Constant *Merged = tryToMergePartialOverlappingStores(
+ KillingSI, DeadSI, KillingOffset, DeadOffset, DL, BatchAA,
+ &DT)) {
+
+ // Update stored value of earlier store to merged constant.
+ DeadSI->setOperand(0, Merged);
+ ++NumModifiedStores;
+ Changed = true;
+ DeletedKillingLoc = true;
+
+ // Remove killing store and remove any outstanding overlap
+ // intervals for the updated store.
+ deleteDeadInstruction(KillingSI, &Deleted);
+ auto I = IOLs.find(DeadSI->getParent());
+ if (I != IOLs.end())
+ I->second.erase(DeadSI);
+ break;
+ }
}
- continue;
}
- auto *DeadDefAccess = cast<MemoryDef>(DeadAccess);
- Instruction *DeadI = DeadDefAccess->getMemoryInst();
- LLVM_DEBUG(dbgs() << " (" << *DeadI << ")\n");
- ToCheck.insert(DeadDefAccess->getDefiningAccess());
- NumGetDomMemoryDefPassed++;
-
- if (!DebugCounter::shouldExecute(MemorySSACounter))
- continue;
-
- MemoryLocation DeadLoc = *State.getLocForWrite(DeadI);
-
- if (IsMemTerm) {
- const Value *DeadUndObj = getUnderlyingObject(DeadLoc.Ptr);
- if (KillingUndObj != DeadUndObj)
- continue;
- LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI
- << "\n KILLER: " << *KillingI << '\n');
- State.deleteDeadInstruction(DeadI, &Deleted);
+ if (OR == OW_Complete) {
+ LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: "
+ << *DeadLocWrapper.DefInst << "\n KILLER: "
+ << *KillingLocWrapper.DefInst << '\n');
+ deleteDeadInstruction(DeadLocWrapper.DefInst, &Deleted);
++NumFastStores;
- MadeChange = true;
- } else {
- // Check if DeadI overwrites KillingI.
- int64_t KillingOffset = 0;
- int64_t DeadOffset = 0;
- OverwriteResult OR = State.isOverwrite(
- KillingI, DeadI, KillingLoc, DeadLoc, KillingOffset, DeadOffset);
- if (OR == OW_MaybePartial) {
- auto Iter = State.IOLs.insert(
- std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
- DeadI->getParent(), InstOverlapIntervalsTy()));
- auto &IOL = Iter.first->second;
- OR = isPartialOverwrite(KillingLoc, DeadLoc, KillingOffset,
- DeadOffset, DeadI, IOL);
- }
-
- if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
- auto *DeadSI = dyn_cast<StoreInst>(DeadI);
- auto *KillingSI = dyn_cast<StoreInst>(KillingI);
- // We are re-using tryToMergePartialOverlappingStores, which requires
- // DeadSI to dominate KillingSI.
- // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
- if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) {
- if (Constant *Merged = tryToMergePartialOverlappingStores(
- KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL,
- State.BatchAA, &DT)) {
-
- // Update stored value of earlier store to merged constant.
- DeadSI->setOperand(0, Merged);
- ++NumModifiedStores;
- MadeChange = true;
-
- Shortend = true;
- // Remove killing store and remove any outstanding overlap
- // intervals for the updated store.
- State.deleteDeadInstruction(KillingSI, &Deleted);
- auto I = State.IOLs.find(DeadSI->getParent());
- if (I != State.IOLs.end())
- I->second.erase(DeadSI);
- break;
- }
- }
- }
-
- if (OR == OW_Complete) {
- LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI
- << "\n KILLER: " << *KillingI << '\n');
- State.deleteDeadInstruction(DeadI, &Deleted);
- ++NumFastStores;
- MadeChange = true;
- }
+ Changed = true;
}
}
+ }
- assert(State.SkipStores.size() - OrigNumSkipStores == Deleted.size() &&
- "SkipStores and Deleted out of sync?");
+ assert(SkipStores.size() - OrigNumSkipStores == Deleted.size() &&
+ "SkipStores and Deleted out of sync?");
- // Check if the store is a no-op.
- if (!Shortend && State.storeIsNoop(KillingDef, KillingUndObj)) {
- LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *KillingI
- << '\n');
- State.deleteDeadInstruction(KillingI);
- NumRedundantStores++;
- MadeChange = true;
- continue;
- }
+ return {Changed, DeletedKillingLoc};
+}
- // Can we form a calloc from a memset/malloc pair?
- if (!Shortend && State.tryFoldIntoCalloc(KillingDef, KillingUndObj)) {
- LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
- << " DEAD: " << *KillingI << '\n');
- State.deleteDeadInstruction(KillingI);
- MadeChange = true;
+bool DSEState::eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper) {
+ if (!KillingDefWrapper.DefinedLocation.has_value()) {
+ LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
+ << *KillingDefWrapper.DefInst << "\n");
+ return false;
+ }
+
+ auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation;
+ LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
+ << *KillingLocWrapper.MemDef << " ("
+ << *KillingLocWrapper.DefInst << ")\n");
+ auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);
+
+ // Check if the store is a no-op.
+ if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
+ KillingLocWrapper.UnderlyingObject)) {
+ LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: "
+ << *KillingLocWrapper.DefInst << '\n');
+ deleteDeadInstruction(KillingLocWrapper.DefInst);
+ NumRedundantStores++;
+ return true;
+ }
+ // Can we form a calloc from a memset/malloc pair?
+ if (!DeletedKillingLoc &&
+ tryFoldIntoCalloc(KillingLocWrapper.MemDef,
+ KillingLocWrapper.UnderlyingObject)) {
+ LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
+ << " DEAD: " << *KillingLocWrapper.DefInst << '\n');
+ deleteDeadInstruction(KillingLocWrapper.DefInst);
+ return true;
+ }
+ return Changed;
+}
+
+static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
+ DominatorTree &DT, PostDominatorTree &PDT,
+ const TargetLibraryInfo &TLI,
+ const LoopInfo &LI) {
+ bool MadeChange = false;
+ DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
+ // For each store:
+ for (unsigned I = 0; I < State.MemDefs.size(); I++) {
+ MemoryDef *KillingDef = State.MemDefs[I];
+ if (State.SkipStores.count(KillingDef))
continue;
- }
+
+ MemoryDefWrapper KillingDefWrapper(
+ KillingDef, State.getLocForInst(KillingDef->getMemoryInst()));
+ MadeChange |= State.eliminateDeadDefs(KillingDefWrapper);
}
if (EnablePartialOverwriteTracking)
More information about the llvm-commits
mailing list