[llvm] [AArch64][SME2] Preserve ZT0 state around function calls (PR #78321)
Kerry McLaughlin via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 18 06:33:26 PST 2024
https://github.com/kmclaughlin-arm updated https://github.com/llvm/llvm-project/pull/78321
>From 9729355b9e7cb23ba15ff47a34f7cc6f3cf24ce8 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 11 Jan 2024 17:46:00 +0000
Subject: [PATCH 1/5] [SME2][Clang] Add tests with ZT0 state
---
llvm/test/CodeGen/AArch64/sme-zt0-state.ll | 125 +++++++++++++++++++++
1 file changed, 125 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/sme-zt0-state.ll
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
new file mode 100644
index 000000000000000..ff560681665f8bc
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -0,0 +1,125 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+; Callee with no ZT state
+declare void @no_state_callee();
+
+; Callees with ZT0 state
+declare void @zt0_shared_callee() "aarch64_in_zt0";
+
+; Callees with ZA state
+
+declare void @za_shared_callee() "aarch64_pstate_za_shared";
+declare void @za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+
+;
+; Private-ZA Callee
+;
+
+; Expect spill & fill of ZT0 around call
+; Expect smstop/smstart za around call
+define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_in_caller_no_state_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: bl no_state_callee
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @no_state_callee();
+ ret void;
+}
+
+; Expect spill & fill of ZT0 around call
+; Expect setup and restore lazy-save around call
+; Expect smstart za after call
+define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT: mov x29, sp
+; CHECK-NEXT: sub sp, sp, #16
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: msub x9, x8, x8, x9
+; CHECK-NEXT: mov sp, x9
+; CHECK-NEXT: sub x10, x29, #16
+; CHECK-NEXT: stur wzr, [x29, #-4]
+; CHECK-NEXT: sturh wzr, [x29, #-6]
+; CHECK-NEXT: stur x9, [x29, #-16]
+; CHECK-NEXT: sturh w8, [x29, #-8]
+; CHECK-NEXT: msr TPIDR2_EL0, x10
+; CHECK-NEXT: bl no_state_callee
+; CHECK-NEXT: smstart za
+; CHECK-NEXT: mrs x8, TPIDR2_EL0
+; CHECK-NEXT: sub x0, x29, #16
+; CHECK-NEXT: cbnz x8, .LBB1_2
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: bl __arm_tpidr2_restore
+; CHECK-NEXT: .LBB1_2:
+; CHECK-NEXT: msr TPIDR2_EL0, xzr
+; CHECK-NEXT: mov sp, x29
+; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @no_state_callee();
+ ret void;
+}
+
+;
+; Shared-ZA Callee
+;
+
+; Caller and callee have shared ZT0 state, no spill/fill of ZT0 required
+define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: bl zt0_shared_callee
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @zt0_shared_callee();
+ ret void;
+}
+
+; Expect spill & fill of ZT0 around call
+define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT: mov x29, sp
+; CHECK-NEXT: sub sp, sp, #16
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: msub x8, x8, x8, x9
+; CHECK-NEXT: mov sp, x8
+; CHECK-NEXT: stur wzr, [x29, #-4]
+; CHECK-NEXT: sturh wzr, [x29, #-6]
+; CHECK-NEXT: stur x8, [x29, #-16]
+; CHECK-NEXT: bl za_shared_callee
+; CHECK-NEXT: mov sp, x29
+; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @za_shared_callee();
+ ret void;
+}
+
+; Caller and callee have shared ZA & ZT0
+define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT: mov x29, sp
+; CHECK-NEXT: sub sp, sp, #16
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: mov x9, sp
+; CHECK-NEXT: msub x8, x8, x8, x9
+; CHECK-NEXT: mov sp, x8
+; CHECK-NEXT: stur wzr, [x29, #-4]
+; CHECK-NEXT: sturh wzr, [x29, #-6]
+; CHECK-NEXT: stur x8, [x29, #-16]
+; CHECK-NEXT: bl za_zt0_shared_callee
+; CHECK-NEXT: mov sp, x29
+; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @za_zt0_shared_callee();
+ ret void;
+}
>From 1fdb15e50ef6010bee53344c761ab6b7942b8e3e Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Fri, 12 Jan 2024 09:30:31 +0000
Subject: [PATCH 2/5] [AArch64][SME2] Preserve ZT0 state around function calls
If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the call.
If the caller shares ZT0 state and the callee is not shared ZA, we must
additionally call SMSTOP/SMSTART ZA around the call.
This patch adds new AArch64ISDNodes for spilling & filling of ZT0.
Where requiresPreservingZT0 is true, ZT0 state will be preserved
across a call.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 56 ++++++++++++++++---
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 +
.../lib/Target/AArch64/AArch64SMEInstrInfo.td | 10 +++-
.../AArch64/Utils/AArch64SMEAttributes.h | 3 +
llvm/test/CodeGen/AArch64/sme-zt0-state.ll | 33 ++++++++---
5 files changed, 87 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 620872790ed8db0..6df4e075e1ca500 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2341,6 +2341,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::SMSTART)
MAKE_CASE(AArch64ISD::SMSTOP)
MAKE_CASE(AArch64ISD::RESTORE_ZA)
+ MAKE_CASE(AArch64ISD::RESTORE_ZT)
+ MAKE_CASE(AArch64ISD::SAVE_ZT)
MAKE_CASE(AArch64ISD::CALL)
MAKE_CASE(AArch64ISD::ADRP)
MAKE_CASE(AArch64ISD::ADR)
@@ -7664,6 +7666,32 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
});
}
+ SDValue ZTFrameIdx;
+ MachineFrameInfo &MFI = MF.getFrameInfo();
+ bool PreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
+
+ // If the caller has ZT0 state which will not be preserved by the callee,
+ // spill ZT0 before the call.
+ if (PreserveZT0) {
+ unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
+ ZTFrameIdx = DAG.getFrameIndex(
+ ZTObj,
+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+
+ Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
+ {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+ }
+
+ // If caller shares ZT0 but the callee is not shared ZA, we need to stop
+ // PSTATE.ZA before the call if there is no lazy-save active.
+ bool ToggleZA = !RequiresLazySave && CallerAttrs.sharesZT0() &&
+ CalleeAttrs.hasPrivateZAInterface();
+ if (ToggleZA)
+ Chain = DAG.getNode(
+ AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
+ DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
+ DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+
// Adjust the stack pointer for the new arguments...
// These operations are automatically eliminated by the prolog/epilog pass
if (!IsSibCall)
@@ -8074,14 +8102,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
PStateSM, false);
}
+ if ((RequiresLazySave && !CalleeAttrs.preservesZA()) || ToggleZA)
+ // Unconditionally resume ZA.
+ Result = DAG.getNode(
+ AArch64ISD::SMSTART, DL, MVT::Other, Result,
+ DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
+ DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+
+ if (PreserveZT0)
+ Result =
+ DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+ {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+
if (RequiresLazySave) {
if (!CalleeAttrs.preservesZA()) {
- // Unconditionally resume ZA.
- Result = DAG.getNode(
- AArch64ISD::SMSTART, DL, MVT::Other, Result,
- DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
-
// Conditionally restore the lazy save using a pseudo node.
unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
SDValue RegMask = DAG.getRegisterMask(
@@ -8110,7 +8144,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(0, DL, MVT::i64));
}
- if (RequiresSMChange || RequiresLazySave) {
+ if (RequiresSMChange || RequiresLazySave || PreserveZT0) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
@@ -23979,6 +24013,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getMergeValues(
{A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
}
+ case Intrinsic::aarch64_sme_ldr_zt:
+ return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
+ DAG.getVTList(MVT::Other), N->getOperand(0),
+ N->getOperand(2), N->getOperand(3));
+ case Intrinsic::aarch64_sme_str_zt:
+ return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
+ DAG.getVTList(MVT::Other), N->getOperand(0),
+ N->getOperand(2), N->getOperand(3));
default:
break;
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1fd639b4f7ee8f4..bffee867fdf2940 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -61,6 +61,8 @@ enum NodeType : unsigned {
SMSTART,
SMSTOP,
RESTORE_ZA,
+ RESTORE_ZT,
+ SAVE_ZT,
// Produces the full sequence of instructions for getting the thread pointer
// offset of a variable into X0, using the TLSDesc model.
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 380f6e1fcfdaefc..eeae5303a3f8987 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
SDNPOptInGlue]>;
+def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
+ [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+ [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
+def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
+ [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+ [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
//===----------------------------------------------------------------------===//
// Instruction naming conventions.
@@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops
defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;
-defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
-defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
+defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
+defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;
def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index af2854856fb9796..417dec3432a0088 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -119,6 +119,9 @@ class SMEAttrs {
State == StateValue::InOut || State == StateValue::Preserved;
}
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
+ bool requiresPreservingZT0(const SMEAttrs &Callee) const {
+ return hasZT0State() && !Callee.sharesZT0();
+ }
};
} // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index ff560681665f8bc..289794d12be171e 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -21,9 +21,16 @@ declare void @za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0"
define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
; CHECK-LABEL: zt0_in_caller_no_state_callee:
; CHECK: // %bb.0:
-; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT: sub sp, sp, #80
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: mov x19, sp
+; CHECK-NEXT: str zt0, [x19]
+; CHECK-NEXT: smstop za
; CHECK-NEXT: bl no_state_callee
-; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT: smstart za
+; CHECK-NEXT: ldr zt0, [x19]
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @no_state_callee();
ret void;
@@ -35,21 +42,25 @@ define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
; CHECK: // %bb.0:
-; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
; CHECK-NEXT: mov x29, sp
-; CHECK-NEXT: sub sp, sp, #16
+; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: sub x10, x29, #16
+; CHECK-NEXT: sub x19, x29, #80
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur x9, [x29, #-16]
; CHECK-NEXT: sturh w8, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x10
+; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: bl no_state_callee
; CHECK-NEXT: smstart za
+; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: sub x0, x29, #16
; CHECK-NEXT: cbnz x8, .LBB1_2
@@ -58,7 +69,8 @@ define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: mov sp, x29
-; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @no_state_callee();
ret void;
@@ -84,19 +96,24 @@ define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
; CHECK: // %bb.0:
-; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
; CHECK-NEXT: mov x29, sp
-; CHECK-NEXT: sub sp, sp, #16
+; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x8, x8, x8, x9
; CHECK-NEXT: mov sp, x8
+; CHECK-NEXT: sub x19, x29, #80
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur x8, [x29, #-16]
+; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: bl za_shared_callee
+; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: mov sp, x29
-; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @za_shared_callee();
ret void;
>From f8204634af793659942935bfa7907874881718cb Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 18 Jan 2024 10:55:53 +0000
Subject: [PATCH 3/5] - Added requiresZAToggle() to AArch64SMEAttributes -
Added a test for an aarch64_in_zt0 caller -> aarch64_new_zt0 callee
---
.../Target/AArch64/AArch64ISelLowering.cpp | 7 ++++--
.../AArch64/Utils/AArch64SMEAttributes.h | 4 ++++
llvm/test/CodeGen/AArch64/sme-zt0-state.ll | 23 +++++++++++++++++++
3 files changed, 32 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6df4e075e1ca500..27b81faccfce360 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7684,8 +7684,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
// PSTATE.ZA before the call if there is no lazy-save active.
- bool ToggleZA = !RequiresLazySave && CallerAttrs.sharesZT0() &&
- CalleeAttrs.hasPrivateZAInterface();
+ bool ToggleZA = CallerAttrs.requiresZAToggle(CalleeAttrs);
+
+ assert((!ToggleZA || !RequiresLazySave) &&
+ "Lazy-save should have PSTATE.SM=1 on entry to the function");
+
if (ToggleZA)
Chain = DAG.getNode(
AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 417dec3432a0088..cedc683e9de4dbc 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -122,6 +122,10 @@ class SMEAttrs {
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
return hasZT0State() && !Callee.sharesZT0();
}
+ bool requiresZAToggle(const SMEAttrs &Callee) const {
+ return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
+
+ }
};
} // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 289794d12be171e..7df9df4ec9a3826 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -6,6 +6,7 @@ declare void @no_state_callee();
; Callees with ZT0 state
declare void @zt0_shared_callee() "aarch64_in_zt0";
+declare void @zt0_new_callee() "aarch64_new_zt0";
; Callees with ZA state
@@ -140,3 +141,25 @@ define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shar
call void @za_zt0_shared_callee();
ret void;
}
+
+; New-ZA Callee
+
+; Expect spill & fill of ZT0 around call
+; Expect smstop/smstart za around call
+define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_in_caller_zt0_new_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sub sp, sp, #80
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: mov x19, sp
+; CHECK-NEXT: str zt0, [x19]
+; CHECK-NEXT: smstop za
+; CHECK-NEXT: bl zt0_new_callee
+; CHECK-NEXT: smstart za
+; CHECK-NEXT: ldr zt0, [x19]
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: add sp, sp, #80
+; CHECK-NEXT: ret
+ call void @zt0_new_callee();
+ ret void;
+}
>From f825d80408bead7b3c6b5e2186c273168b6e9aec Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 18 Jan 2024 11:25:09 +0000
Subject: [PATCH 4/5] - Run clang-format
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +-
llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h | 1 -
2 files changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 27b81faccfce360..a07fcdbd8afafc7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7687,7 +7687,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
bool ToggleZA = CallerAttrs.requiresZAToggle(CalleeAttrs);
assert((!ToggleZA || !RequiresLazySave) &&
- "Lazy-save should have PSTATE.SM=1 on entry to the function");
+ "Lazy-save should have PSTATE.SM=1 on entry to the function");
if (ToggleZA)
Chain = DAG.getNode(
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index cedc683e9de4dbc..c9585cbb00268f6 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -124,7 +124,6 @@ class SMEAttrs {
}
bool requiresZAToggle(const SMEAttrs &Callee) const {
return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
-
}
};
>From 273210906d745186c55e5b80f3406157d52e95a8 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 18 Jan 2024 13:38:45 +0000
Subject: [PATCH 5/5] - Split requiresPreservingZT0 into
requiresDisablingZABeforeCall/requiresEnablingZAAfterCall - Renamed
PreserveZT0 to ShouldPreserveZT0 - Added unittests for requiresPreservingZT0,
requiresDisablingZABeforeCall requiresEnablingZAAfterCall
---
.../Target/AArch64/AArch64ISelLowering.cpp | 17 ++--
.../AArch64/Utils/AArch64SMEAttributes.h | 6 +-
.../Target/AArch64/SMEAttributesTest.cpp | 87 +++++++++++++++++++
3 files changed, 100 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a07fcdbd8afafc7..20f51a69e6d68c7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7668,11 +7668,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue ZTFrameIdx;
MachineFrameInfo &MFI = MF.getFrameInfo();
- bool PreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
+ bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
- if (PreserveZT0) {
+ if (ShouldPreserveZT0) {
unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
ZTFrameIdx = DAG.getFrameIndex(
ZTObj,
@@ -7684,12 +7684,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
// PSTATE.ZA before the call if there is no lazy-save active.
- bool ToggleZA = CallerAttrs.requiresZAToggle(CalleeAttrs);
-
- assert((!ToggleZA || !RequiresLazySave) &&
+ bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
+ assert((!DisableZA || !RequiresLazySave) &&
"Lazy-save should have PSTATE.SM=1 on entry to the function");
- if (ToggleZA)
+ if (DisableZA)
Chain = DAG.getNode(
AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
@@ -8105,14 +8104,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
PStateSM, false);
}
- if ((RequiresLazySave && !CalleeAttrs.preservesZA()) || ToggleZA)
+ if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, MVT::Other, Result,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
- if (PreserveZT0)
+ if (ShouldPreserveZT0)
Result =
DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
{Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
@@ -8147,7 +8146,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(0, DL, MVT::i64));
}
- if (RequiresSMChange || RequiresLazySave || PreserveZT0) {
+ if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index c9585cbb00268f6..f1852bf2f6fcf23 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -122,9 +122,13 @@ class SMEAttrs {
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
return hasZT0State() && !Callee.sharesZT0();
}
- bool requiresZAToggle(const SMEAttrs &Callee) const {
+ bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
}
+ bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
+ return (requiresLazySave(Callee) && !Callee.preservesZA()) ||
+ requiresDisablingZABeforeCall(Callee);
+ }
};
} // namespace llvm
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 2f7201464ba2f23..35918e05996752d 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -191,6 +191,9 @@ TEST(SMEAttributes, Basics) {
TEST(SMEAttributes, Transitions) {
// Normal -> Normal
ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal)));
+ ASSERT_FALSE(SA(SA::Normal).requiresPreservingZT0(SA(SA::Normal)));
+ ASSERT_FALSE(SA(SA::Normal).requiresDisablingZABeforeCall(SA(SA::Normal)));
+ ASSERT_FALSE(SA(SA::Normal).requiresEnablingZAAfterCall(SA(SA::Normal)));
// Normal -> Normal + LocallyStreaming
ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
ASSERT_EQ(*SA(SA::Normal)
@@ -275,4 +278,88 @@ TEST(SMEAttributes, Transitions) {
.requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body),
/*BodyOverridesInterface=*/true),
true);
+
+ SA ZT0_In = SA(SA::encodeZT0State(SA::StateValue::In));
+ SA ZT0_InOut = SA(SA::encodeZT0State(SA::StateValue::InOut));
+ SA ZT0_Out = SA(SA::encodeZT0State(SA::StateValue::Out));
+ SA ZT0_Preserved = SA(SA::encodeZT0State(SA::StateValue::Preserved));
+ SA ZT0_New = SA(SA::encodeZT0State(SA::StateValue::New));
+
+ // ZT0 New -> Normal
+ ASSERT_TRUE(ZT0_New.requiresPreservingZT0(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_New.requiresDisablingZABeforeCall(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+ // ZT0 New -> ZT0 New
+ ASSERT_TRUE(ZT0_New.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_New.requiresDisablingZABeforeCall(ZT0_New));
+ ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(ZT0_New));
+
+ // ZT0 New -> ZT0 Shared
+ ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_In));
+ ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_In));
+ ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_In));
+
+ ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_InOut));
+ ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_InOut));
+ ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_InOut));
+
+ ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_Out));
+ ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_Out));
+ ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_Out));
+
+ ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_Preserved));
+ ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_Preserved));
+ ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_Preserved));
+
+ // ZT0 Shared -> Normal
+ ASSERT_TRUE(ZT0_In.requiresPreservingZT0(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_In.requiresDisablingZABeforeCall(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_In.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+ ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_InOut.requiresDisablingZABeforeCall(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_InOut.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+ ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_Out.requiresDisablingZABeforeCall(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_Out.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+ ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_Preserved.requiresDisablingZABeforeCall(SA(SA::Normal)));
+ ASSERT_TRUE(ZT0_Preserved.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+ // ZT0 Shared -> ZT0 Shared
+ ASSERT_FALSE(ZT0_In.requiresPreservingZT0(ZT0_In));
+ ASSERT_FALSE(ZT0_In.requiresDisablingZABeforeCall(ZT0_In));
+ ASSERT_FALSE(ZT0_In.requiresEnablingZAAfterCall(ZT0_In));
+
+ ASSERT_FALSE(ZT0_InOut.requiresPreservingZT0(ZT0_In));
+ ASSERT_FALSE(ZT0_InOut.requiresDisablingZABeforeCall(ZT0_In));
+ ASSERT_FALSE(ZT0_InOut.requiresEnablingZAAfterCall(ZT0_In));
+
+ ASSERT_FALSE(ZT0_Out.requiresPreservingZT0(ZT0_In));
+ ASSERT_FALSE(ZT0_Out.requiresDisablingZABeforeCall(ZT0_In));
+ ASSERT_FALSE(ZT0_Out.requiresEnablingZAAfterCall(ZT0_In));
+
+ ASSERT_FALSE(ZT0_Preserved.requiresPreservingZT0(ZT0_In));
+ ASSERT_FALSE(ZT0_Preserved.requiresDisablingZABeforeCall(ZT0_In));
+ ASSERT_FALSE(ZT0_Preserved.requiresEnablingZAAfterCall(ZT0_In));
+
+ // ZT0 Shared -> ZT0 New
+ ASSERT_TRUE(ZT0_In.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_In.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_In.requiresEnablingZAAfterCall(ZT0_New));
+
+ ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_InOut.requiresEnablingZAAfterCall(ZT0_New));
+
+ ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_Out.requiresEnablingZAAfterCall(ZT0_New));
+
+ ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(ZT0_New));
+ ASSERT_TRUE(ZT0_Preserved.requiresEnablingZAAfterCall(ZT0_New));
}
More information about the llvm-commits
mailing list