[llvm] Refactor DSE (PR #100956)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 28 18:31:09 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Haopeng Liu (haopliu)
<details>
<summary>Changes</summary>
Refactor DSE with MemoryDefWrapper and MemoryLocationWrapper.
Normally, one MemoryDef accesses one MemoryLocation. With "initializes" attribute, one MemoryDef (like call instruction) could initialize multiple MemoryLocations.
Refactor DSE as a preparation to apply "initializes" attribute in DSE in a follow-up PR.
---
Patch is 26.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/100956.diff
1 Files Affected:
- (modified) llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp (+245-191)
``````````diff
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 931606c6f8fe1..e64b7da818a98 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -806,6 +806,63 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
return false;
}
+/// Returns true if \p I is a memory terminator instruction like
+/// llvm.lifetime.end or free.
+bool isMemTerminatorInst(Instruction *I, const TargetLibraryInfo &TLI) {
+ auto *CB = dyn_cast<CallBase>(I);
+ return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
+ getFreedOperand(CB, &TLI) != nullptr);
+}
+
+struct DSEState;
+enum class ChangeStateEnum : uint8_t {
+ NoChange,
+ DeleteByMemTerm,
+ CompleteDeleteByNonMemTerm,
+ PartiallyDeleteByNonMemTerm,
+};
+
+// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
+// defined by `MemDef`.
+class MemoryLocationWrapper {
+public:
+ MemoryLocationWrapper(MemoryLocation MemLoc, DSEState &State,
+ MemoryDef *MemDef)
+ : MemLoc(MemLoc), State(State), MemDef(MemDef) {
+ assert(MemLoc.Ptr && "MemLoc should be not null");
+ UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
+ DefInst = MemDef->getMemoryInst();
+ }
+
+ // Try to eliminate dead defs killed by this MemoryLocation and return the
+ // change state.
+ ChangeStateEnum eliminateDeadDefs();
+ MemoryAccess *GetDefiningAccess() const {
+ return MemDef->getDefiningAccess();
+ }
+
+ MemoryLocation MemLoc;
+ const Value *UnderlyingObject;
+ DSEState &State;
+ MemoryDef *MemDef;
+ Instruction *DefInst;
+};
+
+// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
+// defined by this MemoryDef.
+class MemoryDefWrapper {
+public:
+ MemoryDefWrapper(MemoryDef *MemDef, DSEState &State);
+ // Try to eliminate dead defs killed by this MemoryDef and return the
+ // change state.
+ bool eliminateDeadDefs();
+
+ MemoryDef *MemDef;
+ Instruction *DefInst;
+ DSEState &State;
+ SmallVector<MemoryLocationWrapper, 1> DefinedLocations;
+};
+
struct DSEState {
Function &F;
AliasAnalysis &AA;
@@ -883,7 +940,7 @@ struct DSEState {
auto *MD = dyn_cast_or_null<MemoryDef>(MA);
if (MD && MemDefs.size() < MemorySSADefsPerBlockLimit &&
- (getLocForWrite(&I) || isMemTerminatorInst(&I)))
+ (getLocForWrite(&I) || isMemTerminatorInst(&I, TLI)))
MemDefs.push_back(MD);
}
}
@@ -1225,14 +1282,6 @@ struct DSEState {
return std::nullopt;
}
- /// Returns true if \p I is a memory terminator instruction like
- /// llvm.lifetime.end or free.
- bool isMemTerminatorInst(Instruction *I) const {
- auto *CB = dyn_cast<CallBase>(I);
- return CB && (CB->getIntrinsicID() == Intrinsic::lifetime_end ||
- getFreedOperand(CB, &TLI) != nullptr);
- }
-
/// Returns true if \p MaybeTerm is a memory terminator for \p Loc from
/// instruction \p AccessI.
bool isMemTerminator(const MemoryLocation &Loc, Instruction *AccessI,
@@ -1325,17 +1374,15 @@ struct DSEState {
// (completely) overwrite \p KillingLoc. Currently we bail out when we
// encounter an aliasing MemoryUse (read).
std::optional<MemoryAccess *>
- getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess,
- const MemoryLocation &KillingLoc, const Value *KillingUndObj,
+ getDomMemoryDef(MemoryLocationWrapper &KillingLoc, MemoryAccess *StartAccess,
unsigned &ScanLimit, unsigned &WalkerStepLimit,
- bool IsMemTerm, unsigned &PartialLimit) {
+ unsigned &PartialLimit) {
if (ScanLimit == 0 || WalkerStepLimit == 0) {
LLVM_DEBUG(dbgs() << "\n ... hit scan limit\n");
return std::nullopt;
}
MemoryAccess *Current = StartAccess;
- Instruction *KillingI = KillingDef->getMemoryInst();
LLVM_DEBUG(dbgs() << " trying to get dominating access\n");
// Only optimize defining access of KillingDef when directly starting at its
@@ -1344,8 +1391,8 @@ struct DSEState {
// it should be sufficient to disable optimizations for instructions that
// also read from memory.
bool CanOptimize = OptimizeMemorySSA &&
- KillingDef->getDefiningAccess() == StartAccess &&
- !KillingI->mayReadFromMemory();
+ KillingLoc.GetDefiningAccess() == StartAccess &&
+ !KillingLoc.DefInst->mayReadFromMemory();
// Find the next clobbering Mod access for DefLoc, starting at StartAccess.
std::optional<MemoryLocation> CurrentLoc;
@@ -1361,15 +1408,15 @@ struct DSEState {
// Reached TOP.
if (MSSA.isLiveOnEntryDef(Current)) {
LLVM_DEBUG(dbgs() << " ... found LiveOnEntryDef\n");
- if (CanOptimize && Current != KillingDef->getDefiningAccess())
+ if (CanOptimize && Current != KillingLoc.GetDefiningAccess())
// The first clobbering def is... none.
- KillingDef->setOptimized(Current);
+ KillingLoc.MemDef->setOptimized(Current);
return std::nullopt;
}
// Cost of a step. Accesses in the same block are more likely to be valid
// candidates for elimination, hence consider them cheaper.
- unsigned StepCost = KillingDef->getBlock() == Current->getBlock()
+ unsigned StepCost = KillingLoc.MemDef->getBlock() == Current->getBlock()
? MemorySSASameBBStepCost
: MemorySSAOtherBBStepCost;
if (WalkerStepLimit <= StepCost) {
@@ -1390,21 +1437,23 @@ struct DSEState {
MemoryDef *CurrentDef = cast<MemoryDef>(Current);
Instruction *CurrentI = CurrentDef->getMemoryInst();
- if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(KillingUndObj))) {
+ if (canSkipDef(CurrentDef, !isInvisibleToCallerOnUnwind(
+ KillingLoc.UnderlyingObject))) {
CanOptimize = false;
continue;
}
// Before we try to remove anything, check for any extra throwing
// instructions that block us from DSEing
- if (mayThrowBetween(KillingI, CurrentI, KillingUndObj)) {
+ if (mayThrowBetween(KillingLoc.DefInst, CurrentI,
+ KillingLoc.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << " ... skip, may throw!\n");
return std::nullopt;
}
// Check for anything that looks like it will be a barrier to further
// removal
- if (isDSEBarrier(KillingUndObj, CurrentI)) {
+ if (isDSEBarrier(KillingLoc.UnderlyingObject, CurrentI)) {
LLVM_DEBUG(dbgs() << " ... skip, barrier\n");
return std::nullopt;
}
@@ -1413,14 +1462,16 @@ struct DSEState {
// clobber, bail out, as the path is not profitable. We skip this check
// for intrinsic calls, because the code knows how to handle memcpy
// intrinsics.
- if (!isa<IntrinsicInst>(CurrentI) && isReadClobber(KillingLoc, CurrentI))
+ if (!isa<IntrinsicInst>(CurrentI) &&
+ isReadClobber(KillingLoc.MemLoc, CurrentI))
return std::nullopt;
// Quick check if there are direct uses that are read-clobbers.
if (any_of(Current->uses(), [this, &KillingLoc, StartAccess](Use &U) {
if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(U.getUser()))
return !MSSA.dominates(StartAccess, UseOrDef) &&
- isReadClobber(KillingLoc, UseOrDef->getMemoryInst());
+ isReadClobber(KillingLoc.MemLoc,
+ UseOrDef->getMemoryInst());
return false;
})) {
LLVM_DEBUG(dbgs() << " ... found a read clobber\n");
@@ -1438,32 +1489,33 @@ struct DSEState {
// AliasAnalysis does not account for loops. Limit elimination to
// candidates for which we can guarantee they always store to the same
// memory location and not located in different loops.
- if (!isGuaranteedLoopIndependent(CurrentI, KillingI, *CurrentLoc)) {
+ if (!isGuaranteedLoopIndependent(CurrentI, KillingLoc.DefInst,
+ *CurrentLoc)) {
LLVM_DEBUG(dbgs() << " ... not guaranteed loop independent\n");
CanOptimize = false;
continue;
}
- if (IsMemTerm) {
+ if (isMemTerminatorInst(KillingLoc.DefInst, TLI)) {
// If the killing def is a memory terminator (e.g. lifetime.end), check
// the next candidate if the current Current does not write the same
// underlying object as the terminator.
- if (!isMemTerminator(*CurrentLoc, CurrentI, KillingI)) {
+ if (!isMemTerminator(*CurrentLoc, CurrentI, KillingLoc.DefInst)) {
CanOptimize = false;
continue;
}
} else {
int64_t KillingOffset = 0;
int64_t DeadOffset = 0;
- auto OR = isOverwrite(KillingI, CurrentI, KillingLoc, *CurrentLoc,
- KillingOffset, DeadOffset);
+ auto OR = isOverwrite(KillingLoc.DefInst, CurrentI, KillingLoc.MemLoc,
+ *CurrentLoc, KillingOffset, DeadOffset);
if (CanOptimize) {
// CurrentDef is the earliest write clobber of KillingDef. Use it as
// optimized access. Do not optimize if CurrentDef is already the
// defining access of KillingDef.
- if (CurrentDef != KillingDef->getDefiningAccess() &&
+ if (CurrentDef != KillingLoc.GetDefiningAccess() &&
(OR == OW_Complete || OR == OW_MaybePartial))
- KillingDef->setOptimized(CurrentDef);
+ KillingLoc.MemDef->setOptimized(CurrentDef);
// Once a may-aliasing def is encountered do not set an optimized
// access.
@@ -1496,7 +1548,7 @@ struct DSEState {
// the blocks with killing (=completely overwriting MemoryDefs) and check if
// they cover all paths from MaybeDeadAccess to any function exit.
SmallPtrSet<Instruction *, 16> KillingDefs;
- KillingDefs.insert(KillingDef->getMemoryInst());
+ KillingDefs.insert(KillingLoc.DefInst);
MemoryAccess *MaybeDeadAccess = Current;
MemoryLocation MaybeDeadLoc = *CurrentLoc;
Instruction *MaybeDeadI = cast<MemoryDef>(MaybeDeadAccess)->getMemoryInst();
@@ -1558,7 +1610,8 @@ struct DSEState {
continue;
}
- if (UseInst->mayThrow() && !isInvisibleToCallerOnUnwind(KillingUndObj)) {
+ if (UseInst->mayThrow() &&
+ !isInvisibleToCallerOnUnwind(KillingLoc.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << " ... found throwing instruction\n");
return std::nullopt;
}
@@ -1582,7 +1635,7 @@ struct DSEState {
// if it reads the memory location.
// TODO: It would probably be better to check for self-reads before
// calling the function.
- if (KillingDef == UseAccess || MaybeDeadAccess == UseAccess) {
+ if (KillingLoc.MemDef == UseAccess || MaybeDeadAccess == UseAccess) {
LLVM_DEBUG(dbgs() << " ... skipping killing def/dom access\n");
continue;
}
@@ -1602,7 +1655,7 @@ struct DSEState {
BasicBlock *MaybeKillingBlock = UseInst->getParent();
if (PostOrderNumbers.find(MaybeKillingBlock)->second <
PostOrderNumbers.find(MaybeDeadAccess->getBlock())->second) {
- if (!isInvisibleToCallerAfterRet(KillingUndObj)) {
+ if (!isInvisibleToCallerAfterRet(KillingLoc.UnderlyingObject)) {
LLVM_DEBUG(dbgs()
<< " ... found killing def " << *UseInst << "\n");
KillingDefs.insert(UseInst);
@@ -1620,7 +1673,7 @@ struct DSEState {
// For accesses to locations visible after the function returns, make sure
// that the location is dead (=overwritten) along all paths from
// MaybeDeadAccess to the exit.
- if (!isInvisibleToCallerAfterRet(KillingUndObj)) {
+ if (!isInvisibleToCallerAfterRet(KillingLoc.UnderlyingObject)) {
SmallPtrSet<BasicBlock *, 16> KillingBlocks;
for (Instruction *KD : KillingDefs)
KillingBlocks.insert(KD->getParent());
@@ -2134,180 +2187,181 @@ struct DSEState {
}
};
-static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
- DominatorTree &DT, PostDominatorTree &PDT,
- const TargetLibraryInfo &TLI,
- const LoopInfo &LI) {
- bool MadeChange = false;
+MemoryDefWrapper::MemoryDefWrapper(MemoryDef *MemDef, DSEState &State)
+ : MemDef(MemDef), State(State) {
+ DefInst = MemDef->getMemoryInst();
- DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
- // For each store:
- for (unsigned I = 0; I < State.MemDefs.size(); I++) {
- MemoryDef *KillingDef = State.MemDefs[I];
- if (State.SkipStores.count(KillingDef))
+ if (isMemTerminatorInst(DefInst, State.TLI)) {
+ if (auto KillingLoc = State.getLocForTerminator(DefInst)) {
+ DefinedLocations.push_back(
+ MemoryLocationWrapper(KillingLoc->first, State, MemDef));
+ }
+ return;
+ }
+
+ if (auto KillingLoc = State.getLocForWrite(DefInst)) {
+ DefinedLocations.push_back(
+ MemoryLocationWrapper(*KillingLoc, State, MemDef));
+ }
+}
+
+ChangeStateEnum MemoryLocationWrapper::eliminateDeadDefs() {
+ ChangeStateEnum ChangeState = ChangeStateEnum::NoChange;
+ unsigned ScanLimit = MemorySSAScanLimit;
+ unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
+ unsigned PartialLimit = MemorySSAPartialStoreLimit;
+ // Worklist of MemoryAccesses that may be killed by KillingDef.
+ SmallSetVector<MemoryAccess *, 8> ToCheck;
+ ToCheck.insert(GetDefiningAccess());
+
+ // Check if MemoryAccesses in the worklist are killed by KillingDef.
+ for (unsigned I = 0; I < ToCheck.size(); I++) {
+ MemoryAccess *Current = ToCheck[I];
+ if (State.SkipStores.count(Current))
continue;
- Instruction *KillingI = KillingDef->getMemoryInst();
+ std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
+ *this, Current, ScanLimit, WalkerStepLimit, PartialLimit);
- std::optional<MemoryLocation> MaybeKillingLoc;
- if (State.isMemTerminatorInst(KillingI)) {
- if (auto KillingLoc = State.getLocForTerminator(KillingI))
- MaybeKillingLoc = KillingLoc->first;
- } else {
- MaybeKillingLoc = State.getLocForWrite(KillingI);
+ if (!MaybeDeadAccess) {
+ LLVM_DEBUG(dbgs() << " finished walk\n");
+ continue;
}
-
- if (!MaybeKillingLoc) {
- LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
- << *KillingI << "\n");
+ MemoryAccess *DeadAccess = *MaybeDeadAccess;
+ LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
+ if (isa<MemoryPhi>(DeadAccess)) {
+ LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n");
+ for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
+ MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
+ BasicBlock *IncomingBlock = IncomingAccess->getBlock();
+ BasicBlock *PhiBlock = DeadAccess->getBlock();
+
+ // We only consider incoming MemoryAccesses that come before the
+ // MemoryPhi. Otherwise we could discover candidates that do not
+ // strictly dominate our starting def.
+ if (State.PostOrderNumbers[IncomingBlock] >
+ State.PostOrderNumbers[PhiBlock])
+ ToCheck.insert(IncomingAccess);
+ }
continue;
}
- MemoryLocation KillingLoc = *MaybeKillingLoc;
- assert(KillingLoc.Ptr && "KillingLoc should not be null");
- const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr);
- LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
- << *KillingDef << " (" << *KillingI << ")\n");
-
- unsigned ScanLimit = MemorySSAScanLimit;
- unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit;
- unsigned PartialLimit = MemorySSAPartialStoreLimit;
- // Worklist of MemoryAccesses that may be killed by KillingDef.
- SmallSetVector<MemoryAccess *, 8> ToCheck;
- // Track MemoryAccesses that have been deleted in the loop below, so we can
- // skip them. Don't use SkipStores for this, which may contain reused
- // MemoryAccess addresses.
- SmallPtrSet<MemoryAccess *, 8> Deleted;
- [[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size();
- ToCheck.insert(KillingDef->getDefiningAccess());
-
- bool Shortend = false;
- bool IsMemTerm = State.isMemTerminatorInst(KillingI);
- // Check if MemoryAccesses in the worklist are killed by KillingDef.
- for (unsigned I = 0; I < ToCheck.size(); I++) {
- MemoryAccess *Current = ToCheck[I];
- if (Deleted.contains(Current))
- continue;
+ MemoryDefWrapper DeadDefWrapper(cast<MemoryDef>(DeadAccess), State);
+ MemoryLocationWrapper &DeadLoc = DeadDefWrapper.DefinedLocations.front();
+ LLVM_DEBUG(dbgs() << " (" << *DeadDefWrapper.DefInst << ")\n");
+ ToCheck.insert(DeadLoc.GetDefiningAccess());
+ NumGetDomMemoryDefPassed++;
- std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
- KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit,
- WalkerStepLimit, IsMemTerm, PartialLimit);
-
- if (!MaybeDeadAccess) {
- LLVM_DEBUG(dbgs() << " finished walk\n");
+ if (!DebugCounter::shouldExecute(MemorySSACounter))
+ continue;
+ if (isMemTerminatorInst(DefInst, State.TLI)) {
+ if (!(UnderlyingObject == DeadLoc.UnderlyingObject))
continue;
+ LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: "
+ << *DeadDefWrapper.DefInst << "\n KILLER: " << *DefInst
+ << '\n');
+ State.deleteDeadInstruction(DeadDefWrapper.DefInst);
+ ++NumFastStores;
+ ChangeState = ChangeStateEnum::DeleteByMemTerm;
+ } else {
+ // Check if DeadI overwrites KillingI.
+ int64_t KillingOffset = 0;
+ int64_t DeadOffset = 0;
+ OverwriteResult OR =
+ State.isOverwrite(DefInst, DeadDefWrapper.DefInst, MemLoc,
+ DeadLoc.MemLoc, KillingOffset, DeadOffset);
+ if (OR == OW_MaybePartial) {
+ auto Iter = State.IOLs.insert(
+ std::make_pair<BasicBlock *, InstOverlapIntervalsTy>(
+ DeadDefWrapper.DefInst->getParent(), InstOverlapIntervalsTy()));
+ auto &IOL = Iter.first->second;
+ OR = isPartialOverwrite(MemLoc, DeadLoc.MemLoc, KillingOffset,
+ DeadOffset, DeadDefWrapper.DefInst, IOL);
}
-
- MemoryAccess *DeadAccess = *MaybeDeadAccess;
- LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess);
- if (isa<MemoryPhi>(DeadAccess)) {
- LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n");
- for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) {
- MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
- BasicBlock *IncomingBlock = IncomingAccess->getBlock();
- BasicBlock *PhiBlock = DeadAccess->getBlock();
-
- // We only consider incoming MemoryAccesses that come before the
- // MemoryPhi. Otherwise we could discover candidates that do not
- // strictly dominate our starting def.
- if (State.PostOrderNumbers[IncomingBlock] >
- State.PostOrderNumbers[PhiBlock])
- ToCheck.insert(IncomingAccess);
+ if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
+ auto *DeadSI = dyn_cast<StoreInst>(DeadDefWrapper.DefInst);
+ auto *KillingSI = dyn_cast<StoreInst>(DefInst);
+ // We are re-using tryToMergePartialOverlappingStores, which requires
+ // DeadSI to dominate KillingSI.
+ // TODO: implement tryToMergeParialOverlappingStores using MemorySSA.
+ if (DeadSI && KillingSI && State.DT.dominates(DeadSI, KillingSI)) {
+ if (Constant *Merged = tryToMergePartialOverlappingStores(
+ KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL,
+ State.BatchAA, &State.DT)) {
+
+ // Update stored value of earlier store to merged constant.
+ D...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/100956
More information about the llvm-commits
mailing list