[llvm] b7ede70 - [MemCpyOpt] Use BatchAA when processing one instruction (NFCI)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 6 01:16:47 PST 2022
Author: Nikita Popov
Date: 2022-12-06T10:16:39+01:00
New Revision: b7ede701d8508874d92df72728e7fcec71fdd84a
URL: https://github.com/llvm/llvm-project/commit/b7ede701d8508874d92df72728e7fcec71fdd84a
DIFF: https://github.com/llvm/llvm-project/commit/b7ede701d8508874d92df72728e7fcec71fdd84a.diff
LOG: [MemCpyOpt] Use BatchAA when processing one instruction (NFCI)
While we can't use a single BatchAA instance for the entire
MemCpyOpt run without further justification, we can use BatchAA
while performing the queries related to a single instruction
(these will first perform some AA-based checks, and then modify
the IR only afterwards).
Added:
Modified:
llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
index 8103b0a924898..587b782439a3a 100644
--- a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
+++ b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
@@ -20,6 +20,7 @@
namespace llvm {
class AAResults;
+class BatchAAResults;
class AssumptionCache;
class CallBase;
class CallInst;
@@ -61,10 +62,14 @@ class MemCpyOptPass : public PassInfoMixin<MemCpyOptPass> {
bool processMemMove(MemMoveInst *M);
bool performCallSlotOptzn(Instruction *cpyLoad, Instruction *cpyStore,
Value *cpyDst, Value *cpySrc, TypeSize cpyLen,
- Align cpyAlign, std::function<CallInst *()> GetC);
- bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep);
- bool processMemSetMemCpyDependence(MemCpyInst *MemCpy, MemSetInst *MemSet);
- bool performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, MemSetInst *MemSet);
+ Align cpyAlign, BatchAAResults &BAA,
+ std::function<CallInst *()> GetC);
+ bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep,
+ BatchAAResults &BAA);
+ bool processMemSetMemCpyDependence(MemCpyInst *MemCpy, MemSetInst *MemSet,
+ BatchAAResults &BAA);
+ bool performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, MemSetInst *MemSet,
+ BatchAAResults &BAA);
bool processByValArgument(CallBase &CB, unsigned ArgNo);
Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr,
Value *ByteVal);
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 7ef00ea012d5f..8174761423c0a 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -335,7 +335,7 @@ void MemCpyOptPass::eraseInstruction(Instruction *I) {
// Start and End must be in the same block.
// If SkippedLifetimeStart is provided, skip over one clobbering lifetime.start
// intrinsic and store it inside SkippedLifetimeStart.
-static bool accessedBetween(AliasAnalysis &AA, MemoryLocation Loc,
+static bool accessedBetween(BatchAAResults &AA, MemoryLocation Loc,
const MemoryUseOrDef *Start,
const MemoryUseOrDef *End,
Instruction **SkippedLifetimeStart = nullptr) {
@@ -359,7 +359,7 @@ static bool accessedBetween(AliasAnalysis &AA, MemoryLocation Loc,
// Check for mod of Loc between Start and End, excluding both boundaries.
// Start and End can be in
diff erent blocks.
-static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA,
+static bool writtenBetween(MemorySSA *MSSA, BatchAAResults &AA,
MemoryLocation Loc, const MemoryUseOrDef *Start,
const MemoryUseOrDef *End) {
if (isa<MemoryUse>(End)) {
@@ -380,7 +380,7 @@ static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA,
// TODO: Only walk until we hit Start.
MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(
- End->getDefiningAccess(), Loc);
+ End->getDefiningAccess(), Loc, AA);
return !MSSA->dominates(Clobber, Start);
}
@@ -778,11 +778,12 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
// Detect cases where we're performing call slot forwarding, but
// happen to be using a load-store pair to implement it, rather than
// a memcpy.
+ BatchAAResults BAA(*AA);
auto GetCall = [&]() -> CallInst * {
// We defer this expensive clobber walk until the cheap checks
// have been done on the source inside performCallSlotOptzn.
if (auto *LoadClobber = dyn_cast<MemoryUseOrDef>(
- MSSA->getWalker()->getClobberingMemoryAccess(LI)))
+ MSSA->getWalker()->getClobberingMemoryAccess(LI, BAA)))
return dyn_cast_or_null<CallInst>(LoadClobber->getMemoryInst());
return nullptr;
};
@@ -791,7 +792,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
LI, SI, SI->getPointerOperand()->stripPointerCasts(),
LI->getPointerOperand()->stripPointerCasts(),
DL.getTypeStoreSize(SI->getOperand(0)->getType()),
- std::min(SI->getAlign(), LI->getAlign()), GetCall);
+ std::min(SI->getAlign(), LI->getAlign()), BAA, GetCall);
if (changed) {
eraseInstruction(SI);
eraseInstruction(LI);
@@ -872,7 +873,7 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
Instruction *cpyStore, Value *cpyDest,
Value *cpySrc, TypeSize cpySize,
- Align cpyAlign,
+ Align cpyAlign, BatchAAResults &BAA,
std::function<CallInst *()> GetC) {
// The general transformation to keep in mind is
//
@@ -930,7 +931,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
// Check that nothing touches the dest of the copy between
// the call and the store/memcpy.
Instruction *SkippedLifetimeStart = nullptr;
- if (accessedBetween(*AA, DestLoc, MSSA->getMemoryAccess(C),
+ if (accessedBetween(BAA, DestLoc, MSSA->getMemoryAccess(C),
MSSA->getMemoryAccess(cpyStore), &SkippedLifetimeStart)) {
LLVM_DEBUG(dbgs() << "Call Slot: Dest pointer modified after call\n");
return false;
@@ -1058,7 +1059,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
// pointer (we have already any direct mod/refs in the loop above).
// Also bail if we hit a terminator, as we don't want to scan into other
// blocks.
- if (isModOrRefSet(AA->getModRefInfo(&I, SrcLoc)) || I.isTerminator())
+ if (isModOrRefSet(BAA.getModRefInfo(&I, SrcLoc)) || I.isTerminator())
return false;
}
}
@@ -1079,10 +1080,11 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
// unexpected manner, for example via a global, which we deduce from
// the use analysis, we also need to know that it does not sneakily
// access dest. We rely on AA to figure this out for us.
- ModRefInfo MR = AA->getModRefInfo(C, cpyDest, LocationSize::precise(srcSize));
+ MemoryLocation DestWithSrcSize(cpyDest, LocationSize::precise(srcSize));
+ ModRefInfo MR = BAA.getModRefInfo(C, DestWithSrcSize);
// If necessary, perform additional analysis.
if (isModOrRefSet(MR))
- MR = AA->callCapturesBefore(C, cpyDest, LocationSize::precise(srcSize), DT);
+ MR = BAA.callCapturesBefore(C, DestWithSrcSize, DT);
if (isModOrRefSet(MR))
return false;
@@ -1146,7 +1148,8 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad,
/// We've found that the (upward scanning) memory dependence of memcpy 'M' is
/// the memcpy 'MDep'. Try to simplify M to copy from MDep's input if we can.
bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
- MemCpyInst *MDep) {
+ MemCpyInst *MDep,
+ BatchAAResults &BAA) {
// We can only transforms memcpy's where the dest of one is the source of the
// other.
if (M->getSource() != MDep->getDest() || MDep->isVolatile())
@@ -1180,7 +1183,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
// then we could still perform the xform by moving M up to the first memcpy.
// TODO: It would be sufficient to check the MDep source up to the memcpy
// size of M, rather than MDep.
- if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep),
+ if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep),
MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M)))
return false;
@@ -1190,7 +1193,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
// still want to eliminate the intermediate value, but we have to generate a
// memmove instead of memcpy.
bool UseMemMove = false;
- if (isModSet(AA->getModRefInfo(M, MemoryLocation::getForSource(MDep))))
+ if (isModSet(BAA.getModRefInfo(M, MemoryLocation::getForSource(MDep))))
UseMemMove = true;
// If all checks passed, then we can transform M.
@@ -1244,20 +1247,21 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
/// memset(dst + src_size, c, dst_size <= src_size ? 0 : dst_size - src_size);
/// \endcode
bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
- MemSetInst *MemSet) {
+ MemSetInst *MemSet,
+ BatchAAResults &BAA) {
// We can only transform memset/memcpy with the same destination.
- if (!AA->isMustAlias(MemSet->getDest(), MemCpy->getDest()))
+ if (!BAA.isMustAlias(MemSet->getDest(), MemCpy->getDest()))
return false;
// Check that src and dst of the memcpy aren't the same. While memcpy
// operands cannot partially overlap, exact equality is allowed.
- if (isModSet(AA->getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy))))
+ if (isModSet(BAA.getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy))))
return false;
// We know that dst up to src_size is not written. We now need to make sure
// that dst up to dst_size is not accessed. (If we did not move the memset,
// checking for reads would be sufficient.)
- if (accessedBetween(*AA, MemoryLocation::getForDest(MemSet),
+ if (accessedBetween(BAA, MemoryLocation::getForDest(MemSet),
MSSA->getMemoryAccess(MemSet),
MSSA->getMemoryAccess(MemCpy)))
return false;
@@ -1327,7 +1331,7 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
/// Determine whether the instruction has undefined content for the given Size,
/// either because it was freshly alloca'd or started its lifetime.
-static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V,
+static bool hasUndefContents(MemorySSA *MSSA, BatchAAResults &AA, Value *V,
MemoryDef *Def, Value *Size) {
if (MSSA->isLiveOnEntryDef(Def))
return isa<AllocaInst>(getUnderlyingObject(V));
@@ -1337,7 +1341,7 @@ static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V,
auto *LTSize = cast<ConstantInt>(II->getArgOperand(0));
if (auto *CSize = dyn_cast<ConstantInt>(Size)) {
- if (AA->isMustAlias(V, II->getArgOperand(1)) &&
+ if (AA.isMustAlias(V, II->getArgOperand(1)) &&
LTSize->getZExtValue() >= CSize->getZExtValue())
return true;
}
@@ -1374,10 +1378,11 @@ static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V,
/// \endcode
/// When dst2_size <= dst1_size.
bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
- MemSetInst *MemSet) {
+ MemSetInst *MemSet,
+ BatchAAResults &BAA) {
// Make sure that memcpy(..., memset(...), ...), that is we are memsetting and
// memcpying from the same address. Otherwise it is hard to reason about.
- if (!AA->isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource()))
+ if (!BAA.isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource()))
return false;
Value *MemSetSize = MemSet->getLength();
@@ -1405,9 +1410,9 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
bool CanReduceSize = false;
MemoryUseOrDef *MemSetAccess = MSSA->getMemoryAccess(MemSet);
MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(
- MemSetAccess->getDefiningAccess(), MemCpyLoc);
+ MemSetAccess->getDefiningAccess(), MemCpyLoc, BAA);
if (auto *MD = dyn_cast<MemoryDef>(Clobber))
- if (hasUndefContents(MSSA, AA, MemCpy->getSource(), MD, CopySize))
+ if (hasUndefContents(MSSA, BAA, MemCpy->getSource(), MD, CopySize))
CanReduceSize = true;
if (!CanReduceSize)
@@ -1464,12 +1469,13 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
return true;
}
+ BatchAAResults BAA(*AA);
MemoryUseOrDef *MA = MSSA->getMemoryAccess(M);
// FIXME: Not using getClobberingMemoryAccess() here due to PR54682.
MemoryAccess *AnyClobber = MA->getDefiningAccess();
MemoryLocation DestLoc = MemoryLocation::getForDest(M);
const MemoryAccess *DestClobber =
- MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc);
+ MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc, BAA);
// Try to turn a partially redundant memset + memcpy into
// memcpy + smaller memset. We don't need the memcpy size for this.
@@ -1478,11 +1484,11 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
if (auto *MD = dyn_cast<MemoryDef>(DestClobber))
if (auto *MDep = dyn_cast_or_null<MemSetInst>(MD->getMemoryInst()))
if (DestClobber->getBlock() == M->getParent())
- if (processMemSetMemCpyDependence(M, MDep))
+ if (processMemSetMemCpyDependence(M, MDep, BAA))
return true;
MemoryAccess *SrcClobber = MSSA->getWalker()->getClobberingMemoryAccess(
- AnyClobber, MemoryLocation::getForSource(M));
+ AnyClobber, MemoryLocation::getForSource(M), BAA);
// There are four possible optimizations we can do for memcpy:
// a) memcpy-memcpy xform which exposes redundance for DSE.
@@ -1499,10 +1505,10 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
// of conservatively taking the minimum?
Align Alignment = std::min(M->getDestAlign().valueOrOne(),
M->getSourceAlign().valueOrOne());
- if (performCallSlotOptzn(
- M, M, M->getDest(), M->getSource(),
- TypeSize::getFixed(CopySize->getZExtValue()), Alignment,
- [C]() -> CallInst * { return C; })) {
+ if (performCallSlotOptzn(M, M, M->getDest(), M->getSource(),
+ TypeSize::getFixed(CopySize->getZExtValue()),
+ Alignment, BAA,
+ [C]() -> CallInst * { return C; })) {
LLVM_DEBUG(dbgs() << "Performed call slot optimization:\n"
<< " call: " << *C << "\n"
<< " memcpy: " << *M << "\n");
@@ -1513,9 +1519,9 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
}
}
if (auto *MDep = dyn_cast<MemCpyInst>(MI))
- return processMemCpyMemCpyDependence(M, MDep);
+ return processMemCpyMemCpyDependence(M, MDep, BAA);
if (auto *MDep = dyn_cast<MemSetInst>(MI)) {
- if (performMemCpyToMemSetOptzn(M, MDep)) {
+ if (performMemCpyToMemSetOptzn(M, MDep, BAA)) {
LLVM_DEBUG(dbgs() << "Converted memcpy to memset\n");
eraseInstruction(M);
++NumCpyToSet;
@@ -1524,7 +1530,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {
}
}
- if (hasUndefContents(MSSA, AA, M->getSource(), MD, M->getLength())) {
+ if (hasUndefContents(MSSA, BAA, M->getSource(), MD, M->getLength())) {
LLVM_DEBUG(dbgs() << "Removed memcpy from undef\n");
eraseInstruction(M);
++NumMemCpyInstr;
@@ -1571,8 +1577,9 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
if (!CallAccess)
return false;
MemCpyInst *MDep = nullptr;
+ BatchAAResults BAA(*AA);
MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(
- CallAccess->getDefiningAccess(), Loc);
+ CallAccess->getDefiningAccess(), Loc, BAA);
if (auto *MD = dyn_cast<MemoryDef>(Clobber))
MDep = dyn_cast_or_null<MemCpyInst>(MD->getMemoryInst());
@@ -1613,7 +1620,7 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
// *b = 42;
// foo(*a)
// It would be invalid to transform the second memcpy into foo(*b).
- if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep),
+ if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep),
MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB)))
return false;
More information about the llvm-commits
mailing list