[flang-commits] [clang-tools-extra] [clang] [compiler-rt] [llvm] [flang] [libcxx] [lldb] [lld] [libc] [AArch64][SME2] Preserve ZT0 state around function calls (PR #78321)

Kerry McLaughlin via flang-commits flang-commits at lists.llvm.org
Sat Jan 20 03:43:25 PST 2024


https://github.com/kmclaughlin-arm updated https://github.com/llvm/llvm-project/pull/78321

>From 11dce217ed307601d0ea1eb5b016b47f80e67786 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 11 Jan 2024 17:46:00 +0000
Subject: [PATCH 1/7] [SME2][Clang] Add tests with ZT0 state

---
 llvm/test/CodeGen/AArch64/sme-zt0-state.ll | 125 +++++++++++++++++++++
 1 file changed, 125 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/sme-zt0-state.ll

diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
new file mode 100644
index 00000000000000..ff560681665f8b
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -0,0 +1,125 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+; Callee with no ZT state
+declare void @no_state_callee();
+
+; Callees with ZT0 state
+declare void @zt0_shared_callee() "aarch64_in_zt0";
+
+; Callees with ZA state
+
+declare void @za_shared_callee() "aarch64_pstate_za_shared";
+declare void @za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+
+;
+; Private-ZA Callee
+;
+
+; Expect spill & fill of ZT0 around call
+; Expect smstop/smstart za around call
+define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_in_caller_no_state_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl no_state_callee
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @no_state_callee();
+  ret void;
+}
+
+; Expect spill & fill of ZT0 around call
+; Expect setup and restore lazy-save around call
+; Expect smstart za after call
+define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x9, [x29, #-16]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    bl no_state_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB1_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @no_state_callee();
+  ret void;
+}
+
+;
+; Shared-ZA Callee
+;
+
+; Caller and callee have shared ZT0 state, no spill/fill of ZT0 required
+define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl zt0_shared_callee
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @zt0_shared_callee();
+  ret void;
+}
+
+; Expect spill & fill of ZT0 around call
+define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    bl za_shared_callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @za_shared_callee();
+  ret void;
+}
+
+; Caller and callee have shared ZA & ZT0
+define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    bl za_zt0_shared_callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @za_zt0_shared_callee();
+  ret void;
+}

>From eef198b94e7336d54b9e296d90c541826073ea36 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Fri, 12 Jan 2024 09:30:31 +0000
Subject: [PATCH 2/7] [AArch64][SME2] Preserve ZT0 state around function calls

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the call.
If the caller shares ZT0 state and the callee is not shared ZA, we must
additionally call SMSTOP/SMSTART ZA around the call.

This patch adds new AArch64ISDNodes for spilling & filling of ZT0.
Where requiresPreservingZT0 is true, ZT0 state will be preserved
across a call.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 46 ++++++++++++++++++-
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  2 +
 .../lib/Target/AArch64/AArch64SMEInstrInfo.td | 10 +++-
 .../AArch64/Utils/AArch64SMEAttributes.h      |  3 ++
 llvm/test/CodeGen/AArch64/sme-zt0-state.ll    | 33 +++++++++----
 5 files changed, 82 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7b9b6a7a428125..6cc839ecf4f66c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2341,6 +2341,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::SMSTART)
     MAKE_CASE(AArch64ISD::SMSTOP)
     MAKE_CASE(AArch64ISD::RESTORE_ZA)
+    MAKE_CASE(AArch64ISD::RESTORE_ZT)
+    MAKE_CASE(AArch64ISD::SAVE_ZT)
     MAKE_CASE(AArch64ISD::CALL)
     MAKE_CASE(AArch64ISD::ADRP)
     MAKE_CASE(AArch64ISD::ADR)
@@ -7655,6 +7657,32 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     });
   }
 
