[llvm] [AArch64] SME implementation for agnostic-ZA functions (PR #120150)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 23 05:54:43 PST 2024


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/120150

>From 4510b01328b3027451535d3869e6cf8f7c2a653c Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Wed, 4 Sep 2024 15:45:45 +0100
Subject: [PATCH 1/4] [AArch64] SME implementation for agnostic-ZA functions

This implements the lowering of calls from agnostic-ZA
functions to non-agnostic-ZA functions, using the ABI routines
`__arm_sme_state_size`, `__arm_sme_save` and `__arm_sme_restore`.

This implements the proposal described in the following PRs:
* https://github.com/ARM-software/acle/pull/336
* https://github.com/ARM-software/abi-aa/pull/264
---
 llvm/lib/IR/Verifier.cpp                      |  24 ++--
 llvm/lib/Target/AArch64/AArch64FastISel.cpp   |   3 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    | 129 +++++++++++++++++-
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   6 +
 .../AArch64/AArch64MachineFunctionInfo.h      |  14 ++
 .../lib/Target/AArch64/AArch64SMEInstrInfo.td |  16 +++
 .../AArch64/AArch64TargetTransformInfo.cpp    |   8 +-
 .../AArch64/Utils/AArch64SMEAttributes.cpp    |   9 ++
 .../AArch64/Utils/AArch64SMEAttributes.h      |  17 ++-
 llvm/test/CodeGen/AArch64/sme-agnostic-za.ll  |  84 ++++++++++++
 .../AArch64/sme-disable-gisel-fisel.ll        |  24 ++++
 llvm/test/Verifier/sme-attributes.ll          |  46 ++++---
 12 files changed, 342 insertions(+), 38 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sme-agnostic-za.ll

diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 55de486e90e190..d216b870281cc3 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2264,19 +2264,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
   Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
          Attrs.hasFnAttr("aarch64_inout_za") +
          Attrs.hasFnAttr("aarch64_out_za") +
-         Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
+         Attrs.hasFnAttr("aarch64_preserves_za") +
+         Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
         "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
-        "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
+        "'aarch64_inout_za', 'aarch64_preserves_za' and "
+        "'aarch64_za_state_agnostic' are mutually exclusive",
         V);
 
-  Check(
-      (Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
-       Attrs.hasFnAttr("aarch64_inout_zt0") +
-       Attrs.hasFnAttr("aarch64_out_zt0") +
-       Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1,
-      "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
-      "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive",
-      V);
+  Check((Attrs.hasFnAttr("aarch64_new_zt0") +
+         Attrs.hasFnAttr("aarch64_in_zt0") +
+         Attrs.hasFnAttr("aarch64_inout_zt0") +
+         Attrs.hasFnAttr("aarch64_out_zt0") +
+         Attrs.hasFnAttr("aarch64_preserves_zt0") +
+         Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
+        "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
+        "'aarch64_inout_zt0', 'aarch64_preserves_zt0' and "
+        "'aarch64_za_state_agnostic' are mutually exclusive",
+        V);
 
   if (Attrs.hasFnAttr(Attribute::JumpTable)) {
     const GlobalValue *GV = cast<GlobalValue>(V);
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 9f0f23b6e6a658..738895998c1195 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
   SMEAttrs CallerAttrs(*FuncInfo.Fn);
   if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
       CallerAttrs.hasStreamingInterfaceOrBody() ||
-      CallerAttrs.hasStreamingCompatibleInterface())
+      CallerAttrs.hasStreamingCompatibleInterface() ||
+      CallerAttrs.hasAgnosticZAInterface())
     return nullptr;
   return new AArch64FastISel(FuncInfo, LibInfo);
 }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a86ee5a6b64d27..152fb8d7469cd9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2631,6 +2631,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     break;
     MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
     MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
+    MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
+    MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
     MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
     MAKE_CASE(AArch64ISD::VG_SAVE)
     MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3218,6 +3220,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
   return BB;
 }
 
