[llvm] [AArch64][SME2] Preserve ZT0 state around function calls (PR #76968)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 4 07:52:22 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

<details>
<summary>Changes</summary>

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the callee.

This patch extends SMEAttrs to interpret the following new attributes,
which apply to SME2 only:
  - aarch64_sme_pstate_zt0_new (ZT_New)
  - aarch64_sme_pstate_zt0_shared (ZT_Shared)
  - aarch64_sme_pstate_zt0_preserved (ZT_Preserved)

ZT0 must also be cleared on entry to a function marked with __arm_new_za.

---

Patch is 26.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76968.diff


9 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64FastISel.cpp (+2-1) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+30-1) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2) 
- (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+8-2) 
- (modified) llvm/lib/Target/AArch64/SMEABIPass.cpp (+14-4) 
- (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp (+15-2) 
- (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (+19-1) 
- (added) llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll (+306) 
- (modified) llvm/unittests/Target/AArch64/SMEAttributesTest.cpp (+33) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index e98f6c4984a752..f63cdf8bc4f328 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5176,7 +5176,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
                                         const TargetLibraryInfo *LibInfo) {
 
   SMEAttrs CallerAttrs(*FuncInfo.Fn);
-  if (CallerAttrs.hasZAState() || CallerAttrs.hasStreamingInterfaceOrBody() ||
+  if (CallerAttrs.hasZAState() || CallerAttrs.hasZTState() ||
+      CallerAttrs.hasStreamingInterfaceOrBody() ||
       CallerAttrs.hasStreamingCompatibleInterface())
     return nullptr;
   return new AArch64FastISel(FuncInfo, LibInfo);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 102fd0c3dae2ab..4121621616b8bd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2338,6 +2338,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::SMSTART)
     MAKE_CASE(AArch64ISD::SMSTOP)
     MAKE_CASE(AArch64ISD::RESTORE_ZA)
+    MAKE_CASE(AArch64ISD::RESTORE_ZT)
+    MAKE_CASE(AArch64ISD::SAVE_ZT)
     MAKE_CASE(AArch64ISD::CALL)
     MAKE_CASE(AArch64ISD::ADRP)
     MAKE_CASE(AArch64ISD::ADR)
@@ -7659,6 +7661,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     });
   }
 
