[llvm] [AArch64] Don't tail call memset if it would convert to a bzero. (PR #98969)
Amara Emerson via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 16 11:17:10 PDT 2024
https://github.com/aemerson updated https://github.com/llvm/llvm-project/pull/98969
>From f9b342324834480eb3403892af205f202300f1b1 Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Mon, 15 Jul 2024 14:48:08 -0700
Subject: [PATCH 1/4] [AArch64] Don't tail call memset if it would convert to a
bzero.
Well, not quite that simple. We can tc memset since it returns the first
argument but bzero doesn't do that and therefore we can end up miscompiling.
rdar://131419786
---
llvm/lib/CodeGen/Analysis.cpp | 7 ++++++-
.../AArch64/no-tail-call-bzero-from-memset.ll | 20 +++++++++++++++++++
2 files changed, 26 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp
index 7fc18639e5852..2a3015866da5f 100644
--- a/llvm/lib/CodeGen/Analysis.cpp
+++ b/llvm/lib/CodeGen/Analysis.cpp
@@ -677,6 +677,8 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
// will be expanded as memcpy in libc, which returns the first
// argument. On other platforms like arm-none-eabi, memcpy may be
// expanded as library call without return value, like __aeabi_memcpy.
+ // Similarly, llvm.memset can be expanded to bzero, which doesn't have a
+ // return value either.
const CallInst *Call = cast<CallInst>(I);
if (Function *F = Call->getCalledFunction()) {
Intrinsic::ID IID = F->getIntrinsicID();
@@ -685,7 +687,10 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
(IID == Intrinsic::memmove &&
TLI.getLibcallName(RTLIB::MEMMOVE) == StringRef("memmove")) ||
(IID == Intrinsic::memset &&
- TLI.getLibcallName(RTLIB::MEMSET) == StringRef("memset"))) &&
+ TLI.getLibcallName(RTLIB::MEMSET) == StringRef("memset") &&
+ (!isa<ConstantInt>(Call->getOperand(1)) ||
+ !cast<ConstantInt>(Call->getOperand(1))->isZero() ||
+ !TLI.getLibcallName(RTLIB::BZERO)))) &&
(RetVal == Call->getArgOperand(0) ||
isPointerBitcastEqualTo(RetVal, Call->getArgOperand(0))))
return true;
diff --git a/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll b/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
new file mode 100644
index 0000000000000..90e641cd4fe3d
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
@@ -0,0 +1,20 @@
+; RUN: llc -o - %s | FileCheck %s
+; RUN: llc -global-isel -o - %s | FileCheck %s
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+target triple = "arm64-apple-macosx15.0.0"
+
+define ptr @test() {
+; CHECK-LABEL: test:
+; CHECK-NOT: b _bzero
+ %1 = tail call ptr @fn(i32 noundef 1) #3
+ tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(1000) %1, i8 noundef 0, i64 noundef 1000, i1 noundef false) #3
+ ret ptr %1
+}
+
+declare ptr @fn(i32 noundef)
+
+; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: write)
+declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg) #2
+
+attributes #2 = { nocallback nofree nounwind willreturn memory(argmem: write) }
+attributes #3 = { nounwind optsize }
>From a316dcf44c9738f39c49c556195df1e231296b0f Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Mon, 15 Jul 2024 15:29:29 -0700
Subject: [PATCH 2/4] Change to positive check in test.
---
llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll b/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
index 90e641cd4fe3d..34c6c63cc1798 100644
--- a/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
+++ b/llvm/test/CodeGen/AArch64/no-tail-call-bzero-from-memset.ll
@@ -5,7 +5,7 @@ target triple = "arm64-apple-macosx15.0.0"
define ptr @test() {
; CHECK-LABEL: test:
-; CHECK-NOT: b _bzero
+; CHECK: bl _bzero
%1 = tail call ptr @fn(i32 noundef 1) #3
tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(1000) %1, i8 noundef 0, i64 noundef 1000, i1 noundef false) #3
ret ptr %1
>From 5aad864dec2a452a1367f118bf8c722317c7cceb Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Tue, 16 Jul 2024 10:28:32 -0700
Subject: [PATCH 3/4] Refactor.
---
llvm/include/llvm/CodeGen/Analysis.h | 14 ++++++--
llvm/include/llvm/CodeGen/SelectionDAG.h | 2 +-
llvm/lib/CodeGen/Analysis.cpp | 36 +++++++++++--------
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 18 +++++++---
.../SelectionDAG/SelectionDAGBuilder.cpp | 8 ++---
llvm/lib/Target/X86/X86SelectionDAGInfo.cpp | 2 +-
6 files changed, 53 insertions(+), 27 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/Analysis.h b/llvm/include/llvm/CodeGen/Analysis.h
index 6f7ed22b8ac71..9825e1ddd0bdf 100644
--- a/llvm/include/llvm/CodeGen/Analysis.h
+++ b/llvm/include/llvm/CodeGen/Analysis.h
@@ -126,7 +126,8 @@ ICmpInst::Predicate getICmpCondCode(ISD::CondCode Pred);
/// between it and the return.
///
/// This function only tests target-independent requirements.
-bool isInTailCallPosition(const CallBase &Call, const TargetMachine &TM);
+bool isInTailCallPosition(const CallBase &Call, const TargetMachine &TM,
+ bool ReturnsFirstArg = false);
/// Test if given that the input instruction is in the tail call position, if
/// there is an attribute mismatch between the caller and the callee that will
@@ -144,7 +145,16 @@ bool attributesPermitTailCall(const Function *F, const Instruction *I,
/// optimization.
bool returnTypeIsEligibleForTailCall(const Function *F, const Instruction *I,
const ReturnInst *Ret,
- const TargetLoweringBase &TLI);
+ const TargetLoweringBase &TLI,
+ bool ReturnsFirstArg = false);
+
+/// Check whether B is a bitcast of a pointer type to another pointer type,
+/// which is equal to A.
+bool isPointerBitcastEqualTo(const Value *A, const Value *B);
+
+/// Returns true if the parent of \p CI returns CI's first argument after
+/// calling \p CI.
+bool funcReturnsFirstArgOfCall(const CallInst &CI);
DenseMap<const MachineBasicBlock *, int>
getEHScopeMembership(const MachineFunction &MF);
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 8e189e9e8bf86..5484921973642 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1199,7 +1199,7 @@ class SelectionDAG {
SDValue getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Src,
SDValue Size, Align Alignment, bool isVol,
- bool AlwaysInline, bool isTailCall,
+ bool AlwaysInline, const CallInst *CI,
MachinePointerInfo DstPtrInfo,
const AAMDNodes &AAInfo = AAMDNodes());
diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp
index 2a3015866da5f..7514e33d47626 100644
--- a/llvm/lib/CodeGen/Analysis.cpp
+++ b/llvm/lib/CodeGen/Analysis.cpp
@@ -532,7 +532,8 @@ static bool nextRealType(SmallVectorImpl<Type *> &SubTypes,
/// between it and the return.
///
/// This function only tests target-independent requirements.
-bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM) {
+bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM,
+ bool ReturnsFirstArg) {
const BasicBlock *ExitBB = Call.getParent();
const Instruction *Term = ExitBB->getTerminator();
const ReturnInst *Ret = dyn_cast<ReturnInst>(Term);
@@ -575,7 +576,8 @@ bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM) {
const Function *F = ExitBB->getParent();
return returnTypeIsEligibleForTailCall(
- F, &Call, Ret, *TM.getSubtargetImpl(*F)->getTargetLowering());
+ F, &Call, Ret, *TM.getSubtargetImpl(*F)->getTargetLowering(),
+ ReturnsFirstArg);
}
bool llvm::attributesPermitTailCall(const Function *F, const Instruction *I,
@@ -638,9 +640,7 @@ bool llvm::attributesPermitTailCall(const Function *F, const Instruction *I,
return CallerAttrs == CalleeAttrs;
}
-/// Check whether B is a bitcast of a pointer type to another pointer type,
-/// which is equal to A.
-static bool isPointerBitcastEqualTo(const Value *A, const Value *B) {
+bool llvm::isPointerBitcastEqualTo(const Value *A, const Value *B) {
assert(A && B && "Expected non-null inputs!");
auto *BitCastIn = dyn_cast<BitCastInst>(B);
@@ -657,7 +657,8 @@ static bool isPointerBitcastEqualTo(const Value *A, const Value *B) {
bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
const Instruction *I,
const ReturnInst *Ret,
- const TargetLoweringBase &TLI) {
+ const TargetLoweringBase &TLI,
+ bool ReturnsFirstArg) {
// If the block ends with a void return or unreachable, it doesn't matter
// what the call's return type is.
if (!Ret || Ret->getNumOperands() == 0) return true;
@@ -671,26 +672,23 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
if (!attributesPermitTailCall(F, I, Ret, TLI, &AllowDifferingSizes))
return false;
+ // If the return value is the first argument of the call.
+ if (ReturnsFirstArg)
+ return true;
+
const Value *RetVal = Ret->getOperand(0), *CallVal = I;
// Intrinsic like llvm.memcpy has no return value, but the expanded
// libcall may or may not have return value. On most platforms, it
// will be expanded as memcpy in libc, which returns the first
// argument. On other platforms like arm-none-eabi, memcpy may be
// expanded as library call without return value, like __aeabi_memcpy.
- // Similarly, llvm.memset can be expanded to bzero, which doesn't have a
- // return value either.
const CallInst *Call = cast<CallInst>(I);
if (Function *F = Call->getCalledFunction()) {
Intrinsic::ID IID = F->getIntrinsicID();
if (((IID == Intrinsic::memcpy &&
TLI.getLibcallName(RTLIB::MEMCPY) == StringRef("memcpy")) ||
(IID == Intrinsic::memmove &&
- TLI.getLibcallName(RTLIB::MEMMOVE) == StringRef("memmove")) ||
- (IID == Intrinsic::memset &&
- TLI.getLibcallName(RTLIB::MEMSET) == StringRef("memset") &&
- (!isa<ConstantInt>(Call->getOperand(1)) ||
- !cast<ConstantInt>(Call->getOperand(1))->isZero() ||
- !TLI.getLibcallName(RTLIB::BZERO)))) &&
+ TLI.getLibcallName(RTLIB::MEMMOVE) == StringRef("memmove"))) &&
(RetVal == Call->getArgOperand(0) ||
isPointerBitcastEqualTo(RetVal, Call->getArgOperand(0))))
return true;
@@ -744,6 +742,16 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
return true;
}
+bool llvm::funcReturnsFirstArgOfCall(const CallInst &CI) {
+ const ReturnInst *Ret = dyn_cast<ReturnInst>(CI.getParent()->getTerminator());
+ Value *RetVal = Ret ? Ret->getReturnValue() : nullptr;
+ bool ReturnsFirstArg = false;
+ if (RetVal && ((RetVal == CI.getArgOperand(0) ||
+ isPointerBitcastEqualTo(RetVal, CI.getArgOperand(0)))))
+ ReturnsFirstArg = true;
+ return ReturnsFirstArg;
+}
+
static void collectEHScopeMembers(
DenseMap<const MachineBasicBlock *, int> &EHScopeMembership, int EHScope,
const MachineBasicBlock *MBB) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 897bdc71818f8..fcceadf7f9ce8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8454,7 +8454,8 @@ SDValue SelectionDAG::getAtomicMemmove(SDValue Chain, const SDLoc &dl,
SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
SDValue Src, SDValue Size, Align Alignment,
- bool isVol, bool AlwaysInline, bool isTailCall,
+ bool isVol, bool AlwaysInline,
+ const CallInst *CI,
MachinePointerInfo DstPtrInfo,
const AAMDNodes &AAInfo) {
// Check to see if we should lower the memset to stores first.
@@ -8514,8 +8515,9 @@ SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
return Entry;
};
+ bool UseBZero = isNullConstant(Src) && BzeroName;
// If zeroing out and bzero is present, use it.
- if (isNullConstant(Src) && BzeroName) {
+ if (UseBZero) {
TargetLowering::ArgListTy Args;
Args.push_back(CreateEntry(Dst, PointerType::getUnqual(Ctx)));
Args.push_back(CreateEntry(Size, DL.getIntPtrType(Ctx)));
@@ -8533,8 +8535,16 @@ SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
TLI->getPointerTy(DL)),
std::move(Args));
}
-
- CLI.setDiscardResult().setTailCall(isTailCall);
+ bool LowersToMemset =
+ TLI->getLibcallName(RTLIB::MEMSET) == StringRef("memset");
+ // If we're going to use bzero, make sure not to tail call unless the
+ // subsequent return doesn't need a value, as bzero doesn't return the first
+ // arg unlike memset.
+ bool ReturnsFirstArg = CI && funcReturnsFirstArgOfCall(*CI) && !UseBZero;
+ bool IsTailCall =
+ CI && CI->isTailCall() &&
+ isInTailCallPosition(*CI, getTarget(), ReturnsFirstArg && LowersToMemset);
+ CLI.setDiscardResult().setTailCall(IsTailCall);
std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
return CallResult.second;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index b0746014daf5a..b17a5815c19b3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6500,11 +6500,10 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
// @llvm.memset defines 0 and 1 to both mean no alignment.
Align Alignment = MSI.getDestAlign().valueOrOne();
bool isVol = MSI.isVolatile();
- bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
SDValue Root = isVol ? getRoot() : getMemoryRoot();
SDValue MS = DAG.getMemset(
- Root, sdl, Op1, Op2, Op3, Alignment, isVol, /* AlwaysInline */ false,
- isTC, MachinePointerInfo(I.getArgOperand(0)), I.getAAMetadata());
+ Root, sdl, Op1, Op2, Op3, Alignment, isVol, /* AlwaysInline */ false, &I,
+ MachinePointerInfo(I.getArgOperand(0)), I.getAAMetadata());
updateDAGForMaybeTailCall(MS);
return;
}
@@ -6517,10 +6516,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
// @llvm.memset defines 0 and 1 to both mean no alignment.
Align DstAlign = MSII.getDestAlign().valueOrOne();
bool isVol = MSII.isVolatile();
- bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
SDValue Root = isVol ? getRoot() : getMemoryRoot();
SDValue MC = DAG.getMemset(Root, sdl, Dst, Value, Size, DstAlign, isVol,
- /* AlwaysInline */ true, isTC,
+ /* AlwaysInline */ true, &I,
MachinePointerInfo(I.getArgOperand(0)),
I.getAAMetadata());
updateDAGForMaybeTailCall(MC);
diff --git a/llvm/lib/Target/X86/X86SelectionDAGInfo.cpp b/llvm/lib/Target/X86/X86SelectionDAGInfo.cpp
index e5f07f230fe6c..8f1acd1f1cd60 100644
--- a/llvm/lib/Target/X86/X86SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/X86/X86SelectionDAGInfo.cpp
@@ -147,7 +147,7 @@ SDValue X86SelectionDAGInfo::EmitTargetCodeForMemset(
DAG.getConstant(Offset, dl, AddrVT)),
Val, DAG.getConstant(BytesLeft, dl, SizeVT), Alignment,
isVolatile, AlwaysInline,
- /* isTailCall */ false, DstPtrInfo.getWithOffset(Offset)));
+ /* CI */ nullptr, DstPtrInfo.getWithOffset(Offset)));
return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Results);
}
>From 4e261c8a9632adc451f6823c3b8ecc90e5039916 Mon Sep 17 00:00:00 2001
From: Amara Emerson <amara at apple.com>
Date: Tue, 16 Jul 2024 11:16:53 -0700
Subject: [PATCH 4/4] clang-format
---
llvm/lib/CodeGen/Analysis.cpp | 14 +++++++-------
.../CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 4 ++--
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp
index 7514e33d47626..a34fc6c28c430 100644
--- a/llvm/lib/CodeGen/Analysis.cpp
+++ b/llvm/lib/CodeGen/Analysis.cpp
@@ -743,13 +743,13 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
}
bool llvm::funcReturnsFirstArgOfCall(const CallInst &CI) {
- const ReturnInst *Ret = dyn_cast<ReturnInst>(CI.getParent()->getTerminator());
- Value *RetVal = Ret ? Ret->getReturnValue() : nullptr;
- bool ReturnsFirstArg = false;
- if (RetVal && ((RetVal == CI.getArgOperand(0) ||
- isPointerBitcastEqualTo(RetVal, CI.getArgOperand(0)))))
- ReturnsFirstArg = true;
- return ReturnsFirstArg;
+ const ReturnInst *Ret = dyn_cast<ReturnInst>(CI.getParent()->getTerminator());
+ Value *RetVal = Ret ? Ret->getReturnValue() : nullptr;
+ bool ReturnsFirstArg = false;
+ if (RetVal && ((RetVal == CI.getArgOperand(0) ||
+ isPointerBitcastEqualTo(RetVal, CI.getArgOperand(0)))))
+ ReturnsFirstArg = true;
+ return ReturnsFirstArg;
}
static void collectEHScopeMembers(
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index b17a5815c19b3..155d33ce78fa6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6502,8 +6502,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
bool isVol = MSI.isVolatile();
SDValue Root = isVol ? getRoot() : getMemoryRoot();
SDValue MS = DAG.getMemset(
- Root, sdl, Op1, Op2, Op3, Alignment, isVol, /* AlwaysInline */ false, &I,
- MachinePointerInfo(I.getArgOperand(0)), I.getAAMetadata());
+ Root, sdl, Op1, Op2, Op3, Alignment, isVol, /* AlwaysInline */ false,
+ &I, MachinePointerInfo(I.getArgOperand(0)), I.getAAMetadata());
updateDAGForMaybeTailCall(MS);
return;
}
More information about the llvm-commits
mailing list