+// TODO: Find a way to merge this with EmitAllocateZABuffer.
+MachineBasicBlock *
+AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
+                                                 MachineBasicBlock *BB) const {
+  MachineFunction *MF = BB->getParent();
+  MachineFrameInfo &MFI = MF->getFrameInfo();
+  AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+  assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
+         "Lazy ZA save is not yet supported on Windows");
+
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  if (FuncInfo->getSMESaveBufferUsed()) {
+    // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
+    auto Size = MI.getOperand(1).getReg();
+    auto Dest = MI.getOperand(0).getReg();
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
+        .addReg(AArch64::SP)
+        .addReg(Size)
+        .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+            AArch64::SP)
+        .addReg(Dest);
+
+    // We have just allocated a variable sized object, tell this to PEI.
+    MFI.CreateVariableSizedObject(Align(16), nullptr);
+  } else
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
+            MI.getOperand(0).getReg());
+
+  BB->remove_instr(&MI);
+  return BB;
+}
+
 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     MachineInstr &MI, MachineBasicBlock *BB) const {
 
@@ -3252,6 +3287,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     return EmitInitTPIDR2Object(MI, BB);
   case AArch64::AllocateZABuffer:
     return EmitAllocateZABuffer(MI, BB);
+  case AArch64::AllocateSMESaveBuffer:
+    return EmitAllocateSMESaveBuffer(MI, BB);
+  case AArch64::GetSMESaveSize: {
+    // If the buffer is used, emit a call to __arm_sme_state_size()
+    MachineFunction *MF = BB->getParent();
+    AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+    const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+    if (FuncInfo->getSMESaveBufferUsed()) {
+      const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
+          .addExternalSymbol("__arm_sme_state_size")
+          .addReg(AArch64::X0, RegState::ImplicitDefine)
+          .addRegMask(TRI->getCallPreservedMask(
+              *MF, CallingConv::
+                       AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+              MI.getOperand(0).getReg())
+          .addReg(AArch64::X0);
+    } else
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+              MI.getOperand(0).getReg())
+          .addReg(AArch64::XZR);
+    BB->remove_instr(&MI);
+    return BB;
+  }
   case AArch64::F128CSEL:
     return EmitF128CSEL(MI, BB);
   case TargetOpcode::STATEPOINT:
@@ -7651,6 +7711,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
   case CallingConv::AArch64_VectorCall:
   case CallingConv::AArch64_SVE_VectorCall:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+  case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
     return CC_AArch64_AAPCS;
   case CallingConv::ARM64EC_Thunk_X64:
@@ -8110,6 +8171,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
     Chain = DAG.getNode(
         AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
         {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
+  } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
+    // Call __arm_sme_state_size().
+    SDValue BufferSize =
+        DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
+                    DAG.getVTList(MVT::i64, MVT::Other), Chain);
+    Chain = BufferSize.getValue(1);
+
+    SDValue Buffer;
+    if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
+      Buffer =
+          DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
+                      DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
+    } else {
+      // Allocate space dynamically.
+      Buffer = DAG.getNode(
+          ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
+          {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
+      MFI.CreateVariableSizedObject(Align(16), nullptr);
+    }
+
+    // Copy the value to a virtual register, and save that in FuncInfo.
+    Register BufferPtr =
+        MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+    FuncInfo->setSMESaveBufferAddr(BufferPtr);
+    Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
   }
 
   if (CallConv == CallingConv::PreserveNone) {
@@ -8398,6 +8484,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
   auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
   if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
       CallerAttrs.requiresLazySave(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
       CallerAttrs.hasStreamingBody())
     return false;
 
@@ -8722,6 +8809,30 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
   return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
 }
 
+// Emit a call to __arm_sme_save or __arm_sme_restore.
+static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
+                                       SelectionDAG &DAG,
+                                       AArch64FunctionInfo *Info, SDLoc DL,
+                                       SDValue Chain, bool IsSave) {
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.Ty = PointerType::getUnqual(*DAG.getContext());
+  Entry.Node =
+      DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
+  Args.push_back(Entry);
+
+  SDValue Callee =
+      DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
+                            TLI.getPointerTy(DAG.getDataLayout()));
+  auto *RetTy = Type::getVoidTy(*DAG.getContext());
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+      CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
+      Callee, std::move(Args));
+
+  return TLI.LowerCallTo(CLI).second;
+}
+
 static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
                                const SMEAttrs &CalleeAttrs) {
   if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8882,6 +8993,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   };
 
   bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
+  bool RequiresSaveAllZA =
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
+  SDValue ZAStateBuffer;
   if (RequiresLazySave) {
     const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     MachinePointerInfo MPI =
@@ -8908,6 +9022,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                                                    &MF.getFunction());
       return DescribeCallsite(R) << " sets up a lazy save for ZA";
     });
+  } else if (RequiresSaveAllZA) {
+    assert(!CalleeAttrs.hasSharedZAInterface() &&
+           "Cannot share state that may not exist");
+    Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+                                    /*IsSave=*/true);
   }
 
   SDValue PStateSM;
@@ -9455,9 +9574,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
         DAG.getConstant(0, DL, MVT::i64));
     TPIDR2.Uses++;
+  } else if (RequiresSaveAllZA) {
+    Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+                                     /*IsSave=*/false);
+    FuncInfo->setSMESaveBufferUsed();
   }
 
-  if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
+  if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
+      RequiresSaveAllZA) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
@@ -28063,7 +28187,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
     auto CalleeAttrs = SMEAttrs(*Base);
     if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
         CallerAttrs.requiresLazySave(CalleeAttrs) ||
-        CallerAttrs.requiresPreservingZT0(CalleeAttrs))
+        CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+        CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
       return true;
   }
   return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d51b36f7e49946..8621aa81edfb2f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -466,6 +466,10 @@ enum NodeType : unsigned {
   ALLOCATE_ZA_BUFFER,
   INIT_TPIDR2OBJ,
 
+  // Needed for __arm_agnostic("sme_za_state")
+  GET_SME_SAVE_SIZE,
+  ALLOC_SME_SAVE_BUFFER,
+
   // Asserts that a function argument (i32) is zero-extended to i8 by
   // the caller
   ASSERT_ZEXT_BOOL,
@@ -663,6 +667,8 @@ class AArch64TargetLowering : public TargetLowering {
                                           MachineBasicBlock *BB) const;
   MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
                                           MachineBasicBlock *BB) const;