+  SDValue ZTFrameIdx;
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+  bool PreserveZT = CallerAttrs.requiresPreservingZT(CalleeAttrs);
+
+  if (PreserveZT) {
+    unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
+    ZTFrameIdx = DAG.getFrameIndex(
+        ZTObj,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+
+    Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
+                        {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+  }
+
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
   if (!IsSibCall)
@@ -8077,6 +8093,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
           DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
           DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
 
+      if (PreserveZT)
+        Result =
+            DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+                        {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+
       // Conditionally restore the lazy save using a pseudo node.
       unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
       SDValue RegMask = DAG.getRegisterMask(
@@ -8105,7 +8126,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(0, DL, MVT::i64));
   }
 
-  if (RequiresSMChange || RequiresLazySave) {
+  if (RequiresSMChange || RequiresLazySave || PreserveZT) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
@@ -23953,6 +23974,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
       return DAG.getMergeValues(
           {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
     }
+    case Intrinsic::aarch64_sme_ldr_zt:
+      return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
+    case Intrinsic::aarch64_sme_str_zt:
+      return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
     default:
       break;
     }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6ddbcd41dcb769..6c14bc0aa8dc73 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -61,6 +61,8 @@ enum NodeType : unsigned {
   SMSTART,
   SMSTOP,
   RESTORE_ZA,
+  RESTORE_ZT,
+  SAVE_ZT,
 
   // Produces the full sequence of instructions for getting the thread pointer
   // offset of a variable into X0, using the TLSDesc model.
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 380f6e1fcfdaef..eeae5303a3f898 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
                              [SDTCisInt<0>, SDTCisPtrTy<1>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue]>;
+def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
+                                [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                                [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
+def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
 
 //===----------------------------------------------------------------------===//
 // Instruction naming conventions.
@@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops
 
 defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;
 
-defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
-defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
+defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
+defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;
 
 def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
 def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 3315171798d9f1..4ca0cf648bc147 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -40,7 +40,8 @@ struct SMEABI : public FunctionPass {
   bool runOnFunction(Function &F) override;
 
 private:
-  bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder);
+  bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder,
+                            bool ClearZTState);
 };
 } // end anonymous namespace
 
@@ -82,8 +83,8 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
 /// is active and we should call __arm_tpidr2_save to commit the lazy save.
 /// Additionally, PSTATE.ZA should be enabled at the beginning of the function
 /// and disabled before returning.
-bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
-                                  IRBuilder<> &Builder) {
+bool SMEABI::updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder,
+                                  bool ClearZTState) {
   LLVMContext &Context = F->getContext();
   BasicBlock *OrigBB = &F->getEntryBlock();
 
@@ -117,6 +118,14 @@ bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
   Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
                      Builder.getInt32(0xff));
 
+  // Clear ZT0 on entry to the function if required, after enabling pstate.za
+  if (ClearZTState) {
+    Function *ClearZT0Intr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero_zt);
+    Builder.CreateCall(ClearZT0Intr->getFunctionType(), ClearZT0Intr,
+                       {Builder.getInt32(0)});
+  }
+
   // Before returning, disable pstate.za
   for (BasicBlock &BB : *F) {
     Instruction *T = BB.getTerminator();
@@ -143,7 +152,8 @@ bool SMEABI::runOnFunction(Function &F) {
   bool Changed = false;
   SMEAttrs FnAttrs(F);
   if (FnAttrs.hasNewZABody())
-    Changed |= updateNewZAFunctions(M, &F, Builder);
+    Changed |= updateNewZAFunctions(M, &F, Builder,
+                                    FnAttrs.requiresPreservingZT(SMEAttrs()));
 
   return Changed;
 }
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 0082b4017986c6..ef3a043a15bcc2 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -18,8 +18,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
   else
     Bitmask &= ~M;
 
+  // Streaming Mode Attrs
   assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) &&
          "SM_Enabled and SM_Compatible are mutually exclusive");
+  // ZA Attrs
   assert(!(hasNewZABody() && hasSharedZAInterface()) &&
          "ZA_New and ZA_Shared are mutually exclusive");
   assert(!(hasNewZABody() && preservesZA()) &&
@@ -28,6 +30,11 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "ZA_New and ZA_NoLazySave are mutually exclusive");
   assert(!(hasSharedZAInterface() && (Bitmask & ZA_NoLazySave)) &&
          "ZA_Shared and ZA_NoLazySave are mutually exclusive");
+  // ZT Attrs
+  assert(!(hasNewZTBody() && hasSharedZTInterface()) &&
+         "ZT_New and ZT_Shared are mutually exclusive");
+  assert(!(hasNewZTBody() && preservesZT()) &&
+         "ZT_New and ZT_Preserved are mutually exclusive");
 }
 
 SMEAttrs::SMEAttrs(const CallBase &CB) {
@@ -40,10 +47,10 @@ SMEAttrs::SMEAttrs(const CallBase &CB) {
 SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
-                SMEAttrs::ZA_NoLazySave);
+                SMEAttrs::ZA_NoLazySave | SMEAttrs::ZT_Preserved);
   if (FuncName == "__arm_tpidr2_restore")
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
-                SMEAttrs::ZA_NoLazySave);
+                SMEAttrs::ZA_NoLazySave | SMEAttrs::ZT_Shared);
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -60,6 +67,12 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= ZA_New;
   if (Attrs.hasFnAttr("aarch64_pstate_za_preserved"))
     Bitmask |= ZA_Preserved;
+  if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_shared"))
+    Bitmask |= ZT_Shared;
+  if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_new"))
+    Bitmask |= ZT_New;
+  if (Attrs.hasFnAttr("aarch64_sme_pstate_zt0_preserved"))
+    Bitmask |= ZT_Preserved;
 }
 
 std::optional<bool>
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index e766b778b54102..3eceaf95a249a2 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -36,7 +36,10 @@ class SMEAttrs {
     ZA_New = 1 << 4,        // aarch64_pstate_sm_new
     ZA_Preserved = 1 << 5,  // aarch64_pstate_sm_preserved
     ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves
-    All = ZA_Preserved - 1
+    ZT_New = 1 << 7,        // aarch64_sme_pstate_zt0_new
+    ZT_Shared = 1 << 8,     // aarch64_sme_pstate_zt0_shared
+    ZT_Preserved = 1 << 9,  // aarch64_sme_pstate_zt0_preserved
+    All = ZT_Preserved - 1
   };
 
   SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); }
