[llvm] [DSE] Fold malloc/store pair into calloc (PR #87048)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 29 02:56:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: XChy (XChy)

<details>
<summary>Changes</summary>

Resolve TODO in DSE.

---
Full diff: https://github.com/llvm/llvm-project/pull/87048.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp (+38-25) 
- (added) llvm/test/Transforms/DeadStoreElimination/malloc-store.ll (+98) 
- (modified) llvm/test/Transforms/DeadStoreElimination/simple.ll (+2-4) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index bfc8bd5970bf27..04ff6faa0db6bd 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -1842,20 +1842,39 @@ struct DSEState {
     return MadeChange;
   }
 
-  /// If we have a zero initializing memset following a call to malloc,
+  /// If we have a zero initializing memset/store following a call to malloc,
   /// try folding it into a call to calloc.
   bool tryFoldIntoCalloc(MemoryDef *Def, const Value *DefUO) {
     Instruction *DefI = Def->getMemoryInst();
-    MemSetInst *MemSet = dyn_cast<MemSetInst>(DefI);
-    if (!MemSet)
-      // TODO: Could handle zero store to small allocation as well.
+    Constant *StoredConstant;
+    Value *PtrToStore;
+
+    auto *Malloc = const_cast<CallInst *>(dyn_cast<CallInst>(DefUO));
+
+    if (!Malloc)
       return false;
-    Constant *StoredConstant = dyn_cast<Constant>(MemSet->getValue());
+
+    if (MemSetInst *MemSet = dyn_cast<MemSetInst>(DefI)) {
+      StoredConstant = dyn_cast<Constant>(MemSet->getValue());
+      PtrToStore = MemSet->getArgOperand(0);
+      if (Malloc->getOperand(0) != MemSet->getLength())
+        return false;
+    } else if (StoreInst *SI = dyn_cast<StoreInst>(DefI)) {
+      StoredConstant = dyn_cast<Constant>(SI->getValueOperand());
+      PtrToStore = SI->getPointerOperand();
+      if (SI->getAccessType()->isScalableTy())
+        return false;
+      uint64_t StoreSize = DL.getTypeStoreSize(SI->getAccessType());
+      if (!match(Malloc->getOperand(0), m_SpecificInt(StoreSize)))
+        return false;
+    } else
+      return false;
+
     if (!StoredConstant || !StoredConstant->isNullValue())
       return false;
 
     if (!isRemovable(DefI))
-      // The memset might be volatile..
+      // The memset/store might be volatile..
       return false;
 
     if (F.hasFnAttribute(Attribute::SanitizeMemory) ||
@@ -1863,9 +1882,7 @@ struct DSEState {
         F.hasFnAttribute(Attribute::SanitizeHWAddress) ||
         F.getName() == "calloc")
       return false;
-    auto *Malloc = const_cast<CallInst *>(dyn_cast<CallInst>(DefUO));
-    if (!Malloc)
-      return false;
+
     auto *InnerCallee = Malloc->getCalledFunction();
     if (!InnerCallee)
       return false;
@@ -1878,30 +1895,27 @@ struct DSEState {
     if (!MallocDef)
       return false;
 
-    auto shouldCreateCalloc = [](CallInst *Malloc, CallInst *Memset) {
+    auto shouldCreateCalloc = [](CallInst *Malloc, Instruction *DefI,
+                                 Value *Ptr) {
       // Check for br(icmp ptr, null), truebb, falsebb) pattern at the end
       // of malloc block
-      auto *MallocBB = Malloc->getParent(),
-        *MemsetBB = Memset->getParent();
-      if (MallocBB == MemsetBB)
+      auto *MallocBB = Malloc->getParent(), *DefBB = DefI->getParent();
+      if (MallocBB == DefBB)
         return true;
-      auto *Ptr = Memset->getArgOperand(0);
       auto *TI = MallocBB->getTerminator();
       ICmpInst::Predicate Pred;
       BasicBlock *TrueBB, *FalseBB;
       if (!match(TI, m_Br(m_ICmp(Pred, m_Specific(Ptr), m_Zero()), TrueBB,
                           FalseBB)))
         return false;
-      if (Pred != ICmpInst::ICMP_EQ || MemsetBB != FalseBB)
+      if (Pred != ICmpInst::ICMP_EQ || DefBB != FalseBB)
         return false;
       return true;
     };
 
-    if (Malloc->getOperand(0) != MemSet->getLength())
-      return false;
-    if (!shouldCreateCalloc(Malloc, MemSet) ||
-        !DT.dominates(Malloc, MemSet) ||
-        !memoryIsNotModifiedBetween(Malloc, MemSet, BatchAA, DL, &DT))
+    if (!shouldCreateCalloc(Malloc, DefI, PtrToStore) ||
+        !DT.dominates(Malloc, DefI) ||
+        !memoryIsNotModifiedBetween(Malloc, DefI, BatchAA, DL, &DT))
       return false;
     IRBuilder<> IRB(Malloc);
     Type *SizeTTy = Malloc->getArgOperand(0)->getType();
@@ -1911,9 +1925,8 @@ struct DSEState {
       return false;
 
     MemorySSAUpdater Updater(&MSSA);
-    auto *NewAccess =
-      Updater.createMemoryAccessAfter(cast<Instruction>(Calloc), nullptr,
-                                      MallocDef);
+    auto *NewAccess = Updater.createMemoryAccessAfter(cast<Instruction>(Calloc),
+                                                      nullptr, MallocDef);
     auto *NewAccessMD = cast<MemoryDef>(NewAccess);
     Updater.insertDef(NewAccessMD, /*RenameUses=*/true);
     Malloc->replaceAllUsesWith(Calloc);
@@ -2288,9 +2301,9 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
       continue;
     }
 
-    // Can we form a calloc from a memset/malloc pair?
+    // Can we form a calloc from a memset/malloc or store/malloc pair?
     if (!Shortend && State.tryFoldIntoCalloc(KillingDef, KillingUndObj)) {
-      LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
+      LLVM_DEBUG(dbgs() << "DSE: Remove memset/store after forming calloc:\n"
                         << "  DEAD: " << *KillingI << '\n');
       State.deleteDeadInstruction(KillingI);
       MadeChange = true;
diff --git a/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll b/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
new file mode 100644
index 00000000000000..70938b60df36e0
--- /dev/null
+++ b/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
@@ -0,0 +1,98 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=dse -S | FileCheck %s
+
+declare noalias ptr @malloc(i64) willreturn allockind("alloc,uninitialized")
+declare void @use(ptr)
+
+define ptr @basic() {
+; CHECK-LABEL: define ptr @basic() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @vec_type() {
+; CHECK-LABEL: define ptr @vec_type() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store <2 x i16> zeroinitializer, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @clobber() {
+; CHECK-LABEL: define ptr @clobber() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
+; CHECK-NEXT:    [[L:%.*]] = load i8, ptr [[PTR]], align 1
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  %l = load i8, ptr %ptr
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @wrong_size() {
+; CHECK-LABEL: define ptr @wrong_size() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i8 0, ptr [[PTR]], align 1
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i8 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @bigstore() {
+; CHECK-LABEL: define ptr @bigstore() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 8096)
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 8096)
+  store <8096 x i8> zeroinitializer, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @nonconstant1(i64 %l) {
+; CHECK-LABEL: define ptr @nonconstant1(
+; CHECK-SAME: i64 [[L:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 [[L]])
+; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 %l)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @nonconstant2(i32 %v) {
+; CHECK-LABEL: define ptr @nonconstant2(
+; CHECK-SAME: i32 [[V:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i32 [[V]], ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 %v, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @clobber_fail(i32 %a) {
+; CHECK-LABEL: define ptr @clobber_fail(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i32 [[A]], ptr [[PTR]], align 4
+; CHECK-NEXT:    call void @use(ptr [[PTR]])
+; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 %a, ptr %ptr
+  call void @use(ptr %ptr)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
diff --git a/llvm/test/Transforms/DeadStoreElimination/simple.ll b/llvm/test/Transforms/DeadStoreElimination/simple.ll
index e5d3dd09fa148d..4a80fc5b141c18 100644
--- a/llvm/test/Transforms/DeadStoreElimination/simple.ll
+++ b/llvm/test/Transforms/DeadStoreElimination/simple.ll
@@ -204,9 +204,8 @@ define void @test_matrix_store(i64 %stride) {
 declare void @may_unwind()
 define ptr @test_malloc_no_escape_before_return() {
 ; CHECK-LABEL: @test_malloc_no_escape_before_return(
-; CHECK-NEXT:    [[PTR:%.*]] = tail call ptr @malloc(i64 4)
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    call void @may_unwind()
-; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
 ; CHECK-NEXT:    ret ptr [[PTR]]
 ;
   %ptr = tail call ptr @malloc(i64 4)
@@ -236,10 +235,9 @@ define ptr @test_custom_malloc_no_escape_before_return() {
 
 define ptr addrspace(1) @test13_addrspacecast() {
 ; CHECK-LABEL: @test13_addrspacecast(
-; CHECK-NEXT:    [[P:%.*]] = tail call ptr @malloc(i64 4)
+; CHECK-NEXT:    [[P:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    [[P_AC:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
 ; CHECK-NEXT:    call void @may_unwind()
-; CHECK-NEXT:    store i32 0, ptr addrspace(1) [[P_AC]], align 4
 ; CHECK-NEXT:    ret ptr addrspace(1) [[P_AC]]
 ;
   %p = tail call ptr @malloc(i64 4)

``````````

</details>


https://github.com/llvm/llvm-project/pull/87048


More information about the llvm-commits mailing list