+  MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
+                                               MachineBasicBlock *BB) const;
 
   MachineBasicBlock *
   EmitInstrWithCustomInserter(MachineInstr &MI,
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index a77fdaf19bcf5f..7fd3a6c560329c 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   // on function entry to record the initial pstate of a function.
   Register PStateSMReg = MCRegister::NoRegister;
 
+  // Holds a pointer to a buffer that is large enough to represent
+  // all SME ZA state and any additional state required by the
+  // __arm_sme_save/restore support routines.
+  Register SMESaveBufferAddr = MCRegister::NoRegister;
+
+  // true if SMESaveBufferAddr is used.
+  bool SMESaveBufferUsed = false;
+
   // Has the PNReg used to build PTRUE instruction.
   // The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
   unsigned PredicateRegForFillSpill = 0;
@@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
     return PredicateRegForFillSpill;
   }
 
+  Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
+  void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
+
+  unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; };
+  void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };
+
   Register getPStateSMReg() const { return PStateSMReg; };
   void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
 
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index fa577cf92e99d1..ac11b048340498 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
   def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
 }
 
+// Nodes to allocate a save buffer for SME.
+def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
+                               [SDTCisInt<0>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [X0] in {
+  def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
+}
+def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;
+
+def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
+                                          [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [SP] in {
+  def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
+}
+def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
+          (AllocateSMESaveBuffer $size)>;
+
 //===----------------------------------------------------------------------===//
 // Instruction naming conventions.
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 6c2e04c3f8a7c1..98e0cefc3c844b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -261,7 +261,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
 
   if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
       CallerAttrs.requiresSMChange(CalleeAttrs) ||
-      CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
+      CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
+    if (hasPossibleIncompatibleOps(Callee))
+      return false;
+  }
+
+  if (CalleeAttrs.hasAgnosticZAInterface()) {
     if (hasPossibleIncompatibleOps(Callee))
       return false;
   }
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 015ca4cb92b25e..bf16acd7f8f7e1 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
                         isPreservesZT0())) &&
       "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
       "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
+
+  assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
+         "Function cannot have a shared-ZA interface and an agnostic-ZA "
+         "interface");
 }
 
 SMEAttrs::SMEAttrs(const CallBase &CB) {
@@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
       FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
     Bitmask |= SMEAttrs::SM_Compatible;
+  if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
+      FuncName == "__arm_sme_state_size")
+    Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= SM_Compatible;
   if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
     Bitmask |= SM_Body;
+  if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
+    Bitmask |= ZA_State_Agnostic;
   if (Attrs.hasFnAttr("aarch64_in_za"))
     Bitmask |= encodeZAState(StateValue::In);
   if (Attrs.hasFnAttr("aarch64_out_za"))
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 4c7c1c9b079538..fb093da70c46b6 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -42,9 +42,10 @@ class SMEAttrs {
     SM_Compatible = 1 << 1,   // aarch64_pstate_sm_compatible
     SM_Body = 1 << 2,         // aarch64_pstate_sm_body
     SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
-    ZA_Shift = 4,
+    ZA_State_Agnostic = 1 << 4,
+    ZA_Shift = 5,
     ZA_Mask = 0b111 << ZA_Shift,
-    ZT0_Shift = 7,
+    ZT0_Shift = 8,
     ZT0_Mask = 0b111 << ZT0_Shift
   };
 
@@ -96,8 +97,11 @@ class SMEAttrs {
     return State == StateValue::In || State == StateValue::Out ||
            State == StateValue::InOut || State == StateValue::Preserved;
   }
+  bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; }
   bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); }
-  bool hasPrivateZAInterface() const { return !hasSharedZAInterface(); }
+  bool hasPrivateZAInterface() const {
+    return !hasSharedZAInterface() && !hasAgnosticZAInterface();
+  }
   bool hasZAState() const { return isNewZA() || sharesZA(); }
   bool requiresLazySave(const SMEAttrs &Callee) const {
     return hasZAState() && Callee.hasPrivateZAInterface() &&
@@ -128,7 +132,8 @@ class SMEAttrs {
   }
   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
   bool requiresPreservingZT0(const SMEAttrs &Callee) const {
-    return hasZT0State() && !Callee.sharesZT0();
+    return hasZT0State() && !Callee.sharesZT0() &&
+           !Callee.hasAgnosticZAInterface();
   }
   bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
     return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
@@ -137,6 +142,10 @@ class SMEAttrs {
   bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
     return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
   }
+  bool requiresPreservingAllZAState(const SMEAttrs &Callee) const {
+    return hasAgnosticZAInterface() && !Callee.hasAgnosticZAInterface() &&
+           !(Callee.Bitmask & SME_ABI_Routine);
+  }
 };
 
 } // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