+  SDValue ZTFrameIdx;
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+  bool PreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
+
+  // If the caller has ZT0 state which will not be preserved by the callee,
+  // spill ZT0 before the call.
+  if (PreserveZT0) {
+    unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
+    ZTFrameIdx = DAG.getFrameIndex(
+        ZTObj,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+
+    Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
+                        {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+  }
+
+  // If caller shares ZT0 but the callee is not shared ZA, we need to stop
+  // PSTATE.ZA before the call if there is no lazy-save active.
+  bool ToggleZA = !RequiresLazySave && CallerAttrs.sharesZT0() &&
+                  CalleeAttrs.hasPrivateZAInterface();
+  if (ToggleZA)
+    Chain = DAG.getNode(
+        AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
+        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
   if (!IsSibCall)
@@ -8065,13 +8093,19 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                                  PStateSM, false);
   }
 
-  if (RequiresLazySave) {
+  if (RequiresLazySave || ToggleZA)
     // Unconditionally resume ZA.
     Result = DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other, Result,
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
         DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
 
+  if (PreserveZT0)
+    Result =
+        DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+                    {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+
+  if (RequiresLazySave) {
     // Conditionally restore the lazy save using a pseudo node.
     unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
     SDValue RegMask = DAG.getRegisterMask(
@@ -8100,7 +8134,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(0, DL, MVT::i64));
   }
 
-  if (RequiresSMChange || RequiresLazySave) {
+  if (RequiresSMChange || RequiresLazySave || PreserveZT0) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
@@ -23977,6 +24011,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
       return DAG.getMergeValues(
           {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
     }
+    case Intrinsic::aarch64_sme_ldr_zt:
+      return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
+    case Intrinsic::aarch64_sme_str_zt:
+      return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
     default:
       break;
     }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6047a3b7b2864a..abecc3560ccbb3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -61,6 +61,8 @@ enum NodeType : unsigned {
   SMSTART,
   SMSTOP,
   RESTORE_ZA,
+  RESTORE_ZT,
+  SAVE_ZT,
 
   // Produces the full sequence of instructions for getting the thread pointer
   // offset of a variable into X0, using the TLSDesc model.
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 380f6e1fcfdaef..eeae5303a3f898 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
                              [SDTCisInt<0>, SDTCisPtrTy<1>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue]>;
+def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
+                                [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                                [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
+def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
 
 //===----------------------------------------------------------------------===//
 // Instruction naming conventions.
@@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops
 
 defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;
 
-defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
-defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
+defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
+defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;
 
 def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
 def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index af2854856fb979..417dec3432a008 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -119,6 +119,9 @@ class SMEAttrs {
            State == StateValue::InOut || State == StateValue::Preserved;
   }
   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
+  bool requiresPreservingZT0(const SMEAttrs &Callee) const {
+    return hasZT0State() && !Callee.sharesZT0();
+  }
 };
 
 } // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index ff560681665f8b..289794d12be171 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -21,9 +21,16 @@ declare void @za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0"
 define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: zt0_in_caller_no_state_callee:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    mov x19, sp
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    smstop za
 ; CHECK-NEXT:    bl no_state_callee
-; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #80
 ; CHECK-NEXT:    ret
   call void @no_state_callee();
   ret void;
@@ -35,21 +42,25 @@ define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
 define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
 ; CHECK-NEXT:    mov x29, sp
-; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    sub sp, sp, #80
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x9, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x9
 ; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
 ; CHECK-NEXT:    stur wzr, [x29, #-4]
 ; CHECK-NEXT:    sturh wzr, [x29, #-6]
 ; CHECK-NEXT:    stur x9, [x29, #-16]
 ; CHECK-NEXT:    sturh w8, [x29, #-8]
 ; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    bl no_state_callee
 ; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
 ; CHECK-NEXT:    sub x0, x29, #16
 ; CHECK-NEXT:    cbnz x8, .LBB1_2
@@ -58,7 +69,8 @@ define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "
 ; CHECK-NEXT:  .LBB1_2:
 ; CHECK-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEXT:    mov sp, x29
-; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
   call void @no_state_callee();
   ret void;
@@ -84,19 +96,24 @@ define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
 define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
 ; CHECK-NEXT:    mov x29, sp
-; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    sub sp, sp, #80
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    mov x9, sp
 ; CHECK-NEXT:    msub x8, x8, x8, x9
 ; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    sub x19, x29, #80
 ; CHECK-NEXT:    stur wzr, [x29, #-4]
 ; CHECK-NEXT:    sturh wzr, [x29, #-6]
 ; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    bl za_shared_callee
+; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    mov sp, x29
-; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
   call void @za_shared_callee();
   ret void;

>From 510f1a6af3669c296d4aa3e68f7789dfca4e8ce7 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 18 Jan 2024 10:55:53 +0000
Subject: [PATCH 3/7] - Added requiresZAToggle() to AArch64SMEAttributes -
 Added a test for an aarch64_in_zt0 caller -> aarch64_new_zt0 callee

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  7 ++++--
 .../AArch64/Utils/AArch64SMEAttributes.h      |  4 ++++
 llvm/test/CodeGen/AArch64/sme-zt0-state.ll    | 23 +++++++++++++++++++
 3 files changed, 32 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6cc839ecf4f66c..01500db9d6a58f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7675,8 +7675,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   // If caller shares ZT0 but the callee is not shared ZA, we need to stop
   // PSTATE.ZA before the call if there is no lazy-save active.
-  bool ToggleZA = !RequiresLazySave && CallerAttrs.sharesZT0() &&
-                  CalleeAttrs.hasPrivateZAInterface();
+  bool ToggleZA = CallerAttrs.requiresZAToggle(CalleeAttrs);
+
+  assert((!ToggleZA || !RequiresLazySave) &&
+       "Lazy-save should have PSTATE.SM=1 on entry to the function");
+
   if (ToggleZA)
     Chain = DAG.getNode(
         AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 417dec3432a008..cedc683e9de4db 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -122,6 +122,10 @@ class SMEAttrs {
   bool requiresPreservingZT0(const SMEAttrs &Callee) const {
     return hasZT0State() && !Callee.sharesZT0();
   }
+  bool requiresZAToggle(const SMEAttrs &Callee) const {
+    return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
+
+  }
 };
 
 } // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 289794d12be171..7df9df4ec9a382 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -6,6 +6,7 @@ declare void @no_state_callee();
 
 ; Callees with ZT0 state
 declare void @zt0_shared_callee() "aarch64_in_zt0";
+declare void @zt0_new_callee() "aarch64_new_zt0";
 
 ; Callees with ZA state
 
@@ -140,3 +141,25 @@ define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shar
   call void @za_zt0_shared_callee();
   ret void;
 }
+
+; New-ZA Callee
+
+; Expect spill & fill of ZT0 around call
+; Expect smstop/smstart za around call
+define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_in_caller_zt0_new_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    mov x19, sp
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    bl zt0_new_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #80
+; CHECK-NEXT:    ret
+  call void @zt0_new_callee();
+  ret void;
+}

>From c9f0e35ca24ea52a8f1d37409cff169e60a7ac24 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 18 Jan 2024 11:25:09 +0000
Subject: [PATCH 4/7] - Run clang-format

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp      | 2 +-
 llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h | 1 -
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 01500db9d6a58f..5756a560280a68 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7678,7 +7678,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   bool ToggleZA = CallerAttrs.requiresZAToggle(CalleeAttrs);
 
   assert((!ToggleZA || !RequiresLazySave) &&
-       "Lazy-save should have PSTATE.SM=1 on entry to the function");
+         "Lazy-save should have PSTATE.SM=1 on entry to the function");
 
   if (ToggleZA)
     Chain = DAG.getNode(
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index cedc683e9de4db..c9585cbb00268f 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -124,7 +124,6 @@ class SMEAttrs {
   }
   bool requiresZAToggle(const SMEAttrs &Callee) const {
     return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
-
   }
 };
 

>From 4f3a3e9d0939a847478308846a8f0605ea789218 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 18 Jan 2024 13:38:45 +0000
Subject: [PATCH 5/7] - Split requiresPreservingZT0 into
 requiresDisablingZABeforeCall/requiresEnablingZAAfterCall - Renamed
 PreserveZT0 to ShouldPreserveZT0 - Added unittests for requiresPreservingZT0,
 requiresDisablingZABeforeCall requiresEnablingZAAfterCall

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 17 ++--
 .../AArch64/Utils/AArch64SMEAttributes.h      |  6 +-
 .../Target/AArch64/SMEAttributesTest.cpp      | 87 +++++++++++++++++++
 3 files changed, 100 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5756a560280a68..4cfe7b572dfbe8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7659,11 +7659,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   SDValue ZTFrameIdx;
   MachineFrameInfo &MFI = MF.getFrameInfo();
-  bool PreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
+  bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
 
   // If the caller has ZT0 state which will not be preserved by the callee,
   // spill ZT0 before the call.
-  if (PreserveZT0) {
+  if (ShouldPreserveZT0) {
     unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
     ZTFrameIdx = DAG.getFrameIndex(
         ZTObj,
@@ -7675,12 +7675,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   // If caller shares ZT0 but the callee is not shared ZA, we need to stop
   // PSTATE.ZA before the call if there is no lazy-save active.
-  bool ToggleZA = CallerAttrs.requiresZAToggle(CalleeAttrs);
-
-  assert((!ToggleZA || !RequiresLazySave) &&
+  bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
+  assert((!DisableZA || !RequiresLazySave) &&
          "Lazy-save should have PSTATE.SM=1 on entry to the function");
 
-  if (ToggleZA)
+  if (DisableZA)
     Chain = DAG.getNode(
         AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
@@ -8096,14 +8095,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                                  PStateSM, false);
   }
 
-  if (RequiresLazySave || ToggleZA)
+  if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
     // Unconditionally resume ZA.
     Result = DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other, Result,
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
         DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
 
-  if (PreserveZT0)
+  if (ShouldPreserveZT0)
     Result =
         DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
                     {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
@@ -8137,7 +8136,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(0, DL, MVT::i64));
   }
 
-  if (RequiresSMChange || RequiresLazySave || PreserveZT0) {
+  if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index c9585cbb00268f..f1852bf2f6fcf2 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -122,9 +122,13 @@ class SMEAttrs {
   bool requiresPreservingZT0(const SMEAttrs &Callee) const {
     return hasZT0State() && !Callee.sharesZT0();
   }
-  bool requiresZAToggle(const SMEAttrs &Callee) const {
+  bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
     return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
   }
+  bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
+    return (requiresLazySave(Callee) && !Callee.preservesZA()) ||
+           requiresDisablingZABeforeCall(Callee);
+  }
 };
 
 } // namespace llvm
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 2f7201464ba2f2..35918e05996752 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -191,6 +191,9 @@ TEST(SMEAttributes, Basics) {
 TEST(SMEAttributes, Transitions) {
   // Normal -> Normal
   ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresPreservingZT0(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresDisablingZABeforeCall(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresEnablingZAAfterCall(SA(SA::Normal)));
   // Normal -> Normal + LocallyStreaming
   ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
   ASSERT_EQ(*SA(SA::Normal)
@@ -275,4 +278,88 @@ TEST(SMEAttributes, Transitions) {
                  .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body),
                                    /*BodyOverridesInterface=*/true),
             true);
+
+  SA ZT0_In = SA(SA::encodeZT0State(SA::StateValue::In));
+  SA ZT0_InOut = SA(SA::encodeZT0State(SA::StateValue::InOut));
+  SA ZT0_Out = SA(SA::encodeZT0State(SA::StateValue::Out));
+  SA ZT0_Preserved = SA(SA::encodeZT0State(SA::StateValue::Preserved));
+  SA ZT0_New = SA(SA::encodeZT0State(SA::StateValue::New));
+
+  // ZT0 New -> Normal
+  ASSERT_TRUE(ZT0_New.requiresPreservingZT0(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_New.requiresDisablingZABeforeCall(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+  // ZT0 New -> ZT0 New
+  ASSERT_TRUE(ZT0_New.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_New.requiresDisablingZABeforeCall(ZT0_New));
+  ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(ZT0_New));
+
+  // ZT0 New -> ZT0 Shared
+  ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_In));
+  ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_In));
+  ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_In));
+
+  ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_InOut));
+  ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_InOut));
+  ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_InOut));
+
+  ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_Out));
+  ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_Out));
+  ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_Out));
+
+  ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_Preserved));
+  ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_Preserved));
+  ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_Preserved));
+
+  // ZT0 Shared -> Normal
+  ASSERT_TRUE(ZT0_In.requiresPreservingZT0(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_In.requiresDisablingZABeforeCall(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_In.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+  ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_InOut.requiresDisablingZABeforeCall(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_InOut.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+  ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_Out.requiresDisablingZABeforeCall(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_Out.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+  ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_Preserved.requiresDisablingZABeforeCall(SA(SA::Normal)));
+  ASSERT_TRUE(ZT0_Preserved.requiresEnablingZAAfterCall(SA(SA::Normal)));
+
+  // ZT0 Shared -> ZT0 Shared
+  ASSERT_FALSE(ZT0_In.requiresPreservingZT0(ZT0_In));
+  ASSERT_FALSE(ZT0_In.requiresDisablingZABeforeCall(ZT0_In));
+  ASSERT_FALSE(ZT0_In.requiresEnablingZAAfterCall(ZT0_In));
+
+  ASSERT_FALSE(ZT0_InOut.requiresPreservingZT0(ZT0_In));
+  ASSERT_FALSE(ZT0_InOut.requiresDisablingZABeforeCall(ZT0_In));
+  ASSERT_FALSE(ZT0_InOut.requiresEnablingZAAfterCall(ZT0_In));
+
+  ASSERT_FALSE(ZT0_Out.requiresPreservingZT0(ZT0_In));
+  ASSERT_FALSE(ZT0_Out.requiresDisablingZABeforeCall(ZT0_In));
+  ASSERT_FALSE(ZT0_Out.requiresEnablingZAAfterCall(ZT0_In));
+
+  ASSERT_FALSE(ZT0_Preserved.requiresPreservingZT0(ZT0_In));
+  ASSERT_FALSE(ZT0_Preserved.requiresDisablingZABeforeCall(ZT0_In));
+  ASSERT_FALSE(ZT0_Preserved.requiresEnablingZAAfterCall(ZT0_In));
+
+  // ZT0 Shared -> ZT0 New
+  ASSERT_TRUE(ZT0_In.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_In.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_In.requiresEnablingZAAfterCall(ZT0_New));
+
+  ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_InOut.requiresEnablingZAAfterCall(ZT0_New));
+
+  ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_Out.requiresEnablingZAAfterCall(ZT0_New));
+
+  ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(ZT0_New));
+  ASSERT_TRUE(ZT0_Preserved.requiresEnablingZAAfterCall(ZT0_New));
 }

