[llvm] [AArch64] Don't tail call memset if it would convert to a bzero. (PR #98969)

Amara Emerson via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 16 10:28:55 PDT 2024


https://github.com/aemerson updated https://github.com/llvm/llvm-project/pull/98969

>From f9b342324834480eb3403892af205f202300f1b1 Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Mon, 15 Jul 2024 14:48:08 -0700
Subject: [PATCH 1/3] [AArch64] Don't tail call memset if it would convert to a
 bzero.

Well, not quite that simple. We can tc memset since it returns the first
argument but bzero doesn't do that and therefore we can end up miscompiling.

rdar://131419786
---
 llvm/lib/CodeGen/Analysis.cpp                 |  7 ++++++-
 .../AArch64/no-tail-call-bzero-from-memset.ll | 20 +++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll

diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp
index 7fc18639e5852..2a3015866da5f 100644
--- a/llvm/lib/CodeGen/Analysis.cpp
+++ b/llvm/lib/CodeGen/Analysis.cpp
@@ -677,6 +677,8 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
   // will be expanded as memcpy in libc, which returns the first
   // argument. On other platforms like arm-none-eabi, memcpy may be
   // expanded as library call without return value, like __aeabi_memcpy.
+  // Similarly, llvm.memset can be expanded to bzero, which doesn't have a
+  // return value either.
   const CallInst *Call = cast<CallInst>(I);
   if (Function *F = Call->getCalledFunction()) {
     Intrinsic::ID IID = F->getIntrinsicID();
@@ -685,7 +687,10 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
          (IID == Intrinsic::memmove &&
           TLI.getLibcallName(RTLIB::MEMMOVE) == StringRef("memmove")) ||
          (IID == Intrinsic::memset &&
-          TLI.getLibcallName(RTLIB::MEMSET) == StringRef("memset"))) &&
+          TLI.getLibcallName(RTLIB::MEMSET) == StringRef("memset") &&
+          (!isa<ConstantInt>(Call->getOperand(1)) ||
+           !cast<ConstantInt>(Call->getOperand(1))->isZero() ||
+           !TLI.getLibcallName(RTLIB::BZERO)))) &&
         (RetVal == Call->getArgOperand(0) ||
          isPointerBitcastEqualTo(RetVal, Call->getArgOperand(0))))
       return true;
diff --git a/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll b/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
new file mode 100644
index 0000000000000..90e641cd4fe3d
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
@@ -0,0 +1,20 @@
+; RUN: llc -o - %s | FileCheck %s
+; RUN: llc -global-isel -o - %s | FileCheck %s
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+target triple = "arm64-apple-macosx15.0.0"
+
+define ptr @test()  {
+; CHECK-LABEL: test:
+; CHECK-NOT: b _bzero
+  %1 = tail call ptr @fn(i32 noundef 1) #3
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(1000) %1, i8 noundef 0, i64 noundef 1000, i1 noundef false) #3
+  ret ptr %1
+}
+
+declare ptr @fn(i32 noundef)
+
+; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: write)
+declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg) #2
+
+attributes #2 = { nocallback nofree nounwind willreturn memory(argmem: write) }
+attributes #3 = { nounwind optsize }

>From a316dcf44c9738f39c49c556195df1e231296b0f Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Mon, 15 Jul 2024 15:29:29 -0700
Subject: [PATCH 2/3] Change to positive check in test.

---
 llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll b/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
index 90e641cd4fe3d..34c6c63cc1798 100644
--- a/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
+++ b/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
@@ -5,7 +5,7 @@ target triple = "arm64-apple-macosx15.0.0"
 
 define ptr @test()  {
 ; CHECK-LABEL: test:
-; CHECK-NOT: b _bzero
+; CHECK: bl _bzero
   %1 = tail call ptr @fn(i32 noundef 1) #3
   tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(1000) %1, i8 noundef 0, i64 noundef 1000, i1 noundef false) #3
   ret ptr %1