new file mode 100644
index 00000000000000..2e613118acbe04
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
@@ -0,0 +1,84 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sme2 < %s | FileCheck %s
+
+target triple = "aarch64"
+
+declare i64 @private_za_decl(i64)
+declare i64 @agnostic_decl(i64) "aarch64_za_state_agnostic"
+
+; No calls. Test that no buffer is allocated.
+define i64 @agnostic_caller_no_callees(ptr %ptr) nounwind "aarch64_za_state_agnostic" {
+; CHECK-LABEL: agnostic_caller_no_callees:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr x0, [x0]
+; CHECK-NEXT:    ret
+  %v = load i64, ptr %ptr
+  ret i64 %v
+}
+
+; agnostic-ZA -> private-ZA
+;
+; Test that a buffer is allocated and that the appropriate save/restore calls are
+; inserted for calls to non-agnostic functions and that the arg/result registers are
+; preserved by the register allocator.
+define i64 @agnostic_caller_private_za_callee(i64 %v) nounwind "aarch64_za_state_agnostic" {
+; CHECK-LABEL: agnostic_caller_private_za_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    bl __arm_sme_state_size
+; CHECK-NEXT:    sub x19, sp, x0
+; CHECK-NEXT:    mov sp, x19
+; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    bl __arm_sme_save
+; CHECK-NEXT:    mov x0, x8
+; CHECK-NEXT:    bl private_za_decl
+; CHECK-NEXT:    mov x1, x0
+; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    bl __arm_sme_restore
+; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    bl __arm_sme_save
+; CHECK-NEXT:    mov x0, x1
+; CHECK-NEXT:    bl private_za_decl
+; CHECK-NEXT:    mov x1, x0
+; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    bl __arm_sme_restore
+; CHECK-NEXT:    mov x0, x1
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  %res = call i64 @private_za_decl(i64 %v)
+  %res2 = call i64 @private_za_decl(i64 %res)
+  ret i64 %res2
+}
+
+; agnostic-ZA -> agnostic-ZA
+;
+; Should not result in save/restore code.
+define i64 @agnostic_caller_agnostic_callee(i64 %v) nounwind "aarch64_za_state_agnostic" {
+; CHECK-LABEL: agnostic_caller_agnostic_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl agnostic_decl
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  %res = call i64 @agnostic_decl(i64 %v)
+  ret i64 %res
+}
+
+; shared-ZA -> agnostic-ZA
+;
+; Should not result in lazy-save or save of ZT0
+define i64 @shared_caller_agnostic_callee(i64 %v) nounwind "aarch64_inout_za" "aarch64_inout_zt0" {
+; CHECK-LABEL: shared_caller_agnostic_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl agnostic_decl
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  %res = call i64 @agnostic_decl(i64 %v)
+  ret i64 %res
+}
diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
index 42dba22d257089..d9dc2ad841f167 100644
--- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
+++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
@@ -526,3 +526,27 @@ entry:
   %add = fadd double %call, 4.200000e+01
   ret double %add;
 }
