[llvm] [MemCpyOpt] Merge memset and skip unrelated clobber in one scan (PR #90350)

via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 27 09:50:34 PDT 2024


https://github.com/XChy created https://github.com/llvm/llvm-project/pull/90350

Alternative to #89550.

The procedure is:
- maintain a map, whose key is the underlying object, and the value is corresponding MemsetRanges
- visit the instructions one by one
- once a clobber appears
   - for a store/memset, if the map contains the corresponding underlying object, try to insert it into the MemsetRanges. Otherwise, we flush all may-aliased MemsetRanges, and insert a new MemsetRanges into the map.
   - for other clobbers, we flush the aliased ranges in the map, this requires checking every MemsetRanges in the map.

This method avoids revisiting the ones to the same object and the ones don't read or write memory.


>From 63f5ca70df10a603e260126d14f78b39695d0703 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 22 Apr 2024 00:46:18 +0800
Subject: [PATCH 1/2] [MemCpyOpt] Precommit tests for merging into memset (NFC)

---
 .../Transforms/MemCpyOpt/merge-into-memset.ll | 174 +++++++++++++++++-
 1 file changed, 172 insertions(+), 2 deletions(-)

diff --git a/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll b/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
index 78aa769982404a..ca2ffd72818496 100644
--- a/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
+++ b/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
@@ -36,5 +36,175 @@ exit:
   ret void
 }
 
-declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1)
-declare void @llvm.memset.p0.i64(ptr, i8, i64, i1)
+define void @memset_clobber_no_alias(ptr %p) {
+; CHECK-LABEL: @memset_clobber_no_alias(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P:%.*]], i8 0, i64 16, i1 false)
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 16, i1 false)
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_clobber_no_alias1(i64 %a, ptr %p) {
+; CHECK-LABEL: @store_clobber_no_alias1(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[P:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store i64 %a, ptr %p, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_clobber_no_alias2(i64 %a, ptr %p) {
+; CHECK-LABEL: @store_clobber_no_alias2(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[P:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store i64 %a, ptr %p, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_clobber_no_alias_precise_fail(i64 %a) {
+; CHECK-LABEL: @store_clobber_no_alias_precise_fail(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[STACK]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store i64 %a, ptr %stack, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_clobber_may_alias_fail(ptr %p, ptr %p1) {
+; CHECK-LABEL: @store_clobber_may_alias_fail(
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK:%.*]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store i64 0, ptr [[P1:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack1 = getelementptr inbounds i8, ptr %p, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store i64 0, ptr %p1, align 8
+  %stack2 = getelementptr inbounds i8, ptr %p, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @load_clobber_no_alias(ptr %p, ptr %p1) {
+; CHECK-LABEL: @load_clobber_no_alias(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    [[A:%.*]] = load i64, ptr [[P:%.*]], align 8
+; CHECK-NEXT:    store i64 [[A]], ptr [[P1:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  %a = load i64, ptr %p, align 8
+  store i64 %a, ptr %p1, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @load_clobber_alias_fail(ptr %p, ptr %p1) {
+; CHECK-LABEL: @load_clobber_alias_fail(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    [[A:%.*]] = load i64, ptr [[STACK]], align 8
+; CHECK-NEXT:    store i64 [[A]], ptr [[P1:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  %a = load i64, ptr %stack, align 8
+  store i64 %a, ptr %p1, align 8
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @memset_volatile_fail(ptr %p) {
+; CHECK-LABEL: @memset_volatile_fail(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P:%.*]], i8 0, i64 16, i1 true)
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 16, i1 true)
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}
+
+define void @store_volatile_fail(i64 %a, ptr %p) {
+; CHECK-LABEL: @store_volatile_fail(
+; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
+; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
+; CHECK-NEXT:    store volatile i64 [[A:%.*]], ptr [[P:%.*]], align 8
+; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %stack = alloca <256 x i8>, align 8
+  %stack1 = getelementptr inbounds i8, ptr %stack, i64 8
+  call void @llvm.memset.p0.i64(ptr %stack1, i8 0, i64 136, i1 false)
+  store volatile i64 %a, ptr %p
+  %stack2 = getelementptr inbounds i8, ptr %stack, i64 24
+  call void @llvm.memset.p0.i64(ptr %stack2, i8 0, i64 24, i1 false)
+  ret void
+}

>From cc8eac48365bb8d6ca7b5b8d26ac93427b0dc144 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 22 Apr 2024 00:29:53 +0800
Subject: [PATCH 2/2] [MemCpyOpt] Continue merge memset with unrelated clobber

---
 .../llvm/Transforms/Scalar/MemCpyOptimizer.h  |   3 +-
 .../lib/Transforms/Scalar/MemCpyOptimizer.cpp | 343 +++++++++++-------
 llvm/test/Transforms/MemCpyOpt/form-memset.ll |   4 +-
 .../Transforms/MemCpyOpt/merge-into-memset.ll |  12 +-
 4 files changed, 214 insertions(+), 148 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
index 6c809bc881d050..00ed731e0cda00 100644
--- a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
+++ b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h
@@ -78,8 +78,7 @@ class MemCpyOptPass : public PassInfoMixin<MemCpyOptPass> {
                                   BatchAAResults &BAA);
   bool processByValArgument(CallBase &CB, unsigned ArgNo);
   bool processImmutArgument(CallBase &CB, unsigned ArgNo);
-  Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr,
-                                    Value *ByteVal);
+  bool tryMergingIntoMemset(BasicBlock *BB);
   bool moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI);
   bool performStackMoveOptzn(Instruction *Load, Instruction *Store,
                              AllocaInst *DestAlloca, AllocaInst *SrcAlloca,
diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 7ef5dceffec0d2..fa4be64ec9e8ba 100644
--- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -158,11 +158,16 @@ class MemsetRanges {
 
   /// A sorted list of the memset ranges.
   SmallVector<MemsetRange, 8> Ranges;
-
-  const DataLayout &DL;
+  const DataLayout *DL;
 
 public:
-  MemsetRanges(const DataLayout &DL) : DL(DL) {}
+  MemsetRanges() {}
+  MemsetRanges(const DataLayout *DL) : DL(DL) {}
+  MemsetRanges(const DataLayout *DL, Instruction *I, Value *StartPtr,
+               Value *ByteVal)
+      : DL(DL), StartInst(I), StartPtr(StartPtr),
+        StartPtrLocation(MemoryLocation::getBeforeOrAfter(StartPtr)),
+        ByteVal(ByteVal) {}
 
   using const_iterator = SmallVectorImpl<MemsetRange>::const_iterator;
 
@@ -178,7 +183,7 @@ class MemsetRanges {
   }
 
   void addStore(int64_t OffsetFromFirst, StoreInst *SI) {
-    TypeSize StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType());
+    TypeSize StoreSize = DL->getTypeStoreSize(SI->getOperand(0)->getType());
     assert(!StoreSize.isScalable() && "Can't track scalable-typed stores");
     addRange(OffsetFromFirst, StoreSize.getFixedValue(),
              SI->getPointerOperand(), SI->getAlign(), SI);
@@ -191,6 +196,18 @@ class MemsetRanges {
 
   void addRange(int64_t Start, int64_t Size, Value *Ptr, MaybeAlign Alignment,
                 Instruction *Inst);
+
+  bool flush(MemorySSAUpdater *MSSAU, Instruction *InsertPoint,
+             MemoryUseOrDef *&MemInsertPoint);
+
+  // The first store/memset instruction
+  Instruction *StartInst;
+  // The start pointer written by the first instruction
+  Value *StartPtr;
+  // The memory location of StartPtr
+  MemoryLocation StartPtrLocation;
+  // The byte value of memset
+  Value *ByteVal;
 };
 
 } // end anonymous namespace
@@ -255,6 +272,73 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr,
   }
 }
 
+bool MemsetRanges::flush(MemorySSAUpdater *MSSAU, Instruction *InsertPoint,
+                         MemoryUseOrDef *&MemInsertPoint) {
+
+  // If we have no ranges, then we just had a single store with nothing that
+  // could be merged in.  This is a very common case of course.
+  if (empty())
+    return false;
+
+  // If we had at least one store that could be merged in, add the starting
+  // store as well.  We try to avoid this unless there is at least something
+  // interesting as a small compile-time optimization.
+  addInst(0, StartInst);
+
+  // If we create any memsets, we put it right before the first instruction that
+  // isn't part of the memset block.  This ensure that the memset is dominated
+  // by any addressing instruction needed by the start of the block.
+  IRBuilder<> Builder(InsertPoint);
+
+  // Now that we have full information about ranges, loop over the ranges and
+  // emit memset's for anything big enough to be worthwhile.
+  Instruction *AMemSet = nullptr;
+
+  // Keeps track of the last memory use or def before the insertion point for
+  // the new memset. The new MemoryDef for the inserted memsets will be inserted
+  // after MemInsertPoint.
+  for (const MemsetRange &Range : *this) {
+    if (Range.TheStores.size() == 1)
+      continue;
+
+    // If it is profitable to lower this range to memset, do so now.
+    if (!Range.isProfitableToUseMemset(*DL))
+      continue;
+
+    // Otherwise, we do want to transform this!  Create a new memset.
+    // Get the starting pointer of the block.
+    Value *CurrentStart = Range.StartPtr;
+
+    AMemSet = Builder.CreateMemSet(CurrentStart, ByteVal,
+                                   Range.End - Range.Start, Range.Alignment);
+    AMemSet->mergeDIAssignID(Range.TheStores);
+
+    LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI
+                                                   : Range.TheStores) dbgs()
+                                              << *SI << '\n';
+               dbgs() << "With: " << *AMemSet << '\n');
+    if (!Range.TheStores.empty())
+      AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc());
+
+    auto *NewDef = cast<MemoryDef>(
+        MemInsertPoint->getMemoryInst() == InsertPoint
+            ? MSSAU->createMemoryAccessBefore(AMemSet, nullptr, MemInsertPoint)
+            : MSSAU->createMemoryAccessAfter(AMemSet, nullptr, MemInsertPoint));
+    MSSAU->insertDef(NewDef, /*RenameUses=*/true);
+    MemInsertPoint = NewDef;
+
+    // Zap all the stores.
+    for (Instruction *SI : Range.TheStores) {
+      MSSAU->removeMemoryAccess(SI);
+      SI->eraseFromParent();
+    }
+
+    ++NumMemSetInfer;
+  }
+
+  return AMemSet;
+}
+
 //===----------------------------------------------------------------------===//
 //                         MemCpyOptLegacyPass Pass
 //===----------------------------------------------------------------------===//
@@ -350,157 +434,156 @@ static void combineAAMetadata(Instruction *ReplInst, Instruction *I) {
   combineMetadata(ReplInst, I, KnownIDs, true);
 }
 
+static bool isCandidateToMergeIntoMemset(Instruction *I, const DataLayout &DL,
+                                         Value *&ByteVal) {
+
+  if (auto *SI = dyn_cast<StoreInst>(I)) {
+    Value *StoredVal = SI->getValueOperand();
+
+    // Avoid merging nontemporal stores since the resulting
+    // memcpy/memset would not be able to preserve the nontemporal hint.
+    if (SI->getMetadata(LLVMContext::MD_nontemporal))
+      return false;
+    // Don't convert stores of non-integral pointer types to memsets (which
+    // stores integers).
+    if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
+      return false;
+
+    // We can't track ranges involving scalable types.
+    if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
+      return false;
+
+    ByteVal = isBytewiseValue(StoredVal, DL);
+    if (!ByteVal)
+      return false;
+
+    return true;
+  }
+
+  if (auto *MSI = dyn_cast<MemSetInst>(I)) {
+    if (!isa<ConstantInt>(MSI->getLength()))
+      return false;
+
+    ByteVal = MSI->getValue();
+    return true;
+  }
+
+  return false;
+}
+
+static Value *getWrittenPtr(Instruction *I) {
+  if (auto *SI = dyn_cast<StoreInst>(I))
+    return SI->getPointerOperand();
+  if (auto *MSI = dyn_cast<MemSetInst>(I))
+    return MSI->getDest();
+  static_assert("Only support store and memset");
+  return nullptr;
+}
+
 /// When scanning forward over instructions, we look for some other patterns to
 /// fold away. In particular, this looks for stores to neighboring locations of
 /// memory. If it sees enough consecutive ones, it attempts to merge them
 /// together into a memcpy/memset.
-Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,
-                                                 Value *StartPtr,
-                                                 Value *ByteVal) {
-  const DataLayout &DL = StartInst->getModule()->getDataLayout();
+bool MemCpyOptPass::tryMergingIntoMemset(BasicBlock *BB) {
+  MapVector<Value *, MemsetRanges> ObjToRanges;
+  const DataLayout &DL = BB->getModule()->getDataLayout();
+  MemoryUseOrDef *MemInsertPoint = nullptr;
+  bool MadeChanged = false;
 
-  // We can't track scalable types
-  if (auto *SI = dyn_cast<StoreInst>(StartInst))
-    if (DL.getTypeStoreSize(SI->getOperand(0)->getType()).isScalable())
-      return nullptr;
+  // The following code creates memset intrinsics out of thin air. Don't do
+  // this if the corresponding libfunc is not available.
+  if (!(TLI->has(LibFunc_memset) || EnableMemCpyOptWithoutLibcalls))
+    return false;
 
-  // Okay, so we now have a single store that can be splatable.  Scan to find
-  // all subsequent stores of the same value to offset from the same pointer.
-  // Join these together into ranges, so we can decide whether contiguous blocks
-  // are stored.
-  MemsetRanges Ranges(DL);
+  for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE; BI++) {
+    BatchAAResults BAA(*AA);
+    Instruction *I = &*BI;
 
-  BasicBlock::iterator BI(StartInst);
+    if (!I->mayReadOrWriteMemory())
+      continue;
 
-  // Keeps track of the last memory use or def before the insertion point for
-  // the new memset. The new MemoryDef for the inserted memsets will be inserted
-  // after MemInsertPoint.
-  MemoryUseOrDef *MemInsertPoint = nullptr;
-  for (++BI; !BI->isTerminator(); ++BI) {
-    auto *CurrentAcc = cast_or_null<MemoryUseOrDef>(
-        MSSAU->getMemorySSA()->getMemoryAccess(&*BI));
+    auto *CurrentAcc =
+        cast_or_null<MemoryUseOrDef>(MSSAU->getMemorySSA()->getMemoryAccess(I));
     if (CurrentAcc)
       MemInsertPoint = CurrentAcc;
 
     // Calls that only access inaccessible memory do not block merging
     // accessible stores.
-    if (auto *CB = dyn_cast<CallBase>(BI)) {
+    if (auto *CB = dyn_cast<CallBase>(BI))
       if (CB->onlyAccessesInaccessibleMemory())
         continue;
-    }
 
-    if (!isa<StoreInst>(BI) && !isa<MemSetInst>(BI)) {
-      // If the instruction is readnone, ignore it, otherwise bail out.  We
-      // don't even allow readonly here because we don't want something like:
-      // A[1] = 2; strlen(A); A[2] = 2; -> memcpy(A, ...); strlen(A).
-      if (BI->mayWriteToMemory() || BI->mayReadFromMemory())
-        break;
+    if (I->isVolatile() || I->isAtomic()) {
+      // Flush all MemsetRanges if reaching a fence.
+      for (auto [Obj, Ranges] : ObjToRanges)
+        MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+      ObjToRanges.clear();
       continue;
     }
 
-    if (auto *NextStore = dyn_cast<StoreInst>(BI)) {
-      // If this is a store, see if we can merge it in.
-      if (!NextStore->isSimple())
-        break;
+    // Handle other clobbers
+    if (!isa<StoreInst>(I) && !isa<MemSetInst>(I)) {
+      // Flush all may-aliased MemsetRanges.
+      ObjToRanges.remove_if([&](std::pair<Value *, MemsetRanges> &Entry) {
+        auto &Ranges = Entry.second;
+        bool ShouldFlush =
+            isModOrRefSet(BAA.getModRefInfo(I, Ranges.StartPtrLocation));
+        if (ShouldFlush)
+          MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        return ShouldFlush;
+      });
+      continue;
+    }
 
-      Value *StoredVal = NextStore->getValueOperand();
+    Value *WrittenPtr = getWrittenPtr(I);
+    Value *Obj = getUnderlyingObject(WrittenPtr);
+    Value *ByteVal = nullptr;
 
-      // Don't convert stores of non-integral pointer types to memsets (which
-      // stores integers).
-      if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
-        break;
+    // If this is a store, see if we can merge it in.
+    if (ObjToRanges.contains(Obj)) {
+      MemsetRanges &Ranges = ObjToRanges[Obj];
 
-      // We can't track ranges involving scalable types.
-      if (DL.getTypeStoreSize(StoredVal->getType()).isScalable())
-        break;
+      if (!isCandidateToMergeIntoMemset(I, DL, ByteVal)) {
+        MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        ObjToRanges.erase(Obj);
+        continue;
+      }
 
-      // Check to see if this stored value is of the same byte-splattable value.
-      Value *StoredByte = isBytewiseValue(StoredVal, DL);
-      if (isa<UndefValue>(ByteVal) && StoredByte)
-        ByteVal = StoredByte;
-      if (ByteVal != StoredByte)
-        break;
+      if (isa<UndefValue>(Ranges.ByteVal))
+        Ranges.ByteVal = ByteVal;
 
-      // Check to see if this store is to a constant offset from the start ptr.
       std::optional<int64_t> Offset =
-          NextStore->getPointerOperand()->getPointerOffsetFrom(StartPtr, DL);
-      if (!Offset)
-        break;
-
-      Ranges.addStore(*Offset, NextStore);
-    } else {
-      auto *MSI = cast<MemSetInst>(BI);
+          WrittenPtr->getPointerOffsetFrom(Ranges.StartPtr, DL);
 
-      if (MSI->isVolatile() || ByteVal != MSI->getValue() ||
-          !isa<ConstantInt>(MSI->getLength()))
-        break;
+      if (!Offset || Ranges.ByteVal != ByteVal) {
+        MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        ObjToRanges[Obj] = MemsetRanges(&DL, I, WrittenPtr, ByteVal);
+        continue;
+      }
 
-      // Check to see if this store is to a constant offset from the start ptr.
-      std::optional<int64_t> Offset =
-          MSI->getDest()->getPointerOffsetFrom(StartPtr, DL);
-      if (!Offset)
-        break;
+      Ranges.addInst(*Offset, I);
+    } else {
+      // Flush all may-aliased MemsetRanges.
+      ObjToRanges.remove_if([&](std::pair<Value *, MemsetRanges> &Entry) {
+        auto &Ranges = Entry.second;
+        bool ShouldFlush =
+            isModOrRefSet(BAA.getModRefInfo(I, Ranges.StartPtrLocation));
+        if (ShouldFlush)
+          MadeChanged |= Ranges.flush(MSSAU, I, MemInsertPoint);
+        return ShouldFlush;
+      });
 
-      Ranges.addMemSet(*Offset, MSI);
+      // Create a new MemsetRanges.
+      if (isCandidateToMergeIntoMemset(I, DL, ByteVal))
+        ObjToRanges.insert({Obj, MemsetRanges(&DL, I, WrittenPtr, ByteVal)});
     }
   }
 
-  // If we have no ranges, then we just had a single store with nothing that
-  // could be merged in.  This is a very common case of course.
-  if (Ranges.empty())
-    return nullptr;
-
-  // If we had at least one store that could be merged in, add the starting
-  // store as well.  We try to avoid this unless there is at least something
-  // interesting as a small compile-time optimization.
-  Ranges.addInst(0, StartInst);
-
-  // If we create any memsets, we put it right before the first instruction that
-  // isn't part of the memset block.  This ensure that the memset is dominated
-  // by any addressing instruction needed by the start of the block.
-  IRBuilder<> Builder(&*BI);
-
-  // Now that we have full information about ranges, loop over the ranges and
-  // emit memset's for anything big enough to be worthwhile.
-  Instruction *AMemSet = nullptr;
-  for (const MemsetRange &Range : Ranges) {
-    if (Range.TheStores.size() == 1)
-      continue;
-
-    // If it is profitable to lower this range to memset, do so now.
-    if (!Range.isProfitableToUseMemset(DL))
-      continue;
-
-    // Otherwise, we do want to transform this!  Create a new memset.
-    // Get the starting pointer of the block.
-    StartPtr = Range.StartPtr;
-
-    AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End - Range.Start,
-                                   Range.Alignment);
-    AMemSet->mergeDIAssignID(Range.TheStores);
-
-    LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI
-                                                   : Range.TheStores) dbgs()
-                                              << *SI << '\n';
-               dbgs() << "With: " << *AMemSet << '\n');
-    if (!Range.TheStores.empty())
-      AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc());
-
-    auto *NewDef = cast<MemoryDef>(
-        MemInsertPoint->getMemoryInst() == &*BI
-            ? MSSAU->createMemoryAccessBefore(AMemSet, nullptr, MemInsertPoint)
-            : MSSAU->createMemoryAccessAfter(AMemSet, nullptr, MemInsertPoint));
-    MSSAU->insertDef(NewDef, /*RenameUses=*/true);
-    MemInsertPoint = NewDef;
-
-    // Zap all the stores.
-    for (Instruction *SI : Range.TheStores)
-      eraseInstruction(SI);
-
-    ++NumMemSetInfer;
+  for (auto [Obj, Ranges] : ObjToRanges) {
+    MadeChanged |= Ranges.flush(MSSAU, &BB->back(), MemInsertPoint);
   }
 
-  return AMemSet;
+  return MadeChanged;
 }
 
 // This method try to lift a store instruction before position P.
@@ -797,12 +880,6 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
   // 0xA0A0A0A0 and 0.0.
   auto *V = SI->getOperand(0);
   if (Value *ByteVal = isBytewiseValue(V, DL)) {
-    if (Instruction *I =
-            tryMergingIntoMemset(SI, SI->getPointerOperand(), ByteVal)) {
-      BBI = I->getIterator(); // Don't invalidate iterator.
-      return true;
-    }
-
     // If we have an aggregate, we try to promote it to memset regardless
     // of opportunity for merging as it can expose optimization opportunities
     // in subsequent passes.
@@ -835,14 +912,6 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
 }
 
 bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
-  // See if there is another memset or store neighboring this memset which
-  // allows us to widen out the memset to do a single larger store.
-  if (isa<ConstantInt>(MSI->getLength()) && !MSI->isVolatile())
-    if (Instruction *I =
-            tryMergingIntoMemset(MSI, MSI->getDest(), MSI->getValue())) {
-      BBI = I->getIterator(); // Don't invalidate iterator.
-      return true;
-    }
   return false;
 }
 
@@ -1991,6 +2060,8 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) {
     if (!DT->isReachableFromEntry(&BB))
       continue;
 
+    MadeChange |= tryMergingIntoMemset(&BB);
+
     for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) {
       // Avoid invalidating the iterator.
       Instruction *I = &*BI++;
diff --git a/llvm/test/Transforms/MemCpyOpt/form-memset.ll b/llvm/test/Transforms/MemCpyOpt/form-memset.ll
index 020a72183e9ea1..edcf27e48f3bda 100644
--- a/llvm/test/Transforms/MemCpyOpt/form-memset.ll
+++ b/llvm/test/Transforms/MemCpyOpt/form-memset.ll
@@ -97,7 +97,6 @@ define void @test2() nounwind  {
 ; CHECK-NEXT:    [[TMP38:%.*]] = getelementptr [8 x i8], ptr [[REF_IDX]], i32 0, i32 1
 ; CHECK-NEXT:    [[TMP41:%.*]] = getelementptr [8 x i8], ptr [[REF_IDX]], i32 0, i32 0
 ; CHECK-NEXT:    [[TMP43:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 7, i32 0
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 1 [[TMP41]], i8 -1, i64 8, i1 false)
 ; CHECK-NEXT:    [[TMP46:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 7, i32 1
 ; CHECK-NEXT:    [[TMP57:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 6, i32 0
 ; CHECK-NEXT:    [[TMP60:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 6, i32 1
@@ -114,7 +113,6 @@ define void @test2() nounwind  {
 ; CHECK-NEXT:    [[TMP141:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 0, i32 0
 ; CHECK-NEXT:    [[TMP144:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 0, i32 1
 ; CHECK-NEXT:    [[TMP148:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 7, i32 0
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 8 [[TMP141]], i8 0, i64 32, i1 false)
 ; CHECK-NEXT:    [[TMP151:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 7, i32 1
 ; CHECK-NEXT:    [[TMP162:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 6, i32 0
 ; CHECK-NEXT:    [[TMP165:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 6, i32 1
@@ -132,6 +130,8 @@ define void @test2() nounwind  {
 ; CHECK-NEXT:    [[TMP249:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 0, i32 1
 ; CHECK-NEXT:    [[UP_MVD252:%.*]] = getelementptr [8 x %struct.MV], ptr [[UP_MVD]], i32 0, i32 0
 ; CHECK-NEXT:    [[LEFT_MVD253:%.*]] = getelementptr [8 x %struct.MV], ptr [[LEFT_MVD]], i32 0, i32 0
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 1 [[TMP41]], i8 -1, i64 8, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 8 [[TMP141]], i8 0, i64 32, i1 false)
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr align 8 [[TMP246]], i8 0, i64 32, i1 false)
 ; CHECK-NEXT:    call void @foo(ptr [[UP_MVD252]], ptr [[LEFT_MVD253]], ptr [[TMP41]]) #[[ATTR0]]
 ; CHECK-NEXT:    ret void
diff --git a/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll b/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
index ca2ffd72818496..7e2d617d1ed275 100644
--- a/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
+++ b/llvm/test/Transforms/MemCpyOpt/merge-into-memset.ll
@@ -40,10 +40,9 @@ define void @memset_clobber_no_alias(ptr %p) {
 ; CHECK-LABEL: @memset_clobber_no_alias(
 ; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
 ; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P:%.*]], i8 0, i64 16, i1 false)
 ; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %stack = alloca <256 x i8>, align 8
@@ -59,10 +58,9 @@ define void @store_clobber_no_alias1(i64 %a, ptr %p) {
 ; CHECK-LABEL: @store_clobber_no_alias1(
 ; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
 ; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[P:%.*]], align 8
 ; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %stack = alloca <256 x i8>, align 8
@@ -78,10 +76,9 @@ define void @store_clobber_no_alias2(i64 %a, ptr %p) {
 ; CHECK-LABEL: @store_clobber_no_alias2(
 ; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
 ; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    store i64 [[A:%.*]], ptr [[P:%.*]], align 8
 ; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %stack = alloca <256 x i8>, align 8
@@ -133,11 +130,10 @@ define void @load_clobber_no_alias(ptr %p, ptr %p1) {
 ; CHECK-LABEL: @load_clobber_no_alias(
 ; CHECK-NEXT:    [[STACK:%.*]] = alloca <256 x i8>, align 8
 ; CHECK-NEXT:    [[STACK1:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 8
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    [[A:%.*]] = load i64, ptr [[P:%.*]], align 8
 ; CHECK-NEXT:    store i64 [[A]], ptr [[P1:%.*]], align 8
 ; CHECK-NEXT:    [[STACK2:%.*]] = getelementptr inbounds i8, ptr [[STACK]], i64 24
-; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK2]], i8 0, i64 24, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[STACK1]], i8 0, i64 136, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %stack = alloca <256 x i8>, align 8



More information about the llvm-commits mailing list