@@ -74,6 +77,14 @@ class SMEAttrs {
   requiresSMChange(const SMEAttrs &Callee,
                    bool BodyOverridesInterface = false) const;
 
+  /// \return true if a call from Caller -> Callee requires ZT0 state to be
+  /// preserved.
+  /// ZT0 must be preserved if the caller has ZT state and the callee
+  /// does not preserve ZT.
+  bool requiresPreservingZT(const SMEAttrs &Callee) const {
+    return hasZTState() && !Callee.preservesZT();
+  }
+
   // Interfaces to query PSTATE.ZA
   bool hasNewZABody() const { return Bitmask & ZA_New; }
   bool hasSharedZAInterface() const { return Bitmask & ZA_Shared; }
@@ -82,6 +93,13 @@ class SMEAttrs {
   bool hasZAState() const {
     return hasNewZABody() || hasSharedZAInterface();
   }
+
+  // Interfaces to query ZT0 state
+  bool hasNewZTBody() const { return Bitmask & ZT_New; }
+  bool hasSharedZTInterface() const { return Bitmask & ZT_Shared; }
+  bool preservesZT() const { return Bitmask & ZT_Preserved; }
+  bool hasZTState() const { return hasNewZTBody() || hasSharedZTInterface(); }
+
   bool requiresLazySave(const SMEAttrs &Callee) const {
     return hasZAState() && Callee.hasPrivateZAInterface() &&
            !(Callee.Bitmask & ZA_NoLazySave);
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll b/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll
new file mode 100644
index 00000000000000..bbcfd5cac197b5
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-preserve.ll
@@ -0,0 +1,306 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+; Normal callee, no ZT state
+declare void @normal_callee();
+
+; Callees with ZT state
+declare void @za_shared_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared";
+declare void @za_new_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new";
+
+; Callee with preserved ZT state
+declare void @za_preserved_callee() "aarch64_pstate_za_preserved" "aarch64_sme_pstate_zt0_preserved";
+
+
+define void @za_zt_new_caller_normal_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind {
+; CHECK-LABEL: za_zt_new_caller_normal_callee:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB0_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero {za}
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    sub x9, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl normal_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB0_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB0_4:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @normal_callee();
+  ret void;
+}
+
+define void @za_zt_new_caller_za_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind {
+; CHECK-LABEL: za_zt_new_caller_za_callee:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #144
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB1_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero {za}
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    sub x9, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x9
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl za_new_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB1_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB1_4:
+; CHECK-NEXT:    sub x8, x29, #144
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    str zt0, [x8]
+; CHECK-NEXT:    bl za_shared_callee
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @za_new_callee();
+  call void @za_shared_callee();
+  ret void;
+}
+
+define void @za_zt_shared_caller_normal_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared" nounwind {
+; CHECK-LABEL: za_zt_shared_caller_normal_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x9, [x29, #-16]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl normal_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB2_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB2_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @normal_callee();
+  ret void;
+}
+
+define void @za_zt_shared_caller_za_callee() "aarch64_pstate_za_shared" "aarch64_sme_pstate_zt0_shared" nounwind {
+; CHECK-LABEL: za_zt_shared_caller_za_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #144
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x9, [x29, #-16]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl za_new_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB3_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB3_2:
+; CHECK-NEXT:    sub x8, x29, #144
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    str zt0, [x8]
+; CHECK-NEXT:    bl za_shared_callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @za_new_callee();
+  call void @za_shared_callee();
+  ret void;
+}
+
+define void @za_zt_new_caller_za_preserved_callee() "aarch64_pstate_za_new" "aarch64_sme_pstate_zt0_new" nounwind {
+; CHECK-LABEL: za_zt_new_caller_za_preserved_callee:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsv...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76968


More information about the llvm-commits mailing list