+
+define void @agnostic_za_function(ptr %ptr) nounwind "aarch64_za_state_agnostic" {
+; CHECK-COMMON-LABEL: agnostic_za_function:
+; CHECK-COMMON:       // %bb.0:
+; CHECK-COMMON-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-COMMON-NEXT:    stp x20, x19, [sp, #16] // 16-byte Folded Spill
+; CHECK-COMMON-NEXT:    mov x29, sp
+; CHECK-COMMON-NEXT:    mov x8, x0
+; CHECK-COMMON-NEXT:    bl __arm_sme_state_size
+; CHECK-COMMON-NEXT:    sub x20, sp, x0
+; CHECK-COMMON-NEXT:    mov sp, x20
+; CHECK-COMMON-NEXT:    mov x0, x20
+; CHECK-COMMON-NEXT:    bl __arm_sme_save
+; CHECK-COMMON-NEXT:    blr x8
+; CHECK-COMMON-NEXT:    mov x0, x20
+; CHECK-COMMON-NEXT:    bl __arm_sme_restore
+; CHECK-COMMON-NEXT:    mov sp, x29
+; CHECK-COMMON-NEXT:    ldp x20, x19, [sp, #16] // 16-byte Folded Reload
+; CHECK-COMMON-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-COMMON-NEXT:    ret
+  call void %ptr()
+  ret void
+}
+
diff --git a/llvm/test/Verifier/sme-attributes.ll b/llvm/test/Verifier/sme-attributes.ll
index 3d01613ebf2fe1..4bf5e813daf2f3 100644
--- a/llvm/test/Verifier/sme-attributes.ll
+++ b/llvm/test/Verifier/sme-attributes.ll
@@ -4,61 +4,67 @@ declare void @sm_attrs() "aarch64_pstate_sm_enabled" "aarch64_pstate_sm_compatib
 ; CHECK: Attributes 'aarch64_pstate_sm_enabled and aarch64_pstate_sm_compatible' are incompatible!
 
 declare void @za_new_preserved() "aarch64_new_za" "aarch64_preserves_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_new_in() "aarch64_new_za" "aarch64_in_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_new_inout() "aarch64_new_za" "aarch64_inout_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_new_out() "aarch64_new_za" "aarch64_out_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_preserved_in() "aarch64_preserves_za" "aarch64_in_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_preserved_inout() "aarch64_preserves_za" "aarch64_inout_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_preserved_out() "aarch64_preserves_za" "aarch64_out_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_in_inout() "aarch64_in_za" "aarch64_inout_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_in_out() "aarch64_in_za" "aarch64_out_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @za_inout_out() "aarch64_inout_za" "aarch64_out_za";
-; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
+
+declare void @za_inout_agnostic() "aarch64_inout_za" "aarch64_za_state_agnostic";
+; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_new_preserved() "aarch64_new_zt0" "aarch64_preserves_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_new_in() "aarch64_new_zt0" "aarch64_in_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_new_inout() "aarch64_new_zt0" "aarch64_inout_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_new_out() "aarch64_new_zt0" "aarch64_out_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_preserved_in() "aarch64_preserves_zt0" "aarch64_in_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_preserved_inout() "aarch64_preserves_zt0" "aarch64_inout_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_preserved_out() "aarch64_preserves_zt0" "aarch64_out_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_in_inout() "aarch64_in_zt0" "aarch64_inout_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_in_out() "aarch64_in_zt0" "aarch64_out_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
 
 declare void @zt0_inout_out() "aarch64_inout_zt0" "aarch64_out_zt0";
-; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
+
+declare void @zt0_inout_agnostic() "aarch64_inout_zt0" "aarch64_za_state_agnostic";
+; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive

>From 84209dd09eda09a06766e37b4cdd4c5542d869cd Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 23 Dec 2024 09:59:59 +0000
Subject: [PATCH 2/4] Address review comments

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 67 ++++++++++---------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  2 +
 llvm/test/CodeGen/AArch64/sme-agnostic-za.ll  |  4 +-
 .../AArch64/sme-disable-gisel-fisel.ll        |  4 +-
 4 files changed, 42 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 152fb8d7469cd9..ff46c568e0703f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3232,16 +3232,15 @@ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
 
   const TargetInstrInfo *TII = Subtarget->getInstrInfo();
   if (FuncInfo->getSMESaveBufferUsed()) {
-    // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
+    // Allocate a buffer object of the size given by MI.getOperand(1).
     auto Size = MI.getOperand(1).getReg();
     auto Dest = MI.getOperand(0).getReg();
-    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), AArch64::SP)
         .addReg(AArch64::SP)
         .addReg(Size)
         .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
-    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
-            AArch64::SP)
-        .addReg(Dest);
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), Dest)
+        .addReg(AArch64::SP);
 
     // We have just allocated a variable sized object, tell this to PEI.
     MFI.CreateVariableSizedObject(Align(16), nullptr);
@@ -3253,6 +3252,32 @@ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
   return BB;
 }
 
+MachineBasicBlock *
+AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
+                                          MachineBasicBlock *BB) const {
+  // If the buffer is used, emit a call to __arm_sme_state_size()
+  MachineFunction *MF = BB->getParent();
+  AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  if (FuncInfo->getSMESaveBufferUsed()) {
+    const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
+        .addExternalSymbol("__arm_sme_state_size")
+        .addReg(AArch64::X0, RegState::ImplicitDefine)
+        .addRegMask(TRI->getCallPreservedMask(
+            *MF, CallingConv::
+                     AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+            MI.getOperand(0).getReg())
+        .addReg(AArch64::X0);
+  } else
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+            MI.getOperand(0).getReg())
+        .addReg(AArch64::XZR);
+  BB->remove_instr(&MI);
+  return BB;
+}
+
 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     MachineInstr &MI, MachineBasicBlock *BB) const {
 
@@ -3289,29 +3314,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     return EmitAllocateZABuffer(MI, BB);
   case AArch64::AllocateSMESaveBuffer:
     return EmitAllocateSMESaveBuffer(MI, BB);
-  case AArch64::GetSMESaveSize: {
-    // If the buffer is used, emit a call to __arm_sme_state_size()
-    MachineFunction *MF = BB->getParent();
-    AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
-    const TargetInstrInfo *TII = Subtarget->getInstrInfo();
-    if (FuncInfo->getSMESaveBufferUsed()) {
-      const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
-      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
-          .addExternalSymbol("__arm_sme_state_size")
-          .addReg(AArch64::X0, RegState::ImplicitDefine)
-          .addRegMask(TRI->getCallPreservedMask(
-              *MF, CallingConv::
-                       AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
-      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
-              MI.getOperand(0).getReg())
-          .addReg(AArch64::X0);
-    } else
-      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
-              MI.getOperand(0).getReg())
-          .addReg(AArch64::XZR);
-    BB->remove_instr(&MI);
-    return BB;
-  }
+  case AArch64::GetSMESaveSize:
+    return EmitGetSMESaveSize(MI, BB);
   case AArch64::F128CSEL:
     return EmitF128CSEL(MI, BB);
   case TargetOpcode::STATEPOINT:
@@ -8814,6 +8818,10 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
                                        SelectionDAG &DAG,
                                        AArch64FunctionInfo *Info, SDLoc DL,
                                        SDValue Chain, bool IsSave) {
+  MachineFunction &MF = DAG.getMachineFunction();
+  AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+  FuncInfo->setSMESaveBufferUsed();
+
   TargetLowering::ArgListTy Args;
   TargetLowering::ArgListEntry Entry;
   Entry.Ty = PointerType::getUnqual(*DAG.getContext());
@@ -8829,7 +8837,6 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
   CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
       CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
       Callee, std::move(Args));