>From 5aad864dec2a452a1367f118bf8c722317c7cceb Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Tue, 16 Jul 2024 10:28:32 -0700
Subject: [PATCH 3/3] Refactor.

---
 llvm/include/llvm/CodeGen/Analysis.h          | 14 ++++++--
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  2 +-
 llvm/lib/CodeGen/Analysis.cpp                 | 36 +++++++++++--------
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 18 +++++++---
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  8 ++---
 llvm/lib/Target/X86/X86SelectionDAGInfo.cpp   |  2 +-
 6 files changed, 53 insertions(+), 27 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/Analysis.h b/llvm/include/llvm/CodeGen/Analysis.h
index 6f7ed22b8ac71..9825e1ddd0bdf 100644
--- a/llvm/include/llvm/CodeGen/Analysis.h
+++ b/llvm/include/llvm/CodeGen/Analysis.h
@@ -126,7 +126,8 @@ ICmpInst::Predicate getICmpCondCode(ISD::CondCode Pred);
 /// between it and the return.
 ///
 /// This function only tests target-independent requirements.
-bool isInTailCallPosition(const CallBase &Call, const TargetMachine &TM);
+bool isInTailCallPosition(const CallBase &Call, const TargetMachine &TM,
+                          bool ReturnsFirstArg = false);
 
 /// Test if given that the input instruction is in the tail call position, if
 /// there is an attribute mismatch between the caller and the callee that will
@@ -144,7 +145,16 @@ bool attributesPermitTailCall(const Function *F, const Instruction *I,
 /// optimization.
 bool returnTypeIsEligibleForTailCall(const Function *F, const Instruction *I,
                                      const ReturnInst *Ret,
-                                     const TargetLoweringBase &TLI);
+                                     const TargetLoweringBase &TLI,
+                                     bool ReturnsFirstArg = false);
+
+/// Check whether B is a bitcast of a pointer type to another pointer type,
+/// which is equal to A.
+bool isPointerBitcastEqualTo(const Value *A, const Value *B);
+
+/// Returns true if the parent of \p CI returns CI's first argument after
+/// calling \p CI.
+bool funcReturnsFirstArgOfCall(const CallInst &CI);
 
 DenseMap<const MachineBasicBlock *, int>
 getEHScopeMembership(const MachineFunction &MF);
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 8e189e9e8bf86..5484921973642 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1199,7 +1199,7 @@ class SelectionDAG {
 
   SDValue getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Src,
                     SDValue Size, Align Alignment, bool isVol,
-                    bool AlwaysInline, bool isTailCall,
+                    bool AlwaysInline, const CallInst *CI,
                     MachinePointerInfo DstPtrInfo,
                     const AAMDNodes &AAInfo = AAMDNodes());
 
diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp
index 2a3015866da5f..7514e33d47626 100644
--- a/llvm/lib/CodeGen/Analysis.cpp
+++ b/llvm/lib/CodeGen/Analysis.cpp
@@ -532,7 +532,8 @@ static bool nextRealType(SmallVectorImpl<Type *> &SubTypes,
 /// between it and the return.
 ///
 /// This function only tests target-independent requirements.
-bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM) {
+bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM,
+                                bool ReturnsFirstArg) {
   const BasicBlock *ExitBB = Call.getParent();
   const Instruction *Term = ExitBB->getTerminator();
   const ReturnInst *Ret = dyn_cast<ReturnInst>(Term);
@@ -575,7 +576,8 @@ bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM) {
 
   const Function *F = ExitBB->getParent();
   return returnTypeIsEligibleForTailCall(
-      F, &Call, Ret, *TM.getSubtargetImpl(*F)->getTargetLowering());
+      F, &Call, Ret, *TM.getSubtargetImpl(*F)->getTargetLowering(),
+      ReturnsFirstArg);
 }
 
 bool llvm::attributesPermitTailCall(const Function *F, const Instruction *I,
@@ -638,9 +640,7 @@ bool llvm::attributesPermitTailCall(const Function *F, const Instruction *I,
   return CallerAttrs == CalleeAttrs;
 }
 
-/// Check whether B is a bitcast of a pointer type to another pointer type,
-/// which is equal to A.
-static bool isPointerBitcastEqualTo(const Value *A, const Value *B) {
+bool llvm::isPointerBitcastEqualTo(const Value *A, const Value *B) {
   assert(A && B && "Expected non-null inputs!");
 
   auto *BitCastIn = dyn_cast<BitCastInst>(B);
@@ -657,7 +657,8 @@ static bool isPointerBitcastEqualTo(const Value *A, const Value *B) {
 bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
                                            const Instruction *I,
                                            const ReturnInst *Ret,
-                                           const TargetLoweringBase &TLI) {
+                                           const TargetLoweringBase &TLI,
+                                           bool ReturnsFirstArg) {
   // If the block ends with a void return or unreachable, it doesn't matter
   // what the call's return type is.
   if (!Ret || Ret->getNumOperands() == 0) return true;
@@ -671,26 +672,23 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
   if (!attributesPermitTailCall(F, I, Ret, TLI, &AllowDifferingSizes))
     return false;
 
+  // If the return value is the first argument of the call.
+  if (ReturnsFirstArg)
+    return true;
+
   const Value *RetVal = Ret->getOperand(0), *CallVal = I;
   // Intrinsic like llvm.memcpy has no return value, but the expanded
   // libcall may or may not have return value. On most platforms, it
   // will be expanded as memcpy in libc, which returns the first
   // argument. On other platforms like arm-none-eabi, memcpy may be
   // expanded as library call without return value, like __aeabi_memcpy.
-  // Similarly, llvm.memset can be expanded to bzero, which doesn't have a
-  // return value either.
   const CallInst *Call = cast<CallInst>(I);
   if (Function *F = Call->getCalledFunction()) {
     Intrinsic::ID IID = F->getIntrinsicID();
     if (((IID == Intrinsic::memcpy &&
           TLI.getLibcallName(RTLIB::MEMCPY) == StringRef("memcpy")) ||
          (IID == Intrinsic::memmove &&
-          TLI.getLibcallName(RTLIB::MEMMOVE) == StringRef("memmove")) ||
-         (IID == Intrinsic::memset &&
-          TLI.getLibcallName(RTLIB::MEMSET) == StringRef("memset") &&
-          (!isa<ConstantInt>(Call->getOperand(1)) ||
-           !cast<ConstantInt>(Call->getOperand(1))->isZero() ||
-           !TLI.getLibcallName(RTLIB::BZERO)))) &&
+          TLI.getLibcallName(RTLIB::MEMMOVE) == StringRef("memmove"))) &&
         (RetVal == Call->getArgOperand(0) ||
          isPointerBitcastEqualTo(RetVal, Call->getArgOperand(0))))
       return true;
@@ -744,6 +742,16 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
   return true;
 }
 
+bool llvm::funcReturnsFirstArgOfCall(const CallInst &CI) {
+    const ReturnInst *Ret = dyn_cast<ReturnInst>(CI.getParent()->getTerminator());
+    Value *RetVal = Ret ? Ret->getReturnValue() : nullptr;
+    bool ReturnsFirstArg = false;
+    if (RetVal && ((RetVal == CI.getArgOperand(0) ||
+                    isPointerBitcastEqualTo(RetVal, CI.getArgOperand(0)))))
+      ReturnsFirstArg = true;
+    return ReturnsFirstArg;
+}
+
 static void collectEHScopeMembers(
     DenseMap<const MachineBasicBlock *, int> &EHScopeMembership, int EHScope,
     const MachineBasicBlock *MBB) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 897bdc71818f8..fcceadf7f9ce8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8454,7 +8454,8 @@ SDValue SelectionDAG::getAtomicMemmove(SDValue Chain, const SDLoc &dl,
 
 SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
                                 SDValue Src, SDValue Size, Align Alignment,
-                                bool isVol, bool AlwaysInline, bool isTailCall,
+                                bool isVol, bool AlwaysInline,
+                                const CallInst *CI,
                                 MachinePointerInfo DstPtrInfo,
                                 const AAMDNodes &AAInfo) {
   // Check to see if we should lower the memset to stores first.
@@ -8514,8 +8515,9 @@ SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
     return Entry;
   };
 
+  bool UseBZero = isNullConstant(Src) && BzeroName;
   // If zeroing out and bzero is present, use it.
-  if (isNullConstant(Src) && BzeroName) {
+  if (UseBZero) {
     TargetLowering::ArgListTy Args;
     Args.push_back(CreateEntry(Dst, PointerType::getUnqual(Ctx)));
     Args.push_back(CreateEntry(Size, DL.getIntPtrType(Ctx)));
@@ -8533,8 +8535,16 @@ SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
                                        TLI->getPointerTy(DL)),
                      std::move(Args));
   }
-
-  CLI.setDiscardResult().setTailCall(isTailCall);
+  bool LowersToMemset =
+      TLI->getLibcallName(RTLIB::MEMSET) == StringRef("memset");
+  // If we're going to use bzero, make sure not to tail call unless the
+  // subsequent return doesn't need a value, as bzero doesn't return the first
+  // arg unlike memset.
+  bool ReturnsFirstArg = CI && funcReturnsFirstArgOfCall(*CI) && !UseBZero;
+  bool IsTailCall =
+      CI && CI->isTailCall() &&
+      isInTailCallPosition(*CI, getTarget(), ReturnsFirstArg && LowersToMemset);
+  CLI.setDiscardResult().setTailCall(IsTailCall);
 
   std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
   return CallResult.second;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index b0746014daf5a..b17a5815c19b3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6500,11 +6500,10 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     // @llvm.memset defines 0 and 1 to both mean no alignment.
     Align Alignment = MSI.getDestAlign().valueOrOne();
     bool isVol = MSI.isVolatile();
-    bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
     SDValue Root = isVol ? getRoot() : getMemoryRoot();
     SDValue MS = DAG.getMemset(
-        Root, sdl, Op1, Op2, Op3, Alignment, isVol, /* AlwaysInline */ false,
-        isTC, MachinePointerInfo(I.getArgOperand(0)), I.getAAMetadata());
+        Root, sdl, Op1, Op2, Op3, Alignment, isVol, /* AlwaysInline */ false, &I,
+        MachinePointerInfo(I.getArgOperand(0)), I.getAAMetadata());
     updateDAGForMaybeTailCall(MS);
     return;
   }
