[llvm] [DSE] Fold malloc/store pair into calloc (PR #87048)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 29 02:56:31 PDT 2024


https://github.com/XChy created https://github.com/llvm/llvm-project/pull/87048

Resolve TODO in DSE.

>From 16f006252e6bc9e61fdcee22b75d35acd8ce77d7 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Fri, 29 Mar 2024 17:19:34 +0800
Subject: [PATCH 1/2] [DSE] Precommit tests

---
 .../DeadStoreElimination/malloc-store.ll      | 102 ++++++++++++++++++
 1 file changed, 102 insertions(+)
 create mode 100644 llvm/test/Transforms/DeadStoreElimination/malloc-store.ll

diff --git a/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll b/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
new file mode 100644
index 00000000000000..65e0656c700a94
--- /dev/null
+++ b/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
@@ -0,0 +1,102 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=dse -S | FileCheck %s
+
+declare noalias ptr @malloc(i64) willreturn allockind("alloc,uninitialized")
+declare void @use(ptr)
+
+define ptr @basic() {
+; CHECK-LABEL: define ptr @basic() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @vec_type() {
+; CHECK-LABEL: define ptr @vec_type() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store <2 x i16> zeroinitializer, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store <2 x i16> zeroinitializer, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @clobber() {
+; CHECK-LABEL: define ptr @clobber() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    [[L:%.*]] = load i8, ptr [[PTR]], align 1
+; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  %l = load i8, ptr %ptr
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @wrong_size() {
+; CHECK-LABEL: define ptr @wrong_size() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i8 0, ptr [[PTR]], align 1
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i8 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @bigstore() {
+; CHECK-LABEL: define ptr @bigstore() {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 8096)
+; CHECK-NEXT:    store <8096 x i8> zeroinitializer, ptr [[PTR]], align 8192
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 8096)
+  store <8096 x i8> zeroinitializer, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @nonconstant1(i64 %l) {
+; CHECK-LABEL: define ptr @nonconstant1(
+; CHECK-SAME: i64 [[L:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 [[L]])
+; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 %l)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @nonconstant2(i32 %v) {
+; CHECK-LABEL: define ptr @nonconstant2(
+; CHECK-SAME: i32 [[V:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i32 [[V]], ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 %v, ptr %ptr
+  ret ptr %ptr
+}
+
+define ptr @clobber_fail(i32 %a) {
+; CHECK-LABEL: define ptr @clobber_fail(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    store i32 [[A]], ptr [[PTR]], align 4
+; CHECK-NEXT:    call void @use(ptr [[PTR]])
+; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret ptr [[PTR]]
+;
+  %ptr = call ptr @malloc(i64 4)
+  store i32 %a, ptr %ptr
+  call void @use(ptr %ptr)
+  store i32 0, ptr %ptr
+  ret ptr %ptr
+}

>From 34c06d9f104d93f06a68b988e73aeaa834b996db Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Fri, 29 Mar 2024 16:51:42 +0800
Subject: [PATCH 2/2] [DSE] Fold malloc/store pair into calloc

---
 .../Scalar/DeadStoreElimination.cpp           | 63 +++++++++++--------
 .../DeadStoreElimination/malloc-store.ll      | 12 ++--
 .../Transforms/DeadStoreElimination/simple.ll |  6 +-
 3 files changed, 44 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index bfc8bd5970bf27..04ff6faa0db6bd 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -1842,20 +1842,39 @@ struct DSEState {
     return MadeChange;
   }
 
-  /// If we have a zero initializing memset following a call to malloc,
+  /// If we have a zero initializing memset/store following a call to malloc,
   /// try folding it into a call to calloc.
   bool tryFoldIntoCalloc(MemoryDef *Def, const Value *DefUO) {
     Instruction *DefI = Def->getMemoryInst();
-    MemSetInst *MemSet = dyn_cast<MemSetInst>(DefI);
-    if (!MemSet)
-      // TODO: Could handle zero store to small allocation as well.
+    Constant *StoredConstant;
+    Value *PtrToStore;
+
+    auto *Malloc = const_cast<CallInst *>(dyn_cast<CallInst>(DefUO));
+
+    if (!Malloc)
       return false;
-    Constant *StoredConstant = dyn_cast<Constant>(MemSet->getValue());
+
+    if (MemSetInst *MemSet = dyn_cast<MemSetInst>(DefI)) {
+      StoredConstant = dyn_cast<Constant>(MemSet->getValue());
+      PtrToStore = MemSet->getArgOperand(0);
+      if (Malloc->getOperand(0) != MemSet->getLength())
+        return false;
+    } else if (StoreInst *SI = dyn_cast<StoreInst>(DefI)) {
+      StoredConstant = dyn_cast<Constant>(SI->getValueOperand());
+      PtrToStore = SI->getPointerOperand();
+      if (SI->getAccessType()->isScalableTy())
+        return false;
+      uint64_t StoreSize = DL.getTypeStoreSize(SI->getAccessType());
+      if (!match(Malloc->getOperand(0), m_SpecificInt(StoreSize)))
+        return false;
+    } else
+      return false;
+
     if (!StoredConstant || !StoredConstant->isNullValue())
       return false;
 
     if (!isRemovable(DefI))
-      // The memset might be volatile..
+      // The memset/store might be volatile..
       return false;
 
     if (F.hasFnAttribute(Attribute::SanitizeMemory) ||
@@ -1863,9 +1882,7 @@ struct DSEState {
         F.hasFnAttribute(Attribute::SanitizeHWAddress) ||
         F.getName() == "calloc")
       return false;
-    auto *Malloc = const_cast<CallInst *>(dyn_cast<CallInst>(DefUO));
-    if (!Malloc)
-      return false;
+
     auto *InnerCallee = Malloc->getCalledFunction();
     if (!InnerCallee)
       return false;
@@ -1878,30 +1895,27 @@ struct DSEState {
     if (!MallocDef)
       return false;
 
-    auto shouldCreateCalloc = [](CallInst *Malloc, CallInst *Memset) {
+    auto shouldCreateCalloc = [](CallInst *Malloc, Instruction *DefI,
+                                 Value *Ptr) {
       // Check for br(icmp ptr, null), truebb, falsebb) pattern at the end
       // of malloc block
-      auto *MallocBB = Malloc->getParent(),
-        *MemsetBB = Memset->getParent();
-      if (MallocBB == MemsetBB)
+      auto *MallocBB = Malloc->getParent(), *DefBB = DefI->getParent();
+      if (MallocBB == DefBB)
         return true;
-      auto *Ptr = Memset->getArgOperand(0);
       auto *TI = MallocBB->getTerminator();
       ICmpInst::Predicate Pred;
       BasicBlock *TrueBB, *FalseBB;
       if (!match(TI, m_Br(m_ICmp(Pred, m_Specific(Ptr), m_Zero()), TrueBB,
                           FalseBB)))
         return false;
-      if (Pred != ICmpInst::ICMP_EQ || MemsetBB != FalseBB)
+      if (Pred != ICmpInst::ICMP_EQ || DefBB != FalseBB)
         return false;
       return true;
     };
 
-    if (Malloc->getOperand(0) != MemSet->getLength())
-      return false;
-    if (!shouldCreateCalloc(Malloc, MemSet) ||
-        !DT.dominates(Malloc, MemSet) ||
-        !memoryIsNotModifiedBetween(Malloc, MemSet, BatchAA, DL, &DT))
+    if (!shouldCreateCalloc(Malloc, DefI, PtrToStore) ||
+        !DT.dominates(Malloc, DefI) ||
+        !memoryIsNotModifiedBetween(Malloc, DefI, BatchAA, DL, &DT))
       return false;
     IRBuilder<> IRB(Malloc);
     Type *SizeTTy = Malloc->getArgOperand(0)->getType();
@@ -1911,9 +1925,8 @@ struct DSEState {
       return false;
 
     MemorySSAUpdater Updater(&MSSA);
-    auto *NewAccess =
-      Updater.createMemoryAccessAfter(cast<Instruction>(Calloc), nullptr,
-                                      MallocDef);
+    auto *NewAccess = Updater.createMemoryAccessAfter(cast<Instruction>(Calloc),
+                                                      nullptr, MallocDef);
     auto *NewAccessMD = cast<MemoryDef>(NewAccess);
     Updater.insertDef(NewAccessMD, /*RenameUses=*/true);
     Malloc->replaceAllUsesWith(Calloc);
@@ -2288,9 +2301,9 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
       continue;
     }
 
-    // Can we form a calloc from a memset/malloc pair?
+    // Can we form a calloc from a memset/malloc or store/malloc pair?
     if (!Shortend && State.tryFoldIntoCalloc(KillingDef, KillingUndObj)) {
-      LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
+      LLVM_DEBUG(dbgs() << "DSE: Remove memset/store after forming calloc:\n"
                         << "  DEAD: " << *KillingI << '\n');
       State.deleteDeadInstruction(KillingI);
       MadeChange = true;
diff --git a/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll b/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
index 65e0656c700a94..70938b60df36e0 100644
--- a/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
+++ b/llvm/test/Transforms/DeadStoreElimination/malloc-store.ll
@@ -6,8 +6,7 @@ declare void @use(ptr)
 
 define ptr @basic() {
 ; CHECK-LABEL: define ptr @basic() {
-; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
-; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    ret ptr [[PTR]]
 ;
   %ptr = call ptr @malloc(i64 4)
@@ -17,8 +16,7 @@ define ptr @basic() {
 
 define ptr @vec_type() {
 ; CHECK-LABEL: define ptr @vec_type() {
-; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
-; CHECK-NEXT:    store <2 x i16> zeroinitializer, ptr [[PTR]], align 4
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    ret ptr [[PTR]]
 ;
   %ptr = call ptr @malloc(i64 4)
@@ -28,9 +26,8 @@ define ptr @vec_type() {
 
 define ptr @clobber() {
 ; CHECK-LABEL: define ptr @clobber() {
-; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 4)
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    [[L:%.*]] = load i8, ptr [[PTR]], align 1
-; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
 ; CHECK-NEXT:    ret ptr [[PTR]]
 ;
   %ptr = call ptr @malloc(i64 4)
@@ -52,8 +49,7 @@ define ptr @wrong_size() {
 
 define ptr @bigstore() {
 ; CHECK-LABEL: define ptr @bigstore() {
-; CHECK-NEXT:    [[PTR:%.*]] = call ptr @malloc(i64 8096)
-; CHECK-NEXT:    store <8096 x i8> zeroinitializer, ptr [[PTR]], align 8192
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 8096)
 ; CHECK-NEXT:    ret ptr [[PTR]]
 ;
   %ptr = call ptr @malloc(i64 8096)
diff --git a/llvm/test/Transforms/DeadStoreElimination/simple.ll b/llvm/test/Transforms/DeadStoreElimination/simple.ll
index e5d3dd09fa148d..4a80fc5b141c18 100644
--- a/llvm/test/Transforms/DeadStoreElimination/simple.ll
+++ b/llvm/test/Transforms/DeadStoreElimination/simple.ll
@@ -204,9 +204,8 @@ define void @test_matrix_store(i64 %stride) {
 declare void @may_unwind()
 define ptr @test_malloc_no_escape_before_return() {
 ; CHECK-LABEL: @test_malloc_no_escape_before_return(
-; CHECK-NEXT:    [[PTR:%.*]] = tail call ptr @malloc(i64 4)
+; CHECK-NEXT:    [[PTR:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    call void @may_unwind()
-; CHECK-NEXT:    store i32 0, ptr [[PTR]], align 4
 ; CHECK-NEXT:    ret ptr [[PTR]]
 ;
   %ptr = tail call ptr @malloc(i64 4)
@@ -236,10 +235,9 @@ define ptr @test_custom_malloc_no_escape_before_return() {
 
 define ptr addrspace(1) @test13_addrspacecast() {
 ; CHECK-LABEL: @test13_addrspacecast(
-; CHECK-NEXT:    [[P:%.*]] = tail call ptr @malloc(i64 4)
+; CHECK-NEXT:    [[P:%.*]] = call ptr @calloc(i64 1, i64 4)
 ; CHECK-NEXT:    [[P_AC:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
 ; CHECK-NEXT:    call void @may_unwind()
-; CHECK-NEXT:    store i32 0, ptr addrspace(1) [[P_AC]], align 4
 ; CHECK-NEXT:    ret ptr addrspace(1) [[P_AC]]
 ;
   %p = tail call ptr @malloc(i64 4)



More information about the llvm-commits mailing list