-
   return TLI.LowerCallTo(CLI).second;
 }
 
@@ -8995,7 +9002,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
   bool RequiresSaveAllZA =
       CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
-  SDValue ZAStateBuffer;
   if (RequiresLazySave) {
     const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     MachinePointerInfo MPI =
@@ -9577,7 +9583,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   } else if (RequiresSaveAllZA) {
     Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
                                      /*IsSave=*/false);
-    FuncInfo->setSMESaveBufferUsed();
   }
 
   if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 8621aa81edfb2f..9e3dad63663586 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -669,6 +669,8 @@ class AArch64TargetLowering : public TargetLowering {
                                           MachineBasicBlock *BB) const;
   MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
                                                MachineBasicBlock *BB) const;
+  MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI,
+                                        MachineBasicBlock *BB) const;
 
   MachineBasicBlock *
   EmitInstrWithCustomInserter(MachineInstr &MI,
diff --git a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
index 2e613118acbe04..97522b9a319c09 100644
--- a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
+++ b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
@@ -29,8 +29,8 @@ define i64 @agnostic_caller_private_za_callee(i64 %v) nounwind "aarch64_za_state
 ; CHECK-NEXT:    mov x29, sp
 ; CHECK-NEXT:    mov x8, x0
 ; CHECK-NEXT:    bl __arm_sme_state_size
-; CHECK-NEXT:    sub x19, sp, x0
-; CHECK-NEXT:    mov sp, x19
+; CHECK-NEXT:    sub sp, sp, x0
+; CHECK-NEXT:    mov x19, sp
 ; CHECK-NEXT:    mov x0, x19
 ; CHECK-NEXT:    bl __arm_sme_save
 ; CHECK-NEXT:    mov x0, x8
diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
index d9dc2ad841f167..fc0208d605dd71 100644
--- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
+++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
@@ -535,8 +535,8 @@ define void @agnostic_za_function(ptr %ptr) nounwind "aarch64_za_state_agnostic"
 ; CHECK-COMMON-NEXT:    mov x29, sp
 ; CHECK-COMMON-NEXT:    mov x8, x0
 ; CHECK-COMMON-NEXT:    bl __arm_sme_state_size
-; CHECK-COMMON-NEXT:    sub x20, sp, x0
-; CHECK-COMMON-NEXT:    mov sp, x20
+; CHECK-COMMON-NEXT:    sub sp, sp, x0
+; CHECK-COMMON-NEXT:    mov x20, sp
 ; CHECK-COMMON-NEXT:    mov x0, x20
 ; CHECK-COMMON-NEXT:    bl __arm_sme_save
 ; CHECK-COMMON-NEXT:    blr x8

>From 8add64a57fb69f7b4ea6c89a9b0ad8a4ebac93a8 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 23 Dec 2024 13:48:44 +0000
Subject: [PATCH 3/4] Add inline-asm to test, so they actually test what
 they're meant to test (I can precommit this)

---
 .../Inline/AArch64/sme-pstateza-attrs.ll      | 19 ++++++++++++++-----
 1 file changed, 14 insertions(+), 5 deletions(-)

diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
index 5e638103a2b063..df1ec0d7de2d4e 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
@@ -14,10 +14,12 @@ define void @nonza_callee() {
 ; CHECK-LABEL: define void @nonza_callee
 ; CHECK-SAME: () #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
+  call void asm sideeffect "; inlineasm", ""()
   call void @inlined_body()
   ret void
 }
@@ -26,10 +28,12 @@ define void @shared_za_callee() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_callee
 ; CHECK-SAME: () #[[ATTR1:[0-9]+]] {
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
+  call void asm sideeffect "; inlineasm", ""()
   call void @inlined_body()
   ret void
 }
@@ -37,9 +41,11 @@ entry:
 define void @new_za_callee() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_callee
 ; CHECK-SAME: () #[[ATTR2:[0-9]+]] {
+; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
+  call void asm sideeffect "; inlineasm", ""()
   call void @inlined_body()
   ret void
 }
