[llvm] [MemCpyOpt] Merge memset and skip unrelated clobber in one scan (PR #90350)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 29 22:23:24 PDT 2024
================
@@ -350,157 +434,157 @@ static void combineAAMetadata(Instruction *ReplInst, Instruction *I) {
combineMetadata(ReplInst, I, KnownIDs, true);
}
+static bool isCandidateToMergeIntoMemset(Instruction *I, const DataLayout &DL,
+ Value *&ByteVal) {
+
+ if (auto *SI = dyn_cast<StoreInst>(I)) {
+ Value *StoredVal = SI->getValueOperand();
+
+ // Avoid merging nontemporal stores since the resulting
+ // memcpy/memset would not be able to preserve the nontemporal hint.
+ if (SI->getMetadata(LLVMContext::MD_nontemporal))
+ return false;
+ // Don't convert stores of non-integral pointer types to memsets (which
+ // stores integers).
+ if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
+ return false;
+
+ // We can't track ranges involving scalable types.
+ if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
+ return false;
+
+ ByteVal = isBytewiseValue(StoredVal, DL);
+ if (!ByteVal)
+ return false;
+
+ return true;
+ }
+
+ if (auto *MSI = dyn_cast<MemSetInst>(I)) {
+ if (!isa<ConstantInt>(MSI->getLength()))
+ return false;
+
+ ByteVal = MSI->getValue();
+ return true;
+ }
+
+ return false;
+}
+
+static Value *getWrittenPtr(Instruction *I) {
+ if (auto *SI = dyn_cast<StoreInst>(I))
+ return SI->getPointerOperand();
+ if (auto *MSI = dyn_cast<MemSetInst>(I))
+ return MSI->getDest();
+ static_assert("Only support store and memset");
+ return nullptr;
+}
+
/// When scanning forward over instructions, we look for some other patterns to
/// fold away. In particular, this looks for stores to neighboring locations of
/// memory. If it sees enough consecutive ones, it attempts to merge them
/// together into a memcpy/memset.
-Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
- Value *StartPtr,
- Value *ByteVal) {
- const DataLayout &DL = StartInst->getModule()->getDataLayout();
+bool MemCpyOptPass::tryMergingIntoMemset(BasicBlock *BB) {
+ MapVector<Value *, MemsetRanges> ObjToRanges;
+ const DataLayout &DL = BB->getModule()->getDataLayout();
+ MemoryUseOrDef *MemInsertPoint = nullptr;
+ BatchAAResults BAA(*AA);
+ bool MadeChanged = false;
- // We can't track scalable types
- if (auto *SI = dyn_cast<StoreInst>(StartInst))
- if (DL.getTypeStoreSize(SI->getOperand(0)->getType()).isScalable())
- return nullptr;
+ // The following code creates memset intrinsics out of thin air. Don't do
+ // this if the corresponding libfunc is not available.
+ if (!(TLI->has(LibFunc_memset) || EnableMemCpyOptWithoutLibcalls))
+ return false;
- // Okay, so we now have a single store that can be splatable. Scan to find
- // all subsequent stores of the same value to offset from the same pointer.
- // Join these together into ranges, so we can decide whether contiguous blocks
- // are stored.
- MemsetRanges Ranges(DL);
+ for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE; BI++) {
+ Instruction *I = &*BI;
- BasicBlock::iterator BI(StartInst);
+ if (!I->mayReadOrWriteMemory())
+ continue;
- // Keeps track of the last memory use or def before the insertion point for
- // the new memset. The new MemoryDef for the inserted memsets will be inserted
- // after MemInsertPoint.
- MemoryUseOrDef *MemInsertPoint = nullptr;
- for (++BI; !BI->isTerminator(); ++BI) {
- auto *CurrentAcc = cast_or_null<MemoryUseOrDef>(
- MSSAU->getMemorySSA()->getMemoryAccess(&*BI));
+ auto *CurrentAcc =
+ cast_or_null<MemoryUseOrDef>(MSSAU->getMemorySSA()->getMemoryAccess(I));
if (CurrentAcc)
MemInsertPoint = CurrentAcc;
// Calls that only access inaccessible memory do not block merging
// accessible stores.
- if (auto *CB = dyn_cast<CallBase>(BI)) {
+ if (auto *CB = dyn_cast<CallBase>(BI))
if (CB->onlyAccessesInaccessibleMemory())
continue;
- }
- if (!isa<StoreInst>(BI) && !isa<MemSetInst>(BI)) {
- // If the instruction is readnone, ignore it, otherwise bail out. We
- // don't even allow readonly here because we don't want something like:
- // A[1] = 2; strlen(A); A[2] = 2; -> memcpy(A, ...); strlen(A).
- if (BI->mayWriteToMemory() || BI->mayReadFromMemory())
- break;
+ if (I->isVolatile() || I->isAtomic()) {
+ // Flush all MemsetRanges if reaching a fence.
+ for (auto [Obj, Ranges] : ObjToRanges)
+ MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+ ObjToRanges.clear();
continue;
}
- if (auto *NextStore = dyn_cast<StoreInst>(BI)) {
- // If this is a store, see if we can merge it in.
- if (!NextStore->isSimple())
- break;
+ // Handle other clobbers
+ if (!isa<StoreInst>(I) && !isa<MemSetInst>(I)) {
+ // Flush all may-aliased MemsetRanges.
+ ObjToRanges.remove_if([&](std::pair<Value *, MemsetRanges> &Entry) {
+ auto &Ranges = Entry.second;
+ bool ShouldFlush =
+ isModOrRefSet(BAA.getModRefInfo(I, Ranges.StartPtrLocation));
+ if (ShouldFlush)
+ MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+ return ShouldFlush;
+ });
+ continue;
+ }
- Value *StoredVal = NextStore->getValueOperand();
+ Value *WrittenPtr = getWrittenPtr(I);
+ Value *Obj = getUnderlyingObject(WrittenPtr);
+ Value *ByteVal = nullptr;
- // Don't convert stores of non-integral pointer types to memsets (which
- // stores integers).
- if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
- break;
+ // If this is a store, see if we can merge it in.
+ if (ObjToRanges.contains(Obj)) {
+ MemsetRanges &Ranges = ObjToRanges[Obj];
- // We can't track ranges involving scalable types.
- if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
- break;
+ if (!isCandidateToMergeIntoMemset(I, DL, ByteVal)) {
+ MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+ ObjToRanges.erase(Obj);
+ continue;
+ }
- // Check to see if this stored value is of the same byte-splattable value.
- Value *StoredByte = isBytewiseValue(StoredVal, DL);
- if (isa<UndefValue>(ByteVal) && StoredByte)
- ByteVal = StoredByte;
- if (ByteVal != StoredByte)
- break;
+ if (isa<UndefValue>(Ranges.ByteVal))
+ Ranges.ByteVal = ByteVal;
- // Check to see if this store is to a constant offset from the start ptr.
std::optional<int64_t> Offset =
- NextStore->getPointerOperand()->getPointerOffsetFrom(StartPtr, DL);
- if (!Offset)
- break;
-
- Ranges.addStore(*Offset, NextStore);
- } else {
- auto *MSI = cast<MemSetInst>(BI);
+ WrittenPtr->getPointerOffsetFrom(Ranges.StartPtr, DL);
- if (MSI->isVolatile() || ByteVal != MSI->getValue() ||
- !isa<ConstantInt>(MSI->getLength()))
- break;
+ // For unmergable stores/memsets, we create a new MemsetRanges.
+ if (!Offset || Ranges.ByteVal != ByteVal) {
+ MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+ ObjToRanges[Obj] = MemsetRanges(&DL, I, WrittenPtr, ByteVal);
+ continue;
+ }
- // Check to see if this store is to a constant offset from the start ptr.
- std::optional<int64_t> Offset =
- MSI->getDest()->getPointerOffsetFrom(StartPtr, DL);
- if (!Offset)
- break;
+ Ranges.addInst(*Offset, I);
+ } else {
+ // Flush all may-aliased MemsetRanges.
+ ObjToRanges.remove_if([&](std::pair<Value *, MemsetRanges> &Entry) {
+ auto &Ranges = Entry.second;
+ bool ShouldFlush =
+ isModOrRefSet(BAA.getModRefInfo(I, Ranges.StartPtrLocation));
+ if (ShouldFlush)
+ MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+ return ShouldFlush;
+ });
- Ranges.addMemSet(*Offset, MSI);
+ // Create a new MemsetRanges.
+ if (isCandidateToMergeIntoMemset(I, DL, ByteVal))
+ ObjToRanges.insert({Obj, MemsetRanges(&DL, I, WrittenPtr, ByteVal)});
}
}
- // If we have no ranges, then we just had a single store with nothing that
- // could be merged in. This is a very common case of course.
- if (Ranges.empty())
- return nullptr;
-
- // If we had at least one store that could be merged in, add the starting
- // store as well. We try to avoid this unless there is at least something
- // interesting as a small compile-time optimization.
- Ranges.addInst(0, StartInst);
-
- // If we create any memsets, we put it right before the first instruction that
- // isn't part of the memset block. This ensure that the memset is dominated
- // by any addressing instruction needed by the start of the block.
- IRBuilder<> Builder(&*BI);
-
- // Now that we have full information about ranges, loop over the ranges and
- // emit memset's for anything big enough to be worthwhile.
- Instruction *AMemSet = nullptr;
- for (const MemsetRange &Range : Ranges) {
- if (Range.TheStores.size() == 1)
- continue;
-
- // If it is profitable to lower this range to memset, do so now.
- if (!Range.isProfitableToUseMemset(DL))
- continue;
-
- // Otherwise, we do want to transform this! Create a new memset.
- // Get the starting pointer of the block.
- StartPtr = Range.StartPtr;
-
- AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End - Range.Start,
- Range.Alignment);
- AMemSet->mergeDIAssignID(Range.TheStores);
-
- LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI
- : Range.TheStores) dbgs()
- << *SI << '\n';
- dbgs() << "With: " << *AMemSet << '\n');
- if (!Range.TheStores.empty())
- AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc());
-
- auto *NewDef = cast<MemoryDef>(
- MemInsertPoint->getMemoryInst() == &*BI
- ? MSSAU->createMemoryAccessBefore(AMemSet, nullptr, MemInsertPoint)
- : MSSAU->createMemoryAccessAfter(AMemSet, nullptr, MemInsertPoint));
- MSSAU->insertDef(NewDef, /*RenameUses=*/true);
- MemInsertPoint = NewDef;
-
- // Zap all the stores.
- for (Instruction *SI : Range.TheStores)
- eraseInstruction(SI);
-
- ++NumMemSetInfer;
+ for (auto [Obj, Ranges] : ObjToRanges) {
----------------
nikic wrote:
```suggestion
for (auto &[Obj, Ranges] : ObjToRanges) {
```
https://github.com/llvm/llvm-project/pull/90350
More information about the llvm-commits
mailing list