[llvm] [MemCpyOpt] Merge memset and skip unrelated clobber in one scan (PR #90350)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 29 22:23:23 PDT 2024
================
@@ -350,157 +434,157 @@ static void combineAAMetadata(Instruction *ReplInst, Instruction *I) {
combineMetadata(ReplInst, I, KnownIDs, true);
}
+static bool isCandidateToMergeIntoMemset(Instruction *I, const DataLayout &DL,
+ Value *&ByteVal) {
+
+ if (auto *SI = dyn_cast<StoreInst>(I)) {
+ Value *StoredVal = SI->getValueOperand();
+
+ // Avoid merging nontemporal stores since the resulting
+ // memcpy/memset would not be able to preserve the nontemporal hint.
+ if (SI->getMetadata(LLVMContext::MD_nontemporal))
+ return false;
+ // Don't convert stores of non-integral pointer types to memsets (which
+ // stores integers).
+ if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
+ return false;
+
+ // We can't track ranges involving scalable types.
+ if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
+ return false;
+
+ ByteVal = isBytewiseValue(StoredVal, DL);
+ if (!ByteVal)
+ return false;
+
+ return true;
+ }
+
+ if (auto *MSI = dyn_cast<MemSetInst>(I)) {
+ if (!isa<ConstantInt>(MSI->getLength()))
+ return false;
+
+ ByteVal = MSI->getValue();
+ return true;
+ }
+
+ return false;
+}
+
+static Value *getWrittenPtr(Instruction *I) {
+ if (auto *SI = dyn_cast<StoreInst>(I))
+ return SI->getPointerOperand();
+ if (auto *MSI = dyn_cast<MemSetInst>(I))
+ return MSI->getDest();
+ static_assert("Only support store and memset");
+ return nullptr;
+}
+
/// When scanning forward over instructions, we look for some other patterns to
/// fold away. In particular, this looks for stores to neighboring locations of
/// memory. If it sees enough consecutive ones, it attempts to merge them
/// together into a memcpy/memset.
-Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
- Value *StartPtr,
- Value *ByteVal) {
- const DataLayout &DL = StartInst->getModule()->getDataLayout();
+bool MemCpyOptPass::tryMergingIntoMemset(BasicBlock *BB) {
+ MapVector<Value *, MemsetRanges> ObjToRanges;
+ const DataLayout &DL = BB->getModule()->getDataLayout();
+ MemoryUseOrDef *MemInsertPoint = nullptr;
+ BatchAAResults BAA(*AA);
+ bool MadeChanged = false;
- // We can't track scalable types
- if (auto *SI = dyn_cast<StoreInst>(StartInst))
- if (DL.getTypeStoreSize(SI->getOperand(0)->getType()).isScalable())
- return nullptr;
+ // The following code creates memset intrinsics out of thin air. Don't do
+ // this if the corresponding libfunc is not available.
+ if (!(TLI->has(LibFunc_memset) || EnableMemCpyOptWithoutLibcalls))
+ return false;
- // Okay, so we now have a single store that can be splatable. Scan to find
- // all subsequent stores of the same value to offset from the same pointer.
- // Join these together into ranges, so we can decide whether contiguous blocks
- // are stored.
- MemsetRanges Ranges(DL);
+ for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE; BI++) {
+ Instruction *I = &*BI;
- BasicBlock::iterator BI(StartInst);
+ if (!I->mayReadOrWriteMemory())
+ continue;
- // Keeps track of the last memory use or def before the insertion point for
- // the new memset. The new MemoryDef for the inserted memsets will be inserted
- // after MemInsertPoint.
- MemoryUseOrDef *MemInsertPoint = nullptr;
- for (++BI; !BI->isTerminator(); ++BI) {
- auto *CurrentAcc = cast_or_null<MemoryUseOrDef>(
- MSSAU->getMemorySSA()->getMemoryAccess(&*BI));
+ auto *CurrentAcc =
+ cast_or_null<MemoryUseOrDef>(MSSAU->getMemorySSA()->getMemoryAccess(I));
if (CurrentAcc)
MemInsertPoint = CurrentAcc;
// Calls that only access inaccessible memory do not block merging
// accessible stores.
- if (auto *CB = dyn_cast<CallBase>(BI)) {
+ if (auto *CB = dyn_cast<CallBase>(BI))
if (CB->onlyAccessesInaccessibleMemory())
continue;
- }
- if (!isa<StoreInst>(BI) && !isa<MemSetInst>(BI)) {
- // If the instruction is readnone, ignore it, otherwise bail out. We
- // don't even allow readonly here because we don't want something like:
- // A[1] = 2; strlen(A); A[2] = 2; -> memcpy(A, ...); strlen(A).
- if (BI->mayWriteToMemory() || BI->mayReadFromMemory())
- break;
+ if (I->isVolatile() || I->isAtomic()) {
+ // Flush all MemsetRanges if reaching a fence.
+ for (auto [Obj, Ranges] : ObjToRanges)
+ MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+ ObjToRanges.clear();
continue;
}
- if (auto *NextStore = dyn_cast<StoreInst>(BI)) {
- // If this is a store, see if we can merge it in.
- if (!NextStore->isSimple())
- break;
+ // Handle other clobbers
+ if (!isa<StoreInst>(I) && !isa<MemSetInst>(I)) {
+ // Flush all may-aliased MemsetRanges.
+ ObjToRanges.remove_if([&](std::pair<Value *, MemsetRanges> &Entry) {
+ auto &Ranges = Entry.second;
+ bool ShouldFlush =
+ isModOrRefSet(BAA.getModRefInfo(I, Ranges.StartPtrLocation));
+ if (ShouldFlush)
+ MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+ return ShouldFlush;
+ });
+ continue;
+ }
- Value *StoredVal = NextStore->getValueOperand();
+ Value *WrittenPtr = getWrittenPtr(I);
+ Value *Obj = getUnderlyingObject(WrittenPtr);
+ Value *ByteVal = nullptr;
- // Don't convert stores of non-integral pointer types to memsets (which
- // stores integers).
- if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
- break;
+ // If this is a store, see if we can merge it in.
+ if (ObjToRanges.contains(Obj)) {
----------------
nikic wrote:
Use `find()` here to avoid the double lookup.
https://github.com/llvm/llvm-project/pull/90350
More information about the llvm-commits
mailing list