@@ -49,7 +55,7 @@ define void @new_za_callee() "aarch64_new_za" {
 ; Test for a number of combinations, where:
 ; N   Not using ZA.
 ; S   Shared ZA interface
-; Z   New ZA interface
+; Z   New ZA with Private-ZA interface
 
 ; [x] N -> N
 ; [ ] N -> S (This combination is invalid)
@@ -58,6 +64,7 @@ define void @nonza_caller_nonza_callee_inline() {
 ; CHECK-LABEL: define void @nonza_caller_nonza_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
@@ -84,11 +91,11 @@ entry:
 ; [x] Z -> N
 ; [ ] Z -> S
 ; [ ] Z -> Z
-define void @new_za_caller_nonza_callee_inline() "aarch64_new_za" {
-; CHECK-LABEL: define void @new_za_caller_nonza_callee_inline
+define void @new_za_caller_nonza_callee_dont_inline() "aarch64_new_za" {
+; CHECK-LABEL: define void @new_za_caller_nonza_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @inlined_body()
+; CHECK-NEXT:    call void @nonza_callee()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -103,6 +110,7 @@ define void @new_za_caller_shared_za_callee_inline() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_caller_shared_za_callee_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
@@ -133,7 +141,7 @@ define void @shared_za_caller_nonza_callee_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_nonza_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @inlined_body()
+; CHECK-NEXT:    call void @nonza_callee()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -163,6 +171,7 @@ define void @shared_za_caller_shared_za_callee_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_shared_za_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;

>From fc9d150df1fa0b8b88ae67707825ca2ba981dc41 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 23 Dec 2024 13:40:45 +0000
Subject: [PATCH 4/4] Fix inlining

---
 .../AArch64/AArch64TargetTransformInfo.cpp    |   5 -
 .../Inline/AArch64/sme-pstateza-attrs.ll      | 147 +++++++++++++++++-
 2 files changed, 143 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 98e0cefc3c844b..5816d29eef81ee 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -267,11 +267,6 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
       return false;
   }
 
-  if (CalleeAttrs.hasAgnosticZAInterface()) {
-    if (hasPossibleIncompatibleOps(Callee))
-      return false;
-  }
-
   return BaseT::areInlineCompatible(Caller, Callee);
 }
 
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
index df1ec0d7de2d4e..8ec4f0762117ab 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
@@ -50,16 +50,30 @@ define void @new_za_callee() "aarch64_new_za" {
   ret void
 }
 
+define void @agnostic_za_callee() "aarch64_za_state_agnostic" {
+; CHECK-LABEL: define void @agnostic_za_callee
+; CHECK-SAME: () #[[ATTR3:[0-9]+]] {
+; CHECK-NEXT:    call void asm sideeffect "
+; CHECK-NEXT:    call void @inlined_body()
+; CHECK-NEXT:    ret void
+;
+  call void asm sideeffect "; inlineasm", ""()
+  call void @inlined_body()
+  ret void
+}
+
 ;
 ; Now test that inlining only happens when no lazy-save is needed.
 ; Test for a number of combinations, where:
 ; N   Not using ZA.
 ; S   Shared ZA interface
 ; Z   New ZA with Private-ZA interface
+; A   Agnostic ZA interface
 
 ; [x] N -> N
 ; [ ] N -> S (This combination is invalid)
 ; [ ] N -> Z
+; [ ] N -> A
 define void @nonza_caller_nonza_callee_inline() {
 ; CHECK-LABEL: define void @nonza_caller_nonza_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
@@ -76,6 +90,7 @@ entry:
 ; [ ] N -> N
 ; [ ] N -> S (This combination is invalid)
 ; [x] N -> Z
+; [ ] N -> A
 define void @nonza_caller_new_za_callee_dont_inline() {
 ; CHECK-LABEL: define void @nonza_caller_new_za_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR0]] {
@@ -88,9 +103,27 @@ entry:
   ret void
 }
 
+; [ ] N -> N
+; [ ] N -> S (This combination is invalid)
+; [ ] N -> Z
+; [x] N -> A
+define void @nonza_caller_agnostic_za_callee_inline() {
+; CHECK-LABEL: define void @nonza_caller_agnostic_za_callee_inline
+; CHECK-SAME: () #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
+; CHECK-NEXT:    call void @inlined_body()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @agnostic_za_callee()
+  ret void
+}
+
 ; [x] Z -> N
 ; [ ] Z -> S
 ; [ ] Z -> Z
+; [ ] Z -> A
 define void @new_za_caller_nonza_callee_dont_inline() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_caller_nonza_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR2]] {
@@ -106,6 +139,7 @@ entry:
 ; [ ] Z -> N
 ; [x] Z -> S
 ; [ ] Z -> Z
+; [ ] Z -> A
 define void @new_za_caller_shared_za_callee_inline() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_caller_shared_za_callee_inline
 ; CHECK-SAME: () #[[ATTR2]] {
@@ -122,6 +156,7 @@ entry:
 ; [ ] Z -> N
 ; [ ] Z -> S
 ; [x] Z -> Z
+; [ ] Z -> A
 define void @new_za_caller_new_za_callee_dont_inline() "aarch64_new_za" {
 ; CHECK-LABEL: define void @new_za_caller_new_za_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR2]] {
@@ -134,9 +169,27 @@ entry:
   ret void
 }
 
-; [x] Z -> N
+; [ ] Z -> N
 ; [ ] Z -> S
 ; [ ] Z -> Z
