[llvm] [AArch64][SME] Reuse ZT0 spill slot (PR #158593)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 15 03:14:22 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
Previously, we'd allocate a new spill slot each time we needed to spill ZT0, which grows the stack size for each spill. Saving the spill slot in FuncInfo will also allow us to reload the spill on entry to exception handlers.
---
Full diff: https://github.com/llvm/llvm-project/pull/158593.diff
4 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+12-4)
- (modified) llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h (+12)
- (modified) llvm/test/CodeGen/AArch64/sme-peephole-opts.ll (+5-6)
- (modified) llvm/test/CodeGen/AArch64/sme-zt0-state.ll (+40)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c9a756da0078d..2e1e44ec48eb8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8023,6 +8023,17 @@ static bool isPassedInFPR(EVT VT) {
(VT.isFloatingPoint() && !VT.isScalableVector());
}
+static SDValue getZT0FrameIndex(MachineFrameInfo &MFI,
+ AArch64FunctionInfo &FuncInfo,
+ SelectionDAG &DAG) {
+ if (!FuncInfo.hasZT0SpillSlotIndex())
+ FuncInfo.setZT0SpillSlotIndex(MFI.CreateSpillStackObject(64, Align(16)));
+
+ return DAG.getFrameIndex(
+ FuncInfo.getZT0SpillSlotIndex(),
+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+}
+
SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
SelectionDAG &DAG) const {
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
@@ -9427,10 +9438,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
if (ShouldPreserveZT0) {
- unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
- ZTFrameIdx = DAG.getFrameIndex(
- ZTObj,
- DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+ ZTFrameIdx = getZT0FrameIndex(MFI, *FuncInfo, DAG);
Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
{Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index 98fd018bf33a9..897c7e8539608 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -239,6 +239,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// support).
Register EarlyAllocSMESaveBuffer = AArch64::NoRegister;
+ // Holds the spill slot for ZT0.
+ int ZT0SpillSlotIndex = std::numeric_limits<int>::max();
+
// Note: The following properties are only used for the old SME ABI lowering:
/// The frame-index for the TPIDR2 object used for lazy saves.
TPIDR2Object TPIDR2;
@@ -265,6 +268,15 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
return EarlyAllocSMESaveBuffer;
}
+ void setZT0SpillSlotIndex(int FI) { ZT0SpillSlotIndex = FI; }
+ int getZT0SpillSlotIndex() const {
+ assert(hasZT0SpillSlotIndex() && "ZT0 spill slot index not set!");
+ return ZT0SpillSlotIndex;
+ }
+ bool hasZT0SpillSlotIndex() const {
+ return ZT0SpillSlotIndex != std::numeric_limits<int>::max();
+ }
+
// Old SME ABI lowering state getters/setters:
Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
diff --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
index 80827c2547780..062b68e5909f3 100644
--- a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
@@ -224,22 +224,21 @@ define float @test6(float %f) nounwind "aarch64_pstate_sm_enabled" {
define void @test7() nounwind "aarch64_inout_zt0" {
; CHECK-LABEL: test7:
; CHECK: // %bb.0:
-; CHECK-NEXT: sub sp, sp, #144
-; CHECK-NEXT: stp x30, x19, [sp, #128] // 16-byte Folded Spill
-; CHECK-NEXT: add x19, sp, #64
+; CHECK-NEXT: sub sp, sp, #80
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
-; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
-; CHECK-NEXT: ldp x30, x19, [sp, #128] // 16-byte Folded Reload
-; CHECK-NEXT: add sp, sp, #144
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @callee()
call void @callee()
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 49eb368662b5d..2583a93e514a2 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -386,3 +386,43 @@ define void @shared_za_new_zt0(ptr %callee) "aarch64_inout_za" "aarch64_new_zt0"
call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
ret void;
}
+
+
+define void @zt0_multiple_private_za_calls(ptr %callee) "aarch64_in_zt0" nounwind {
+; CHECK-COMMON-LABEL: zt0_multiple_private_za_calls:
+; CHECK-COMMON: // %bb.0:
+; CHECK-COMMON-NEXT: sub sp, sp, #96
+; CHECK-COMMON-NEXT: stp x20, x19, [sp, #80] // 16-byte Folded Spill
+; CHECK-COMMON-NEXT: mov x20, sp
+; CHECK-COMMON-NEXT: mov x19, x0
+; CHECK-COMMON-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-COMMON-NEXT: str zt0, [x20]
+; CHECK-COMMON-NEXT: smstop za
+; CHECK-COMMON-NEXT: blr x0
+; CHECK-COMMON-NEXT: smstart za
+; CHECK-COMMON-NEXT: ldr zt0, [x20]
+; CHECK-COMMON-NEXT: str zt0, [x20]
+; CHECK-COMMON-NEXT: smstop za
+; CHECK-COMMON-NEXT: blr x19
+; CHECK-COMMON-NEXT: smstart za
+; CHECK-COMMON-NEXT: ldr zt0, [x20]
+; CHECK-COMMON-NEXT: str zt0, [x20]
+; CHECK-COMMON-NEXT: smstop za
+; CHECK-COMMON-NEXT: blr x19
+; CHECK-COMMON-NEXT: smstart za
+; CHECK-COMMON-NEXT: ldr zt0, [x20]
+; CHECK-COMMON-NEXT: str zt0, [x20]
+; CHECK-COMMON-NEXT: smstop za
+; CHECK-COMMON-NEXT: blr x19
+; CHECK-COMMON-NEXT: smstart za
+; CHECK-COMMON-NEXT: ldr zt0, [x20]
+; CHECK-COMMON-NEXT: ldp x20, x19, [sp, #80] // 16-byte Folded Reload
+; CHECK-COMMON-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-COMMON-NEXT: add sp, sp, #96
+; CHECK-COMMON-NEXT: ret
+ call void %callee()
+ call void %callee()
+ call void %callee()
+ call void %callee()
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/158593
More information about the llvm-commits
mailing list