[llvm] [MemCpyOpt] Merge memset and skip unrelated clobber in one scan (PR #90350)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 29 22:23:23 PDT 2024


================
@@ -350,157 +434,157 @@ static void combineAAMetadata(Instruction *ReplInst, Instruction *I) {
   combineMetadata(ReplInst, I, KnownIDs, true);
 }
 
+static bool isCandidateToMergeIntoMemset(Instruction *I, const DataLayout &DL,
+                                         Value *&ByteVal) {
+
+  if (auto *SI = dyn_cast<StoreInst>(I)) {
+    Value *StoredVal = SI->getValueOperand();
+
+    // Avoid merging nontemporal stores since the resulting
+    // memcpy/memset would not be able to preserve the nontemporal hint.
+    if (SI->getMetadata(LLVMContext::MD_nontemporal))
+      return false;
+    // Don't convert stores of non-integral pointer types to memsets (which
+    // stores integers).
+    if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
+      return false;
+
+    // We can't track ranges involving scalable types.
+    if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
+      return false;
+
+    ByteVal = isBytewiseValue(StoredVal, DL);
+    if (!ByteVal)
+      return false;
+
+    return true;
+  }
+
+  if (auto *MSI = dyn_cast<MemSetInst>(I)) {
+    if (!isa<ConstantInt>(MSI->getLength()))
+      return false;
+
+    ByteVal = MSI->getValue();
+    return true;
+  }
+
+  return false;
+}
+
+static Value *getWrittenPtr(Instruction *I) {
+  if (auto *SI = dyn_cast<StoreInst>(I))
+    return SI->getPointerOperand();
+  if (auto *MSI = dyn_cast<MemSetInst>(I))
+    return MSI->getDest();
+  static_assert("Only support store and memset");
+  return nullptr;
+}
+
 /// When scanning forward over instructions, we look for some other patterns to
 /// fold away. In particular, this looks for stores to neighboring locations of
 /// memory. If it sees enough consecutive ones, it attempts to merge them
 /// together into a memcpy/memset.
-Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
-                                                 Value *StartPtr,
-                                                 Value *ByteVal) {
-  const DataLayout &DL = StartInst->getModule()->getDataLayout();
+bool MemCpyOptPass::tryMergingIntoMemset(BasicBlock *BB) {
+  MapVector<Value *, MemsetRanges> ObjToRanges;
+  const DataLayout &DL = BB->getModule()->getDataLayout();
+  MemoryUseOrDef *MemInsertPoint = nullptr;
+  BatchAAResults BAA(*AA);
+  bool MadeChanged = false;
 
-  // We can't track scalable types
-  if (auto *SI = dyn_cast<StoreInst>(StartInst))
-    if (DL.getTypeStoreSize(SI->getOperand(0)->getType()).isScalable())
-      return nullptr;
+  // The following code creates memset intrinsics out of thin air. Don't do
+  // this if the corresponding libfunc is not available.
+  if (!(TLI->has(LibFunc_memset) || EnableMemCpyOptWithoutLibcalls))
+    return false;
 
-  // Okay, so we now have a single store that can be splatable.  Scan to find
-  // all subsequent stores of the same value to offset from the same pointer.
-  // Join these together into ranges, so we can decide whether contiguous blocks
-  // are stored.
-  MemsetRanges Ranges(DL);
+  for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE; BI++) {
+    Instruction *I = &*BI;
 
-  BasicBlock::iterator BI(StartInst);
+    if (!I->mayReadOrWriteMemory())
+      continue;
 
-  // Keeps track of the last memory use or def before the insertion point for
-  // the new memset. The new MemoryDef for the inserted memsets will be inserted
-  // after MemInsertPoint.
-  MemoryUseOrDef *MemInsertPoint = nullptr;
-  for (++BI; !BI->isTerminator(); ++BI) {
-    auto *CurrentAcc = cast_or_null<MemoryUseOrDef>(
-        MSSAU->getMemorySSA()->getMemoryAccess(&*BI));
+    auto *CurrentAcc =
+        cast_or_null<MemoryUseOrDef>(MSSAU->getMemorySSA()->getMemoryAccess(I));
     if (CurrentAcc)
       MemInsertPoint = CurrentAcc;
 
     // Calls that only access inaccessible memory do not block merging
     // accessible stores.
-    if (auto *CB = dyn_cast<CallBase>(BI)) {
+    if (auto *CB = dyn_cast<CallBase>(BI))
       if (CB->onlyAccessesInaccessibleMemory())
         continue;
-    }
 
-    if (!isa<StoreInst>(BI) && !isa<MemSetInst>(BI)) {
-      // If the instruction is readnone, ignore it, otherwise bail out.  We
-      // don't even allow readonly here because we don't want something like:
-      // A[1] = 2; strlen(A); A[2] = 2; -> memcpy(A, ...); strlen(A).
-      if (BI->mayWriteToMemory() || BI->mayReadFromMemory())
-        break;
+    if (I->isVolatile() || I->isAtomic()) {
+      // Flush all MemsetRanges if reaching a fence.
+      for (auto [Obj, Ranges] : ObjToRanges)
+        MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+      ObjToRanges.clear();
       continue;
     }
 
-    if (auto *NextStore = dyn_cast<StoreInst>(BI)) {
-      // If this is a store, see if we can merge it in.
-      if (!NextStore->isSimple())
-        break;
+    // Handle other clobbers
+    if (!isa<StoreInst>(I) && !isa<MemSetInst>(I)) {
+      // Flush all may-aliased MemsetRanges.
+      ObjToRanges.remove_if([&](std::pair<Value *, MemsetRanges> &Entry) {
+        auto &Ranges = Entry.second;
+        bool ShouldFlush =
+            isModOrRefSet(BAA.getModRefInfo(I, Ranges.StartPtrLocation));
+        if (ShouldFlush)
+          MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        return ShouldFlush;
+      });
+      continue;
+    }
 
-      Value *StoredVal = NextStore->getValueOperand();
+    Value *WrittenPtr = getWrittenPtr(I);
+    Value *Obj = getUnderlyingObject(WrittenPtr);
+    Value *ByteVal = nullptr;
 
-      // Don't convert stores of non-integral pointer types to memsets (which
-      // stores integers).
-      if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
-        break;
+    // If this is a store, see if we can merge it in.
+    if (ObjToRanges.contains(Obj)) {
----------------
nikic wrote:

Use `find()` here to avoid the double lookup.

https://github.com/llvm/llvm-project/pull/90350


More information about the llvm-commits mailing list