[llvm] [MemCpyOpt] Merge memset and skip unrelated clobber in one scan (PR #90350)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 29 05:07:45 PDT 2024


https://github.com/XChy updated https://github.com/llvm/llvm-project/pull/90350

>From fbcf682fbba3870b2750f1e1cf5185ff6de9969d Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 22 Apr 2024 00:46:18 +0800
Subject: [PATCH 1/2] [MemCpyOpt] Precommit tests for merging into memset (NFC)

---
 .../Transforms/MemCpyOpt/merge-into-memset.ll | 174 +++++++++++++++++-
 1 file changed, 172 insertions(+), 2 deletions(-)

diff --git a/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll b/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
index 78aa769982404a0..ca2ffd72818496f 100644
--- a/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
+++ b/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
@@ -36,5 +36,175 @@ exit:
   ret void
 }
 
-declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1)
-declare void @llvm.memset.p0.i64(ptr, i8, i64, i1)
+define void @memset_clobber_no_alias(ptr %p) {
+; CHECK-LABEL: @memset_clobber_no_alias(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P:%.*]], i8 0, i64 16, i1 false)
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 16, i1 false)
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_clobber_no_alias1(i64 %a, ptr %p) {
+; CHECK-LABEL: @store_clobber_no_alias1(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[P:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store i64 %a, ptr %p, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_clobber_no_alias2(i64 %a, ptr %p) {
+; CHECK-LABEL: @store_clobber_no_alias2(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[P:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store i64 %a, ptr %p, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_clobber_no_alias_precise_fail(i64 %a) {
+; CHECK-LABEL: @store_clobber_no_alias_precise_fail(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[STACK]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store i64 %a, ptr %stack, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_clobber_may_alias_fail(ptr %p, ptr %p1) {
+; CHECK-LABEL: @store_clobber_may_alias_fail(
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK:%.*]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store i64 0, ptr [[P1:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack1 = getelementptr inbounds i8, ptr %p, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store i64 0, ptr %p1, align 8
+  %stack2 = getelementptr inbounds i8, ptr %p, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @load_clobber_no_alias(ptr %p, ptr %p1) {
+; CHECK-LABEL: @load_clobber_no_alias(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    [[A:%.*]] = load i64, ptr [[P:%.*]], align 8
+; CHECK-NEXT:    store i64 [[A]], ptr [[P1:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  %a = load i64, ptr %p, align 8
+  store i64 %a, ptr %p1, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @load_clobber_alias_fail(ptr %p, ptr %p1) {
+; CHECK-LABEL: @load_clobber_alias_fail(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    [[A:%.*]] = load i64, ptr [[STACK]], align 8
+; CHECK-NEXT:    store i64 [[A]], ptr [[P1:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  %a = load i64, ptr %stack, align 8
+  store i64 %a, ptr %p1, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @memset_volatile_fail(ptr %p) {
+; CHECK-LABEL: @memset_volatile_fail(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P:%.*]], i8 0, i64 16, i1 true)
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 16, i1 true)
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_volatile_fail(i64 %a, ptr %p) {
+; CHECK-LABEL: @store_volatile_fail(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store volatile i64 [[A:%.*]], ptr [[P:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store volatile i64 %a, ptr %p
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}

>From cd5a161c5ecaa7868532cee60ae95fca645641ca Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 22 Apr 2024 00:29:53 +0800
Subject: [PATCH 2/2] [MemCpyOpt] Continue merge memset with unrelated clobber

---
 .../llvm/Transforms/Scalar/MemCpyOptimizer.h  |   3 +-
 .../lib/Transforms/Scalar/MemCpyOptimizer.cpp | 343 +++++++++++-------
 llvm/test/Transforms/MemCpyOpt/form-memset.ll |   4 +-
 .../Transforms/MemCpyOpt/merge-into-memset.ll |  12 +-
 4 files changed, 214 insertions(+), 148 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
index 6c809bc881d050d..00ed731e0cda00b 100644
--- a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
+++ b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
@@ -78,8 +78,7 @@ class MemCpyOptPass : public PassInfoMixin<MemCpyOptPass> {
                                   BatchAAResults &BAA);
   bool processByValArgument(CallBase &CB, unsigned ArgNo);
   bool processImmutArgument(CallBase &CB, unsigned ArgNo);
-  Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr,
-                                    Value *ByteVal);
+  bool tryMergingIntoMemset(BasicBlock *BB);
   bool moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI);
   bool performStackMoveOptzn(Instruction *Load, Instruction *Store,
                              AllocaInst *DestAlloca, AllocaInst *SrcAlloca,
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 7ef5dceffec0d2a..fa4be64ec9e8bad 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -158,11 +158,16 @@ class MemsetRanges {
 
   /// A sorted list of the memset ranges.
   SmallVector<MemsetRange, 8> Ranges;
-
-  const DataLayout &DL;
+  const DataLayout *DL;
 
 public:
-  MemsetRanges(const DataLayout &DL) : DL(DL) {}
+  MemsetRanges() {}
+  MemsetRanges(const DataLayout *DL) : DL(DL) {}
+  MemsetRanges(const DataLayout *DL, Instruction *I, Value *StartPtr,
+               Value *ByteVal)
+      : DL(DL), StartInst(I), StartPtr(StartPtr),
+        StartPtrLocation(MemoryLocation::getBeforeOrAfter(StartPtr)),
+        ByteVal(ByteVal) {}
 
   using const_iterator = SmallVectorImpl<MemsetRange>::const_iterator;
 
@@ -178,7 +183,7 @@ class MemsetRanges {
   }
 
   void addStore(int64_t OffsetFromFirst, StoreInst *SI) {
-    TypeSize StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType());
+    TypeSize StoreSize = DL->getTypeStoreSize(SI->getOperand(0)->getType());
     assert(!StoreSize.isScalable() && "Can't track scalable-typed stores");
     addRange(OffsetFromFirst, StoreSize.getFixedValue(),
              SI->getPointerOperand(), SI->getAlign(), SI);
@@ -191,6 +196,18 @@ class MemsetRanges {
 
   void addRange(int64_t Start, int64_t Size, Value *Ptr, MaybeAlign Alignment,
                 Instruction *Inst);
+
+  bool flush(MemorySSAUpdater *MSSAU, Instruction *InsertPoint,
+             MemoryUseOrDef *&MemInsertPoint);
+
+  // The first store/memset instruction
+  Instruction *StartInst;
+  // The start pointer written by the first instruction
+  Value *StartPtr;
+  // The memory location of StartPtr
+  MemoryLocation StartPtrLocation;
+  // The byte value of memset
+  Value *ByteVal;
 };
 
 } // end anonymous namespace
@@ -255,6 +272,73 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr,
   }
 }
 
+bool MemsetRanges::flush(MemorySSAUpdater *MSSAU, Instruction *InsertPoint,
+                         MemoryUseOrDef *&MemInsertPoint) {
+
+  // If we have no ranges, then we just had a single store with nothing that
+  // could be merged in.  This is a very common case of course.
+  if (empty())
+    return false;
+
+  // If we had at least one store that could be merged in, add the starting
+  // store as well.  We try to avoid this unless there is at least something
+  // interesting as a small compile-time optimization.
+  addInst(0, StartInst);
+
+  // If we create any memsets, we put it right before the first instruction that
+  // isn't part of the memset block.  This ensure that the memset is dominated
+  // by any addressing instruction needed by the start of the block.
+  IRBuilder<> Builder(InsertPoint);
+
+  // Now that we have full information about ranges, loop over the ranges and
+  // emit memset's for anything big enough to be worthwhile.
+  Instruction *AMemSet = nullptr;
+
+  // Keeps track of the last memory use or def before the insertion point for
+  // the new memset. The new MemoryDef for the inserted memsets will be inserted
+  // after MemInsertPoint.
+  for (const MemsetRange &Range : *this) {
+    if (Range.TheStores.size() == 1)
+      continue;
+
+    // If it is profitable to lower this range to memset, do so now.
+    if (!Range.isProfitableToUseMemset(*DL))
+      continue;
+
+    // Otherwise, we do want to transform this!  Create a new memset.
+    // Get the starting pointer of the block.
+    Value *CurrentStart = Range.StartPtr;
+
+    AMemSet = Builder.CreateMemSet(CurrentStart, ByteVal,
+                                   Range.End - Range.Start, Range.Alignment);
+    AMemSet->mergeDIAssignID(Range.TheStores);
+
+    LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI
+                                                   : Range.TheStores) dbgs()
+                                              << *SI << '\n';
+               dbgs() << "With: " << *AMemSet << '\n');
+    if (!Range.TheStores.empty())
+      AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc());
+
+    auto *NewDef = cast<MemoryDef>(
+        MemInsertPoint->getMemoryInst() == InsertPoint
+            ? MSSAU->createMemoryAccessBefore(AMemSet, nullptr, MemInsertPoint)
+            : MSSAU->createMemoryAccessAfter(AMemSet, nullptr, MemInsertPoint));
+    MSSAU->insertDef(NewDef, /*RenameUses=*/true);
+    MemInsertPoint = NewDef;
+
+    // Zap all the stores.
+    for (Instruction *SI : Range.TheStores) {
+      MSSAU->removeMemoryAccess(SI);
+      SI->eraseFromParent();
+    }
+
+    ++NumMemSetInfer;
+  }
+
+  return AMemSet;
+}
+
 //===----------------------------------------------------------------------===//
 //                         MemCpyOptLegacyPass Pass
 //===----------------------------------------------------------------------===//
@@ -350,157 +434,156 @@ static void combineAAMetadata(Instruction *ReplInst, Instruction *I) {
   combineMetadata(ReplInst, I, KnownIDs, true);
 }
 
+static bool isCandidateToMergeIntoMemset(Instruction *I, const DataLayout &DL,
+                                         Value *&ByteVal) {
+
+  if (auto *SI = dyn_cast<StoreInst>(I)) {
+    Value *StoredVal = SI->getValueOperand();
+
+    // Avoid merging nontemporal stores since the resulting
+    // memcpy/memset would not be able to preserve the nontemporal hint.
+    if (SI->getMetadata(LLVMContext::MD_nontemporal))
+      return false;
+    // Don't convert stores of non-integral pointer types to memsets (which
+    // stores integers).
+    if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
+      return false;
+
+    // We can't track ranges involving scalable types.
+    if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
+      return false;
+
+    ByteVal = isBytewiseValue(StoredVal, DL);
+    if (!ByteVal)
+      return false;
+
+    return true;
+  }
+
+  if (auto *MSI = dyn_cast<MemSetInst>(I)) {
+    if (!isa<ConstantInt>(MSI->getLength()))
+      return false;
+
+    ByteVal = MSI->getValue();
+    return true;
+  }
+
+  return false;
+}
+
+static Value *getWrittenPtr(Instruction *I) {
+  if (auto *SI = dyn_cast<StoreInst>(I))
+    return SI->getPointerOperand();
+  if (auto *MSI = dyn_cast<MemSetInst>(I))
+    return MSI->getDest();
+  static_assert("Only support store and memset");
+  return nullptr;
+}
+
 /// When scanning forward over instructions, we look for some other patterns to
 /// fold away. In particular, this looks for stores to neighboring locations of
 /// memory. If it sees enough consecutive ones, it attempts to merge them
 /// together into a memcpy/memset.
-Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
-                                                 Value *StartPtr,
-                                                 Value *ByteVal) {
-  const DataLayout &DL = StartInst->getModule()->getDataLayout();
+bool MemCpyOptPass::tryMergingIntoMemset(BasicBlock *BB) {
+  MapVector<Value *, MemsetRanges> ObjToRanges;
+  const DataLayout &DL = BB->getModule()->getDataLayout();
+  MemoryUseOrDef *MemInsertPoint = nullptr;
+  bool MadeChanged = false;
 
-  // We can't track scalable types
-  if (auto *SI = dyn_cast<StoreInst>(StartInst))
-    if (DL.getTypeStoreSize(SI->getOperand(0)->getType()).isScalable())
-      return nullptr;
+  // The following code creates memset intrinsics out of thin air. Don't do
+  // this if the corresponding libfunc is not available.
+  if (!(TLI->has(LibFunc_memset) || EnableMemCpyOptWithoutLibcalls))
+    return false;
 
-  // Okay, so we now have a single store that can be splatable.  Scan to find
-  // all subsequent stores of the same value to offset from the same pointer.
-  // Join these together into ranges, so we can decide whether contiguous blocks
-  // are stored.
-  MemsetRanges Ranges(DL);
+  for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE; BI++) {
+    BatchAAResults BAA(*AA);
+    Instruction *I = &*BI;
 
-  BasicBlock::iterator BI(StartInst);
+    if (!I->mayReadOrWriteMemory())
+      continue;
 
-  // Keeps track of the last memory use or def before the insertion point for
-  // the new memset. The new MemoryDef for the inserted memsets will be inserted
-  // after MemInsertPoint.
-  MemoryUseOrDef *MemInsertPoint = nullptr;
-  for (++BI; !BI->isTerminator(); ++BI) {
-    auto *CurrentAcc = cast_or_null<MemoryUseOrDef>(
-        MSSAU->getMemorySSA()->getMemoryAccess(&*BI));
+    auto *CurrentAcc =
+        cast_or_null<MemoryUseOrDef>(MSSAU->getMemorySSA()->getMemoryAccess(I));
     if (CurrentAcc)
       MemInsertPoint = CurrentAcc;
 
     // Calls that only access inaccessible memory do not block merging
     // accessible stores.
-    if (auto *CB = dyn_cast<CallBase>(BI)) {
+    if (auto *CB = dyn_cast<CallBase>(BI))
       if (CB->onlyAccessesInaccessibleMemory())
         continue;
-    }
 
-    if (!isa<StoreInst>(BI) && !isa<MemSetInst>(BI)) {
-      // If the instruction is readnone, ignore it, otherwise bail out.  We
-      // don't even allow readonly here because we don't want something like:
-      // A[1] = 2; strlen(A); A[2] = 2; -> memcpy(A, ...); strlen(A).
-      if (BI->mayWriteToMemory() || BI->mayReadFromMemory())
-        break;
+    if (I->isVolatile() || I->isAtomic()) {
+      // Flush all MemsetRanges if reaching a fence.
+      for (auto [Obj, Ranges] : ObjToRanges)
+        MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+      ObjToRanges.clear();
       continue;
     }
 
-    if (auto *NextStore = dyn_cast<StoreInst>(BI)) {
-      // If this is a store, see if we can merge it in.
-      if (!NextStore->isSimple())
-        break;
+    // Handle other clobbers
+    if (!isa<StoreInst>(I) && !isa<MemSetInst>(I)) {
+      // Flush all may-aliased MemsetRanges.
+      ObjToRanges.remove_if([&](std::pair<Value *, MemsetRanges> &Entry) {
+        auto &Ranges = Entry.second;
+        bool ShouldFlush =
+            isModOrRefSet(BAA.getModRefInfo(I, Ranges.StartPtrLocation));
+        if (ShouldFlush)
+          MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        return ShouldFlush;
+      });
+      continue;
+    }
 
-      Value *StoredVal = NextStore->getValueOperand();
+    Value *WrittenPtr = getWrittenPtr(I);
+    Value *Obj = getUnderlyingObject(WrittenPtr);
+    Value *ByteVal = nullptr;
 
-      // Don't convert stores of non-integral pointer types to memsets (which
-      // stores integers).
-      if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
-        break;
+    // If this is a store, see if we can merge it in.
+    if (ObjToRanges.contains(Obj)) {
+      MemsetRanges &Ranges = ObjToRanges[Obj];
 
-      // We can't track ranges involving scalable types.
-      if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
-        break;
+      if (!isCandidateToMergeIntoMemset(I, DL, ByteVal)) {
+        MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        ObjToRanges.erase(Obj);
+        continue;
+      }
 
-      // Check to see if this stored value is of the same byte-splattable value.
-      Value *StoredByte = isBytewiseValue(StoredVal, DL);
-      if (isa<UndefValue>(ByteVal) && StoredByte)
-        ByteVal = StoredByte;
-      if (ByteVal != StoredByte)
-        break;
+      if (isa<UndefValue>(Ranges.ByteVal))
+        Ranges.ByteVal = ByteVal;
 
-      // Check to see if this store is to a constant offset from the start ptr.
       std::optional<int64_t> Offset =
-          NextStore->getPointerOperand()->getPointerOffsetFrom(StartPtr, DL);
-      if (!Offset)
-        break;
-
-      Ranges.addStore(*Offset, NextStore);
-    } else {
-      auto *MSI = cast<MemSetInst>(BI);
+          WrittenPtr->getPointerOffsetFrom(Ranges.StartPtr, DL);
 
-      if (MSI->isVolatile() || ByteVal != MSI->getValue() ||
-          !isa<ConstantInt>(MSI->getLength()))
-        break;
+      if (!Offset || Ranges.ByteVal != ByteVal) {
+        MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        ObjToRanges[Obj] = MemsetRanges(&DL, I, WrittenPtr, ByteVal);
+        continue;
+      }
 
-      // Check to see if this store is to a constant offset from the start ptr.
-      std::optional<int64_t> Offset =
-          MSI->getDest()->getPointerOffsetFrom(StartPtr, DL);
-      if (!Offset)
-        break;
+      Ranges.addInst(*Offset, I);
+    } else {
+      // Flush all may-aliased MemsetRanges.
+      ObjToRanges.remove_if([&](std::pair<Value *, MemsetRanges> &Entry) {
+        auto &Ranges = Entry.second;
+        bool ShouldFlush =
+            isModOrRefSet(BAA.getModRefInfo(I, Ranges.StartPtrLocation));
+        if (ShouldFlush)
+          MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        return ShouldFlush;
+      });
 
-      Ranges.addMemSet(*Offset, MSI);
+      // Create a new MemsetRanges.
+      if (isCandidateToMergeIntoMemset(I, DL, ByteVal))
+        ObjToRanges.insert({Obj, MemsetRanges(&DL, I, WrittenPtr, ByteVal)});
     }
   }
 
-  // If we have no ranges, then we just had a single store with nothing that
-  // could be merged in.  This is a very common case of course.
-  if (Ranges.empty())
-    return nullptr;
-
-  // If we had at least one store that could be merged in, add the starting
-  // store as well.  We try to avoid this unless there is at least something
-  // interesting as a small compile-time optimization.
-  Ranges.addInst(0, StartInst);
-
-  // If we create any memsets, we put it right before the first instruction that
-  // isn't part of the memset block.  This ensure that the memset is dominated
-  // by any addressing instruction needed by the start of the block.
-  IRBuilder<> Builder(&*BI);
-
-  // Now that we have full information about ranges, loop over the ranges and
-  // emit memset's for anything big enough to be worthwhile.
-  Instruction *AMemSet = nullptr;
-  for (const MemsetRange &Range : Ranges) {
-    if (Range.TheStores.size() == 1)
-      continue;
-
-    // If it is profitable to lower this range to memset, do so now.
-    if (!Range.isProfitableToUseMemset(DL))
-      continue;
-
-    // Otherwise, we do want to transform this!  Create a new memset.
-    // Get the starting pointer of the block.
-    StartPtr = Range.StartPtr;
-
-    AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End - Range.Start,
-                                   Range.Alignment);
-    AMemSet->mergeDIAssignID(Range.TheStores);
-
-    LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI
-                                                   : Range.TheStores) dbgs()
-                                              << *SI << '\n';
-               dbgs() << "With: " << *AMemSet << '\n');
-    if (!Range.TheStores.empty())
-      AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc());
-
-    auto *NewDef = cast<MemoryDef>(
-        MemInsertPoint->getMemoryInst() == &*BI
-            ? MSSAU->createMemoryAccessBefore(AMemSet, nullptr, MemInsertPoint)
-            : MSSAU->createMemoryAccessAfter(AMemSet, nullptr, MemInsertPoint));
-    MSSAU->insertDef(NewDef, /*RenameUses=*/true);
-    MemInsertPoint = NewDef;
-
-    // Zap all the stores.
-    for (Instruction *SI : Range.TheStores)
-      eraseInstruction(SI);
-
-    ++NumMemSetInfer;
+  for (auto [Obj, Ranges] : ObjToRanges) {
+    MadeChanged |= Ranges.flush(MSSAU, &BB->back(), MemInsertPoint);
   }
 
-  return AMemSet;
+  return MadeChanged;
 }
 
 // This method try to lift a store instruction before position P.
@@ -797,12 +880,6 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
   // 0xA0A0A0A0 and 0.0.
   auto *V = SI->getOperand(0);
   if (Value *ByteVal = isBytewiseValue(V, DL)) {
-    if (Instruction *I =
-            tryMergingIntoMemset(SI, SI->getPointerOperand(), ByteVal)) {
-      BBI = I->getIterator(); // Don't invalidate iterator.
-      return true;
-    }
-
     // If we have an aggregate, we try to promote it to memset regardless
     // of opportunity for merging as it can expose optimization opportunities
     // in subsequent passes.
@@ -835,14 +912,6 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
 }
 
 bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
-  // See if there is another memset or store neighboring this memset which
-  // allows us to widen out the memset to do a single larger store.
-  if (isa<ConstantInt>(MSI->getLength()) && !MSI->isVolatile())
-    if (Instruction *I =
-            tryMergingIntoMemset(MSI, MSI->getDest(), MSI->getValue())) {
-      BBI = I->getIterator(); // Don't invalidate iterator.
-      return true;
-    }
   return false;
 }
 
@@ -1991,6 +2060,8 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) {
     if (!DT->isReachableFromEntry(&BB))
       continue;
 
+    MadeChange |= tryMergingIntoMemset(&BB);
+
     for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) {
       // Avoid invalidating the iterator.
       Instruction *I = &*BI++;
diff --git a/llvm/test/Transforms/MemCpyOpt/form-memset.ll b/llvm/test/Transforms/MemCpyOpt/form-memset.ll
index 020a72183e9ea1c..edcf27e48f3bda6 100644
--- a/llvm/test/Transforms/MemCpyOpt/form-memset.ll
+++ b/llvm/test/Transforms/MemCpyOpt/form-memset.ll
@@ -97,7 +97,6 @@ define void @test2() nounwind  {
 ; CHECK-NEXT:    [[TMP38:%.*]] = getelementptr [8 x i8], ptr [[REF_IDX]], i32 0, i32 1
 ; CHECK-NEXT:    [[TMP41:%.*]] = getelementptr [8 x i8], ptr [[REF_IDX]], i32 0, i32 0
 ; CHECK-NEXT:    [[TMP43:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 7, i32 0
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 1 [[TMP41]], i8 -1, i64 8, i1 false)
 ; CHECK-NEXT:    [[TMP46:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 7, i32 1
 ; CHECK-NEXT:    [[TMP57:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 6, i32 0
 ; CHECK-NEXT:    [[TMP60:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 6, i32 1
@@ -114,7 +113,6 @@ define void @test2() nounwind  {
 ; CHECK-NEXT:    [[TMP141:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 0, i32 0
 ; CHECK-NEXT:    [[TMP144:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 0, i32 1
 ; CHECK-NEXT:    [[TMP148:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 7, i32 0
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 8 [[TMP141]], i8 0, i64 32, i1 false)
 ; CHECK-NEXT:    [[TMP151:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 7, i32 1
 ; CHECK-NEXT:    [[TMP162:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 6, i32 0
 ; CHECK-NEXT:    [[TMP165:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 6, i32 1
@@ -132,6 +130,8 @@ define void @test2() nounwind  {
 ; CHECK-NEXT:    [[TMP249:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 0, i32 1
 ; CHECK-NEXT:    [[UP_MVD252:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 0
 ; CHECK-NEXT:    [[LEFT_MVD253:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 0
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 1 [[TMP41]], i8 -1, i64 8, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 8 [[TMP141]], i8 0, i64 32, i1 false)
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 8 [[TMP246]], i8 0, i64 32, i1 false)
 ; CHECK-NEXT:    call void @foo(ptr [[UP_MVD252]], ptr [[LEFT_MVD253]], ptr [[TMP41]]) #[[ATTR0]]
 ; CHECK-NEXT:    ret void
diff --git a/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll b/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
index ca2ffd72818496f..7e2d617d1ed2751 100644
--- a/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
+++ b/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
@@ -40,10 +40,9 @@ define void @memset_clobber_no_alias(ptr %p) {
 ; CHECK-LABEL: @memset_clobber_no_alias(
 ; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
 ; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P:%.*]], i8 0, i64 16, i1 false)
 ; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %stack = alloca <256 x i8>, align 8
@@ -59,10 +58,9 @@ define void @store_clobber_no_alias1(i64 %a, ptr %p) {
 ; CHECK-LABEL: @store_clobber_no_alias1(
 ; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
 ; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[P:%.*]], align 8
 ; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %stack = alloca <256 x i8>, align 8
@@ -78,10 +76,9 @@ define void @store_clobber_no_alias2(i64 %a, ptr %p) {
 ; CHECK-LABEL: @store_clobber_no_alias2(
 ; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
 ; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[P:%.*]], align 8
 ; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %stack = alloca <256 x i8>, align 8
@@ -133,11 +130,10 @@ define void @load_clobber_no_alias(ptr %p, ptr %p1) {
 ; CHECK-LABEL: @load_clobber_no_alias(
 ; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
 ; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    [[A:%.*]] = load i64, ptr [[P:%.*]], align 8
 ; CHECK-NEXT:    store i64 [[A]], ptr [[P1:%.*]], align 8
 ; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %stack = alloca <256 x i8>, align 8



More information about the llvm-commits mailing list