@@ -6517,10 +6516,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     // @llvm.memset defines 0 and 1 to both mean no alignment.
     Align DstAlign = MSII.getDestAlign().valueOrOne();
     bool isVol = MSII.isVolatile();
-    bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
     SDValue Root = isVol ? getRoot() : getMemoryRoot();
     SDValue MC = DAG.getMemset(Root, sdl, Dst, Value, Size, DstAlign, isVol,
-                               /* AlwaysInline */ true, isTC,
+                               /* AlwaysInline */ true, &I,
                                MachinePointerInfo(I.getArgOperand(0)),
                                I.getAAMetadata());
     updateDAGForMaybeTailCall(MC);
diff --git a/llvm/lib/Target/X86/X86SelectionDAGInfo.cpp b/llvm/lib/Target/X86/X86SelectionDAGInfo.cpp
index e5f07f230fe6c..8f1acd1f1cd60 100644
--- a/llvm/lib/Target/X86/X86SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/X86/X86SelectionDAGInfo.cpp
@@ -147,7 +147,7 @@ SDValue X86SelectionDAGInfo::EmitTargetCodeForMemset(
                                 DAG.getConstant(Offset, dl, AddrVT)),
                     Val, DAG.getConstant(BytesLeft, dl, SizeVT), Alignment,
                     isVolatile, AlwaysInline,
-                    /* isTailCall */ false, DstPtrInfo.getWithOffset(Offset)));
+                    /* CI */ nullptr, DstPtrInfo.getWithOffset(Offset)));
 
   return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Results);
 }



More information about the llvm-commits mailing list