+; [x] Z -> A
+define void @new_za_caller_agnostic_za_callee_inline() "aarch64_new_za" {
+; CHECK-LABEL: define void @new_za_caller_agnostic_za_callee_inline
+; CHECK-SAME: () #[[ATTR2]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
+; CHECK-NEXT:    call void @inlined_body()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @agnostic_za_callee()
+  ret void
+}
+
+; [x] S -> N
+; [ ] S -> S
+; [ ] S -> Z
+; [ ] S -> A
 define void @shared_za_caller_nonza_callee_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_nonza_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
@@ -152,6 +205,7 @@ entry:
 ; [ ] S -> N
 ; [x] S -> Z
 ; [ ] S -> S
+; [ ] S -> A
 define void @shared_za_caller_new_za_callee_dont_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_new_za_callee_dont_inline
 ; CHECK-SAME: () #[[ATTR1]] {
@@ -167,6 +221,7 @@ entry:
 ; [ ] S -> N
 ; [ ] S -> Z
 ; [x] S -> S
+; [ ] S -> A
 define void @shared_za_caller_shared_za_callee_inline() "aarch64_inout_za" {
 ; CHECK-LABEL: define void @shared_za_caller_shared_za_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
@@ -180,6 +235,90 @@ entry:
   ret void
 }
 
+; [ ] S -> N
+; [ ] S -> Z
+; [ ] S -> S
+; [x] S -> A
+define void @shared_za_caller_agnostic_za_callee_inline() "aarch64_inout_za" {
+; CHECK-LABEL: define void @shared_za_caller_agnostic_za_callee_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
+; CHECK-NEXT:    call void @inlined_body()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @agnostic_za_callee()
+  ret void
+}
+
+; [x] A -> N
+; [ ] A -> Z
+; [ ] A -> S
+; [ ] A -> A
+define void @agnostic_za_caller_nonza_callee_dont_inline() "aarch64_za_state_agnostic" {
+; CHECK-LABEL: define void @agnostic_za_caller_nonza_callee_dont_inline
+; CHECK-SAME: () #[[ATTR3]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @nonza_callee()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @nonza_callee()
+  ret void
+}
+
+; [ ] A -> N
+; [x] A -> Z
+; [ ] A -> S
+; [ ] A -> A
+define void @agnostic_za_caller_now_za_callee_dont_inline() "aarch64_za_state_agnostic" {
+; CHECK-LABEL: define void @agnostic_za_caller_now_za_callee_dont_inline
+; CHECK-SAME: () #[[ATTR3]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @new_za_callee()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @new_za_callee()
+  ret void
+}
+
+; [ ] A -> N
+; [ ] A -> Z
+; [x] A -> S (invalid)
+; [ ] A -> A
+define void @agnostic_za_caller_shared_za_callee_dont_inline() "aarch64_za_state_agnostic" {
+; CHECK-LABEL: define void @agnostic_za_caller_shared_za_callee_dont_inline
+; CHECK-SAME: () #[[ATTR3]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @shared_za_callee()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @shared_za_callee()
+  ret void
+}
+
+; [ ] A -> N
+; [ ] A -> Z
+; [ ] A -> S
+; [x] A -> A
+define void @agnostic_za_caller_agnostic_za_callee_dont_inline() "aarch64_za_state_agnostic" {
+; CHECK-LABEL: define void @agnostic_za_caller_agnostic_za_callee_dont_inline
+; CHECK-SAME: () #[[ATTR3]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void asm sideeffect "
+; CHECK-NEXT:    call void @inlined_body()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @agnostic_za_callee()
+  ret void
+}
+
+
+
 define void @private_za_callee_call_za_disable() {
 ; CHECK-LABEL: define void @private_za_callee_call_za_disable
 ; CHECK-SAME: () #[[ATTR0]] {
@@ -254,7 +393,7 @@ define void @nonzt0_callee() {
 
 define void @shared_zt0_caller_nonzt0_callee_dont_inline() "aarch64_inout_zt0" {
 ; CHECK-LABEL: define void @shared_zt0_caller_nonzt0_callee_dont_inline
-; CHECK-SAME: () #[[ATTR3:[0-9]+]] {
+; CHECK-SAME: () #[[ATTR4:[0-9]+]] {
 ; CHECK-NEXT:    call void @nonzt0_callee()
 ; CHECK-NEXT:    ret void
 ;
@@ -264,7 +403,7 @@ define void @shared_zt0_caller_nonzt0_callee_dont_inline() "aarch64_inout_zt0" {
 
 define void @shared_zt0_callee() "aarch64_inout_zt0" {
 ; CHECK-LABEL: define void @shared_zt0_callee
-; CHECK-SAME: () #[[ATTR3]] {
+; CHECK-SAME: () #[[ATTR4]] {
 ; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
@@ -276,7 +415,7 @@ define void @shared_zt0_callee() "aarch64_inout_zt0" {
 
 define void @shared_zt0_caller_shared_zt0_callee_inline() "aarch64_inout_zt0" {
 ; CHECK-LABEL: define void @shared_zt0_caller_shared_zt0_callee_inline
-; CHECK-SAME: () #[[ATTR3]] {
+; CHECK-SAME: () #[[ATTR4]] {
 ; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void



More information about the llvm-commits mailing list