>From f901ecf35f8b849ce0b10b7e861586c96e7e925a Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Fri, 19 Jan 2024 11:32:33 +0000
Subject: [PATCH 6/7] - Reduced number of tests added to SMEAttributesTest &
 added some with ZA state

---
 .../AArch64/Utils/AArch64SMEAttributes.h      |   3 +-
 .../Target/AArch64/SMEAttributesTest.cpp      | 115 +++++-------------
 2 files changed, 33 insertions(+), 85 deletions(-)

diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index f1852bf2f6fcf2..d370d518ef3841 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -126,8 +126,7 @@ class SMEAttrs {
     return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
   }
   bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
-    return (requiresLazySave(Callee) && !Callee.preservesZA()) ||
-           requiresDisablingZABeforeCall(Callee);
+    return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
   }
 };
 
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 35918e05996752..e121d3690bf035 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -279,87 +279,36 @@ TEST(SMEAttributes, Transitions) {
                                    /*BodyOverridesInterface=*/true),
             true);
 
-  SA ZT0_In = SA(SA::encodeZT0State(SA::StateValue::In));
-  SA ZT0_InOut = SA(SA::encodeZT0State(SA::StateValue::InOut));
-  SA ZT0_Out = SA(SA::encodeZT0State(SA::StateValue::Out));
-  SA ZT0_Preserved = SA(SA::encodeZT0State(SA::StateValue::Preserved));
-  SA ZT0_New = SA(SA::encodeZT0State(SA::StateValue::New));
-
-  // ZT0 New -> Normal
-  ASSERT_TRUE(ZT0_New.requiresPreservingZT0(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_New.requiresDisablingZABeforeCall(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(SA(SA::Normal)));
-
-  // ZT0 New -> ZT0 New
-  ASSERT_TRUE(ZT0_New.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_New.requiresDisablingZABeforeCall(ZT0_New));
-  ASSERT_TRUE(ZT0_New.requiresEnablingZAAfterCall(ZT0_New));
-
-  // ZT0 New -> ZT0 Shared
-  ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_In));
-  ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_In));
-  ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_In));
-
-  ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_InOut));
-  ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_InOut));
-  ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_InOut));
-
-  ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_Out));
-  ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_Out));
-  ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_Out));
-
-  ASSERT_FALSE(ZT0_New.requiresPreservingZT0(ZT0_Preserved));
-  ASSERT_FALSE(ZT0_New.requiresDisablingZABeforeCall(ZT0_Preserved));
-  ASSERT_FALSE(ZT0_New.requiresEnablingZAAfterCall(ZT0_Preserved));
-
-  // ZT0 Shared -> Normal
-  ASSERT_TRUE(ZT0_In.requiresPreservingZT0(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_In.requiresDisablingZABeforeCall(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_In.requiresEnablingZAAfterCall(SA(SA::Normal)));
-
-  ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_InOut.requiresDisablingZABeforeCall(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_InOut.requiresEnablingZAAfterCall(SA(SA::Normal)));
-
-  ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_Out.requiresDisablingZABeforeCall(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_Out.requiresEnablingZAAfterCall(SA(SA::Normal)));
-
-  ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_Preserved.requiresDisablingZABeforeCall(SA(SA::Normal)));
-  ASSERT_TRUE(ZT0_Preserved.requiresEnablingZAAfterCall(SA(SA::Normal)));
-
-  // ZT0 Shared -> ZT0 Shared
-  ASSERT_FALSE(ZT0_In.requiresPreservingZT0(ZT0_In));
-  ASSERT_FALSE(ZT0_In.requiresDisablingZABeforeCall(ZT0_In));
-  ASSERT_FALSE(ZT0_In.requiresEnablingZAAfterCall(ZT0_In));
-
-  ASSERT_FALSE(ZT0_InOut.requiresPreservingZT0(ZT0_In));
-  ASSERT_FALSE(ZT0_InOut.requiresDisablingZABeforeCall(ZT0_In));
-  ASSERT_FALSE(ZT0_InOut.requiresEnablingZAAfterCall(ZT0_In));
-
-  ASSERT_FALSE(ZT0_Out.requiresPreservingZT0(ZT0_In));
-  ASSERT_FALSE(ZT0_Out.requiresDisablingZABeforeCall(ZT0_In));
-  ASSERT_FALSE(ZT0_Out.requiresEnablingZAAfterCall(ZT0_In));
-
-  ASSERT_FALSE(ZT0_Preserved.requiresPreservingZT0(ZT0_In));
-  ASSERT_FALSE(ZT0_Preserved.requiresDisablingZABeforeCall(ZT0_In));
-  ASSERT_FALSE(ZT0_Preserved.requiresEnablingZAAfterCall(ZT0_In));
-
-  // ZT0 Shared -> ZT0 New
-  ASSERT_TRUE(ZT0_In.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_In.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_In.requiresEnablingZAAfterCall(ZT0_New));
-
-  ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_InOut.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_InOut.requiresEnablingZAAfterCall(ZT0_New));
-
-  ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_Out.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_Out.requiresEnablingZAAfterCall(ZT0_New));
-
-  ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_Preserved.requiresPreservingZT0(ZT0_New));
-  ASSERT_TRUE(ZT0_Preserved.requiresEnablingZAAfterCall(ZT0_New));
+  SA Private_ZA = SA(SA::Normal);
+  SA ZA_Shared = SA(SA::ZA_Shared);
+  SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
+  SA ZA_ZT0_Shared = SA(SA::ZA_Shared | SA::encodeZT0State(SA::StateValue::In));
+
+  // Shared ZA -> Private ZA Interface
+  ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZA_Shared.requiresEnablingZAAfterCall(Private_ZA));
+
+  // Shared ZT0 -> Private ZA Interface
+  ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
+  ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+
+  // Shared ZA & ZT0 -> Private ZA Interface
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
+  ASSERT_TRUE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+
+  // Shared ZA -> Shared ZA Interface
+  ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZA_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+
+  // Shared ZT0 -> Shared ZA Interface
+  ASSERT_FALSE(ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
+  ASSERT_FALSE(ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+
+  // Shared ZA & ZT0 -> Shared ZA Interface
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
 }

>From dca9f001b7c1eb47ea865e24484de3c1eca6b56d Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Fri, 19 Jan 2024 14:11:55 +0000
Subject: [PATCH 7/7] - Rebased on main - Used a single callee() function in
 sme-zt0-state.ll and added attributes at the callsite

---
 llvm/test/CodeGen/AArch64/sme-zt0-state.ll | 36 ++++++++--------------
 1 file changed, 13 insertions(+), 23 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 7df9df4ec9a382..88eaf19ec488f3 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -1,17 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
 
-; Callee with no ZT state
-declare void @no_state_callee();
-
-; Callees with ZT0 state
-declare void @zt0_shared_callee() "aarch64_in_zt0";
-declare void @zt0_new_callee() "aarch64_new_zt0";
-
-; Callees with ZA state
-
-declare void @za_shared_callee() "aarch64_pstate_za_shared";
-declare void @za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+declare void @callee();
 
 ;
 ; Private-ZA Callee
@@ -27,13 +17,13 @@ define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
 ; CHECK-NEXT:    mov x19, sp
 ; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    smstop za
-; CHECK-NEXT:    bl no_state_callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
 ; CHECK-NEXT:    add sp, sp, #80
 ; CHECK-NEXT:    ret
-  call void @no_state_callee();
+  call void @callee();
   ret void;
 }
 
