[llvm] 914a399 - [ARM] Avoid clobbering byval arguments when passing to tail-calls
Oliver Stannard via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 25 01:35:02 PDT 2024
Author: Oliver Stannard
Date: 2024-10-25T09:34:09+01:00
New Revision: 914a3990d1a055f5f1848c6979f621ceff97fa7c
URL: https://github.com/llvm/llvm-project/commit/914a3990d1a055f5f1848c6979f621ceff97fa7c
DIFF: https://github.com/llvm/llvm-project/commit/914a3990d1a055f5f1848c6979f621ceff97fa7c.diff
LOG: [ARM] Avoid clobbering byval arguments when passing to tail-calls
When passing byval arguments to tail-calls, we need to store them into
the stack memory in which this the caller received it's arguments. If
any of the outgoing arguments are forwarded from incoming byval
arguments, then the source of the copy is from the same stack memory.
This can result in the copy corrupting a value which is still to be
read.
The fix is to first make a copy of the outgoing byval arguments in local
stack space, and then copy them to their final location. This fixes the
correctness issue, but results in extra copying, which could be
optimised.
Added:
Modified:
llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/ARM/musttail.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 78d2b98b0a84a1..6cc0b681193f55 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -2380,6 +2380,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
MachineFunction &MF = DAG.getMachineFunction();
ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>();
+ MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
MachineFunction::CallSiteInfo CSInfo;
bool isStructRet = (Outs.empty()) ? false : Outs[0].Flags.isSRet();
bool isThisReturn = false;
@@ -2492,6 +2493,45 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
RegsToPassVector RegsToPass;
SmallVector<SDValue, 8> MemOpChains;
+ // If we are doing a tail-call, any byval arguments will be written to stack
+ // space which was used for incoming arguments. If any the values being used
+ // are incoming byval arguments to this function, then they might be
+ // overwritten by the stores of the outgoing arguments. To avoid this, we
+ // need to make a temporary copy of them in local stack space, then copy back
+ // to the argument area.
+ // TODO This could be optimised to skip byvals which are already being copied
+ // from local stack space, or which are copied from the incoming stack at the
+ // exact same location.
+ DenseMap<unsigned, SDValue> ByValTemporaries;
+ SDValue ByValTempChain;
+ if (isTailCall) {
+ for (unsigned ArgIdx = 0, e = OutVals.size(); ArgIdx != e; ++ArgIdx) {
+ SDValue Arg = OutVals[ArgIdx];
+ ISD::ArgFlagsTy Flags = Outs[ArgIdx].Flags;
+
+ if (Flags.isByVal()) {
+ int FrameIdx = MFI.CreateStackObject(
+ Flags.getByValSize(), Flags.getNonZeroByValAlign(), false);
+ SDValue Dst =
+ DAG.getFrameIndex(FrameIdx, getPointerTy(DAG.getDataLayout()));
+
+ SDValue SizeNode = DAG.getConstant(Flags.getByValSize(), dl, MVT::i32);
+ SDValue AlignNode =
+ DAG.getConstant(Flags.getNonZeroByValAlign().value(), dl, MVT::i32);
+
+ SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ SDValue Ops[] = { Chain, Dst, Arg, SizeNode, AlignNode};
+ MemOpChains.push_back(DAG.getNode(ARMISD::COPY_STRUCT_BYVAL, dl, VTs,
+ Ops));
+ ByValTemporaries[ArgIdx] = Dst;
+ }
+ }
+ if (!MemOpChains.empty()) {
+ ByValTempChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, MemOpChains);
+ MemOpChains.clear();
+ }
+ }
+
// During a tail call, stores to the argument area must happen after all of
// the function's incoming arguments have been loaded because they may alias.
// This is done by folding in a TokenFactor from LowerFormalArguments, but
@@ -2529,6 +2569,9 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (isTailCall && VA.isMemLoc() && !AfterFormalArgLoads) {
Chain = DAG.getStackArgumentTokenFactor(Chain);
+ if (ByValTempChain)
+ Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Chain,
+ ByValTempChain);
AfterFormalArgLoads = true;
}
@@ -2600,6 +2643,12 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
unsigned ByValArgsCount = CCInfo.getInRegsParamsCount();
unsigned CurByValIdx = CCInfo.getInRegsParamsProcessed();
+ SDValue ByValSrc;
+ if (ByValTemporaries.contains(realArgIdx))
+ ByValSrc = ByValTemporaries[realArgIdx];
+ else
+ ByValSrc = Arg;
+
if (CurByValIdx < ByValArgsCount) {
unsigned RegBegin, RegEnd;
@@ -2610,7 +2659,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
unsigned int i, j;
for (i = 0, j = RegBegin; j < RegEnd; i++, j++) {
SDValue Const = DAG.getConstant(4*i, dl, MVT::i32);
- SDValue AddArg = DAG.getNode(ISD::ADD, dl, PtrVT, Arg, Const);
+ SDValue AddArg = DAG.getNode(ISD::ADD, dl, PtrVT, ByValSrc, Const);
SDValue Load =
DAG.getLoad(PtrVT, dl, Chain, AddArg, MachinePointerInfo(),
DAG.InferPtrAlign(AddArg));
@@ -2632,7 +2681,7 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
std::tie(Dst, DstInfo) =
computeAddrForCallArg(dl, DAG, VA, StackPtr, isTailCall, SPDiff);
SDValue SrcOffset = DAG.getIntPtrConstant(4*offset, dl);
- SDValue Src = DAG.getNode(ISD::ADD, dl, PtrVT, Arg, SrcOffset);
+ SDValue Src = DAG.getNode(ISD::ADD, dl, PtrVT, ByValSrc, SrcOffset);
SDValue SizeNode = DAG.getConstant(Flags.getByValSize() - 4*offset, dl,
MVT::i32);
SDValue AlignNode =
diff --git a/llvm/test/CodeGen/ARM/musttail.ll b/llvm/test/CodeGen/ARM/musttail.ll
index aecf8e4579b5df..3c577134f696d8 100644
--- a/llvm/test/CodeGen/ARM/musttail.ll
+++ b/llvm/test/CodeGen/ARM/musttail.ll
@@ -139,13 +139,28 @@ define void @large_caller(%twenty_bytes* byval(%twenty_bytes) align 4 %a) {
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: .save {r4, lr}
; CHECK-NEXT: push {r4, lr}
-; CHECK-NEXT: add r12, sp, #8
-; CHECK-NEXT: add lr, sp, #24
+; CHECK-NEXT: .pad #20
+; CHECK-NEXT: sub sp, sp, #20
+; CHECK-NEXT: add r12, sp, #28
+; CHECK-NEXT: add lr, sp, #44
; CHECK-NEXT: stm r12, {r0, r1, r2, r3}
-; CHECK-NEXT: add r12, sp, #8
-; CHECK-NEXT: add r12, r12, #16
+; CHECK-NEXT: add r0, sp, #28
+; CHECK-NEXT: mov r1, sp
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: add r12, r1, #16
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldm sp, {r0, r1, r2, r3}
; CHECK-NEXT: ldr r4, [r12], #4
; CHECK-NEXT: str r4, [lr], #4
+; CHECK-NEXT: add sp, sp, #20
; CHECK-NEXT: pop {r4, lr}
; CHECK-NEXT: add sp, sp, #16
; CHECK-NEXT: b large_callee
@@ -163,17 +178,30 @@ define void @large_caller_check_regs(%twenty_bytes* byval(%twenty_bytes) align 4
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: .save {r4, lr}
; CHECK-NEXT: push {r4, lr}
-; CHECK-NEXT: add r12, sp, #8
-; CHECK-NEXT: add lr, sp, #24
+; CHECK-NEXT: .pad #20
+; CHECK-NEXT: sub sp, sp, #20
+; CHECK-NEXT: add r12, sp, #28
+; CHECK-NEXT: add lr, sp, #44
; CHECK-NEXT: stm r12, {r0, r1, r2, r3}
; CHECK-NEXT: @APP
; CHECK-NEXT: @NO_APP
-; CHECK-NEXT: add r3, sp, #8
-; CHECK-NEXT: add r0, sp, #8
-; CHECK-NEXT: add r12, r0, #16
-; CHECK-NEXT: ldm r3, {r0, r1, r2, r3}
+; CHECK-NEXT: add r0, sp, #28
+; CHECK-NEXT: mov r1, sp
+; CHECK-NEXT: add r12, r1, #16
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldm sp, {r0, r1, r2, r3}
; CHECK-NEXT: ldr r4, [r12], #4
; CHECK-NEXT: str r4, [lr], #4
+; CHECK-NEXT: add sp, sp, #20
; CHECK-NEXT: pop {r4, lr}
; CHECK-NEXT: add sp, sp, #16
; CHECK-NEXT: b large_callee
@@ -190,30 +218,44 @@ entry:
define void @large_caller_new_value(%twenty_bytes* byval(%twenty_bytes) align 4 %a) {
; CHECK-LABEL: large_caller_new_value:
; CHECK: @ %bb.0: @ %entry
-; CHECK-NEXT: .pad #36
-; CHECK-NEXT: sub sp, sp, #36
-; CHECK-NEXT: add r12, sp, #20
+; CHECK-NEXT: .pad #16
+; CHECK-NEXT: sub sp, sp, #16
+; CHECK-NEXT: .save {r4, lr}
+; CHECK-NEXT: push {r4, lr}
+; CHECK-NEXT: .pad #40
+; CHECK-NEXT: sub sp, sp, #40
+; CHECK-NEXT: add r12, sp, #48
+; CHECK-NEXT: add lr, sp, #64
; CHECK-NEXT: stm r12, {r0, r1, r2, r3}
; CHECK-NEXT: mov r0, #4
-; CHECK-NEXT: add r1, sp, #36
-; CHECK-NEXT: str r0, [sp, #16]
+; CHECK-NEXT: mov r1, sp
+; CHECK-NEXT: str r0, [sp, #36]
; CHECK-NEXT: mov r0, #3
-; CHECK-NEXT: str r0, [sp, #12]
+; CHECK-NEXT: str r0, [sp, #32]
; CHECK-NEXT: mov r0, #2
-; CHECK-NEXT: str r0, [sp, #8]
+; CHECK-NEXT: str r0, [sp, #28]
; CHECK-NEXT: mov r0, #1
-; CHECK-NEXT: str r0, [sp, #4]
+; CHECK-NEXT: str r0, [sp, #24]
; CHECK-NEXT: mov r0, #0
-; CHECK-NEXT: str r0, [sp]
-; CHECK-NEXT: mov r0, sp
-; CHECK-NEXT: add r0, r0, #16
-; CHECK-NEXT: mov r3, #3
+; CHECK-NEXT: str r0, [sp, #20]
+; CHECK-NEXT: add r0, sp, #20
+; CHECK-NEXT: add r12, r1, #16
; CHECK-NEXT: ldr r2, [r0], #4
; CHECK-NEXT: str r2, [r1], #4
-; CHECK-NEXT: mov r0, #0
-; CHECK-NEXT: mov r1, #1
-; CHECK-NEXT: mov r2, #2
-; CHECK-NEXT: add sp, sp, #36
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldr r2, [r0], #4
+; CHECK-NEXT: str r2, [r1], #4
+; CHECK-NEXT: ldm sp, {r0, r1, r2, r3}
+; CHECK-NEXT: ldr r4, [r12], #4
+; CHECK-NEXT: str r4, [lr], #4
+; CHECK-NEXT: add sp, sp, #40
+; CHECK-NEXT: pop {r4, lr}
+; CHECK-NEXT: add sp, sp, #16
; CHECK-NEXT: b large_callee
entry:
%y = alloca %twenty_bytes, align 4
@@ -229,3 +271,67 @@ entry:
musttail call void @large_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %y)
ret void
}
+
+declare void @two_byvals_callee(%twenty_bytes* byval(%twenty_bytes) align 4, %twenty_bytes* byval(%twenty_bytes) align 4)
+define void @swap_byvals(%twenty_bytes* byval(%twenty_bytes) align 4 %a, %twenty_bytes* byval(%twenty_bytes) align 4 %b) {
+; CHECK-LABEL: swap_byvals:
+; CHECK: @ %bb.0: @ %entry
+; CHECK-NEXT: .pad #16
+; CHECK-NEXT: sub sp, sp, #16
+; CHECK-NEXT: .save {r4, r5, r11, lr}
+; CHECK-NEXT: push {r4, r5, r11, lr}
+; CHECK-NEXT: .pad #40
+; CHECK-NEXT: sub sp, sp, #40
+; CHECK-NEXT: add r12, sp, #56
+; CHECK-NEXT: add lr, sp, #20
+; CHECK-NEXT: stm r12, {r0, r1, r2, r3}
+; CHECK-NEXT: add r0, sp, #56
+; CHECK-NEXT: mov r12, sp
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: mov r2, r12
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: add r3, sp, #20
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: add r4, sp, #76
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: add r0, sp, #76
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: mov r2, lr
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: ldr r1, [r0], #4
+; CHECK-NEXT: str r1, [r2], #4
+; CHECK-NEXT: ldm r3, {r0, r1, r2, r3}
+; CHECK-NEXT: ldr r5, [r12], #4
+; CHECK-NEXT: str r5, [r4], #4
+; CHECK-NEXT: ldr r5, [r12], #4
+; CHECK-NEXT: str r5, [r4], #4
+; CHECK-NEXT: ldr r5, [r12], #4
+; CHECK-NEXT: str r5, [r4], #4
+; CHECK-NEXT: ldr r5, [r12], #4
+; CHECK-NEXT: str r5, [r4], #4
+; CHECK-NEXT: ldr r5, [r12], #4
+; CHECK-NEXT: str r5, [r4], #4
+; CHECK-NEXT: add r5, lr, #16
+; CHECK-NEXT: add r12, sp, #72
+; CHECK-NEXT: ldr r4, [r5], #4
+; CHECK-NEXT: str r4, [r12], #4
+; CHECK-NEXT: add sp, sp, #40
+; CHECK-NEXT: pop {r4, r5, r11, lr}
+; CHECK-NEXT: add sp, sp, #16
+; CHECK-NEXT: b two_byvals_callee
+entry:
+ musttail call void @two_byvals_callee(%twenty_bytes* byval(%twenty_bytes) align 4 %b, %twenty_bytes* byval(%twenty_bytes) align 4 %a)
+ ret void
+}
More information about the llvm-commits
mailing list