@@ -59,7 +49,7 @@ define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "
 ; CHECK-NEXT:    sturh w8, [x29, #-8]
 ; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    str zt0, [x19]
-; CHECK-NEXT:    bl no_state_callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
@@ -73,7 +63,7 @@ define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "
 ; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @no_state_callee();
+  call void @callee();
   ret void;
 }
 
@@ -86,10 +76,10 @@ define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
 ; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    bl zt0_shared_callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @zt0_shared_callee();
+  call void @callee() "aarch64_in_zt0";
   ret void;
 }
 
@@ -110,13 +100,13 @@ define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared"
 ; CHECK-NEXT:    sturh wzr, [x29, #-6]
 ; CHECK-NEXT:    stur x8, [x29, #-16]
 ; CHECK-NEXT:    str zt0, [x19]
-; CHECK-NEXT:    bl za_shared_callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    mov sp, x29
 ; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @za_shared_callee();
+  call void @callee() "aarch64_pstate_za_shared";
   ret void;
 }
 
@@ -134,11 +124,11 @@ define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shar
 ; CHECK-NEXT:    stur wzr, [x29, #-4]
 ; CHECK-NEXT:    sturh wzr, [x29, #-6]
 ; CHECK-NEXT:    stur x8, [x29, #-16]
-; CHECK-NEXT:    bl za_zt0_shared_callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    mov sp, x29
 ; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
-  call void @za_zt0_shared_callee();
+  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
   ret void;
 }
 
@@ -154,12 +144,12 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
 ; CHECK-NEXT:    mov x19, sp
 ; CHECK-NEXT:    str zt0, [x19]
 ; CHECK-NEXT:    smstop za
-; CHECK-NEXT:    bl zt0_new_callee
+; CHECK-NEXT:    bl callee
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    ldr zt0, [x19]
 ; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
 ; CHECK-NEXT:    add sp, sp, #80
 ; CHECK-NEXT:    ret
-  call void @zt0_new_callee();
+  call void @callee() "aarch64_new_zt0";
   ret void;
 }



More information about the flang-commits mailing list