[llvm] f7f44f0 - [AArch64][SME] Set up a lazy-save/restore around calls.

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 5 07:21:18 PDT 2022


Author: Kerry McLaughlin
Date: 2022-10-05T14:36:53+01:00
New Revision: f7f44f018f6ffbd11cd3bd2ffafbdb15725aac86

URL: https://github.com/llvm/llvm-project/commit/f7f44f018f6ffbd11cd3bd2ffafbdb15725aac86
DIFF: https://github.com/llvm/llvm-project/commit/f7f44f018f6ffbd11cd3bd2ffafbdb15725aac86.diff

LOG: [AArch64][SME] Set up a lazy-save/restore around calls.

Setting up a lazy-save mechanism around calls is done during SelectionDAG
because calls to intrinsics may be expanded into an actual function call
(e.g. calls to @llvm.cos()), and maintaining an allowed-list in the SMEABI
pass is not feasible.

The approach for conditionally restoring the lazy-save based on the runtime
value of TPIDR2_EL0 is similar to how we handle conditional smstart/smstop.
We create a pseudo-node which gets expanded into a conditional branch and
expands to a call to __arm_tpidr2_restore(%tpidr2_object_ptr).

The lazy-save buffer and TPIDR2 block are only allocated once at the start
of the function. For each call, the TPIDR2 block is initialised, and at
the end of the call, a pseudo node (RestoreZA) is planted.

Patch by Sander de Smalen.

Differential Revision: https://reviews.llvm.org/D133900

Added: 
    llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
    llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
    llvm/lib/Target/AArch64/AArch64RegisterInfo.h
    llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
    llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 318f8b8554db..62a0ae5a264e 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -89,6 +89,8 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
   bool expandCALL_BTI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI);
   bool expandStoreSwiftAsyncContext(MachineBasicBlock &MBB,
                                     MachineBasicBlock::iterator MBBI);
+  MachineBasicBlock *expandRestoreZA(MachineBasicBlock &MBB,
+                                     MachineBasicBlock::iterator MBBI);
   MachineBasicBlock *expandCondSMToggle(MachineBasicBlock &MBB,
                                         MachineBasicBlock::iterator MBBI);
 };
@@ -851,6 +853,48 @@ bool AArch64ExpandPseudo::expandStoreSwiftAsyncContext(
   return true;
 }
 
+MachineBasicBlock *
+AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB,
+                                     MachineBasicBlock::iterator MBBI) {
+  MachineInstr &MI = *MBBI;
+  assert((std::next(MBBI) != MBB.end() ||
+          MI.getParent()->successors().begin() !=
+              MI.getParent()->successors().end()) &&
+         "Unexpected unreachable in block that restores ZA");
+
+  // Compare TPIDR2_EL0 value against 0.
+  DebugLoc DL = MI.getDebugLoc();
+  MachineInstrBuilder Cbz = BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBZX))
+                                .add(MI.getOperand(0));
+
+  // Split MBB and create two new blocks:
+  //  - MBB now contains all instructions before RestoreZAPseudo.
+  //  - SMBB contains the RestoreZAPseudo instruction only.
+  //  - EndBB contains all instructions after RestoreZAPseudo.
+  MachineInstr &PrevMI = *std::prev(MBBI);
+  MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true);
+  MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end()
+                                 ? *SMBB->successors().begin()
+                                 : SMBB->splitAt(MI, /*UpdateLiveIns*/ true);
+
+  // Add the SMBB label to the TB[N]Z instruction & create a branch to EndBB.
+  Cbz.addMBB(SMBB);
+  BuildMI(&MBB, DL, TII->get(AArch64::B))
+      .addMBB(EndBB);
+  MBB.addSuccessor(EndBB);
+
+  // Replace the pseudo with a call (BL).
+  MachineInstrBuilder MIB =
+      BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::BL));
+  MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);
+  for (unsigned I = 2; I < MI.getNumOperands(); ++I)
+    MIB.add(MI.getOperand(I));
+  BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
+
+  MI.eraseFromParent();
+  return EndBB;
+}
+
 MachineBasicBlock *
 AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
                                         MachineBasicBlock::iterator MBBI) {
@@ -1371,6 +1415,12 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
      return expandCALL_BTI(MBB, MBBI);
    case AArch64::StoreSwiftAsyncContext:
      return expandStoreSwiftAsyncContext(MBB, MBBI);
+   case AArch64::RestoreZAPseudo: {
+     auto *NewMBB = expandRestoreZA(MBB, MBBI);
+     if (NewMBB != &MBB)
+       NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
+     return true;
+   }
    case AArch64::MSRpstatePseudo: {
      auto *NewMBB = expandCondSMToggle(MBB, MBBI);
      if (NewMBB != &MBB)

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6865be2d04ca..43d9f8c2f9c9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2068,6 +2068,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::OBSCURE_COPY)
     MAKE_CASE(AArch64ISD::SMSTART)
     MAKE_CASE(AArch64ISD::SMSTOP)
+    MAKE_CASE(AArch64ISD::RESTORE_ZA)
     MAKE_CASE(AArch64ISD::CALL)
     MAKE_CASE(AArch64ISD::ADRP)
     MAKE_CASE(AArch64ISD::ADR)
@@ -5944,6 +5945,50 @@ AArch64TargetLowering::CCAssignFnForReturn(CallingConv::ID CC) const {
                                       : RetCC_AArch64_AAPCS;
 }
 
+
+/// Returns true if the Function has ZA state and contains at least one call to
+/// a function that requires setting up a lazy-save buffer.
+static bool requiresBufferForLazySave(const Function &F) {
+  SMEAttrs CallerAttrs(F);
+  if (!CallerAttrs.hasZAState())
+    return false;
+
+  for (const BasicBlock &BB : F)
+    for (const Instruction &I : BB)
+      if (const CallInst *Call = dyn_cast<CallInst>(&I))
+        if (CallerAttrs.requiresLazySave(SMEAttrs(*Call)))
+          return true;
+  return false;
+}
+
+unsigned AArch64TargetLowering::allocateLazySaveBuffer(
+    SDValue &Chain, const SDLoc &DL, SelectionDAG &DAG, Register &Reg) const {
+  MachineFunction &MF = DAG.getMachineFunction();
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+
+  // Allocate a lazy-save buffer object of size SVL.B * SVL.B (worst-case)
+  SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+                          DAG.getConstant(1, DL, MVT::i32));
+  SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N);
+  SDValue Ops[] = {Chain, NN, DAG.getConstant(1, DL, MVT::i64)};
+  SDVTList VTs = DAG.getVTList(MVT::i64, MVT::Other);
+  SDValue Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL, VTs, Ops);
+  unsigned FI = MFI.CreateVariableSizedObject(Align(1), nullptr);
+  Reg = MF.getRegInfo().createVirtualRegister(getRegClassFor(MVT::i64));
+  Chain = DAG.getCopyToReg(Buffer.getValue(1), DL, Reg, Buffer.getValue(0));
+
+  // Allocate an additional TPIDR2 object on the stack (16 bytes)
+  unsigned TPIDR2Obj = MFI.CreateStackObject(16, Align(16), false);
+
+  // Store the buffer pointer to the TPIDR2 stack object.
+  MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, FI);
+  SDValue Ptr = DAG.getFrameIndex(
+      FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+  Chain = DAG.getStore(Chain, DL, Buffer, Ptr, MPI);
+
+  return TPIDR2Obj;
+}
+
 SDValue AArch64TargetLowering::LowerFormalArguments(
     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -6287,6 +6332,14 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
   if (Subtarget->hasCustomCallingConv())
     Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
 
+  if (requiresBufferForLazySave(MF.getFunction())) {
+    // Set up a buffer once and store the buffer in the MachineFunctionInfo.
+    Register Reg;
+    unsigned TPIDR2Obj = allocateLazySaveBuffer(Chain, DL, DAG, Reg);
+    FuncInfo->setLazySaveBufferReg(Reg);
+    FuncInfo->setLazySaveTPIDR2Obj(TPIDR2Obj);
+  }
+
   return Chain;
 }
 
@@ -6869,7 +6922,36 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                getCalleeAttrsFromExternalFunction(CLI.Callee))
     CalleeAttrs = *Attrs;
 
-  SDValue InFlag, PStateSM;
+  bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
+
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+  if (RequiresLazySave) {
+    // Set up a lazy save mechanism by storing the runtime live slices
+    // (worst-case N*N) to the TPIDR2 stack object.
+    SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+                            DAG.getConstant(1, DL, MVT::i32));
+    SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N);
+    unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj();
+
+    if (!TPIDR2Obj) {
+      Register Reg;
+      TPIDR2Obj = allocateLazySaveBuffer(Chain, DL, DAG, Reg);
+    }
+
+    MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, TPIDR2Obj);
+    SDValue TPIDR2ObjAddr = DAG.getFrameIndex(TPIDR2Obj,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+    SDValue BufferPtrAddr =
+        DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
+                    DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
+    Chain = DAG.getTruncStore(Chain, DL, NN, BufferPtrAddr, MPI, MVT::i16);
+    Chain = DAG.getNode(
+        ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
+        DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+        TPIDR2ObjAddr);
+  }
+
+  SDValue PStateSM;
   Optional<bool> RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
   if (RequiresSMChange)
     PStateSM = getPStateSM(DAG, Chain, CallerAttrs, DL, MVT::i64);
@@ -6966,7 +7048,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         StoreSize *= NumParts;
       }
 
-      MachineFrameInfo &MFI = MF.getFrameInfo();
       Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
       Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
       int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
@@ -7130,6 +7211,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   if (!MemOpChains.empty())
     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
 
+  SDValue InFlag;
   if (RequiresSMChange) {
     SDValue NewChain = changeStreamingMode(DAG, DL, *RequiresSMChange, Chain,
                                            InFlag, PStateSM, true);
@@ -7283,6 +7365,46 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     assert(PStateSM && "Expected a PStateSM to be set");
     Result = changeStreamingMode(DAG, DL, !*RequiresSMChange, Result, InFlag,
                                  PStateSM, false);
+  }
+
+  if (RequiresLazySave) {
+    // Unconditionally resume ZA.
+    Result = DAG.getNode(
+        AArch64ISD::SMSTART, DL, MVT::Other, Result,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
+        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+
+    // Conditionally restore the lazy save using a pseudo node.
+    unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
+    SDValue RegMask = DAG.getRegisterMask(
+        TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+    SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
+        "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
+    SDValue TPIDR2_EL0 = DAG.getNode(
+        ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
+        DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
+
+    // Copy the address of the TPIDR2 block into X0 before 'calling' the
+    // RESTORE_ZA pseudo.
+    SDValue Glue;
+    SDValue TPIDR2Block = DAG.getFrameIndex(
+        FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+    Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
+    Result = DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
+                         {Result, TPIDR2_EL0,
+                          DAG.getRegister(AArch64::X0, MVT::i64),
+                          RestoreRoutine,
+                          RegMask,
+                          Result.getValue(1)});
+
+    // Finally reset the TPIDR2_EL0 register to 0.
+    Result = DAG.getNode(
+        ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
+        DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
+        DAG.getConstant(0, DL, MVT::i64));
+  }
+
+  if (RequiresSMChange || RequiresLazySave) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 786dd249fa8e..acdb520c52e7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -67,6 +67,7 @@ enum NodeType : unsigned {
   OBSCURE_COPY,
   SMSTART,
   SMSTOP,
+  RESTORE_ZA,
 
   // Produces the full sequence of instructions for getting the thread pointer
   // offset of a variable into X0, using the TLSDesc model.
@@ -903,6 +904,9 @@ class AArch64TargetLowering : public TargetLowering {
   void addDRTypeForNEON(MVT VT);
   void addQRTypeForNEON(MVT VT);
 
+  unsigned allocateLazySaveBuffer(SDValue &Chain, const SDLoc &DL,
+                                  SelectionDAG &DAG, Register &Reg) const;
+
   SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv,
                                bool isVarArg,
                                const SmallVectorImpl<ISD::InputArg> &Ins,

diff  --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index 85fbf88bfc0e..c11506c898fa 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -184,6 +184,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   /// or return type
   bool IsSVECC = false;
 
+  /// The virtual register that is the pointer to the lazy save buffer.
+  /// This value is used during ISelLowering.
+  Register LazySaveBufferReg = 0;
+
+  /// The frame-index for the TPIDR2 object used for lazy saves.
+  Register LazySaveTPIDR2Obj = 0;
+
+
   /// True if the function need unwind information.
   mutable Optional<bool> NeedsDwarfUnwindInfo;
 
@@ -201,6 +209,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   bool isSVECC() const { return IsSVECC; };
   void setIsSVECC(bool s) { IsSVECC = s; };
 
+  unsigned getLazySaveBufferReg() const { return LazySaveBufferReg; }
+  void setLazySaveBufferReg(unsigned Reg) { LazySaveBufferReg = Reg; }
+
+  unsigned getLazySaveTPIDR2Obj() const { return LazySaveTPIDR2Obj; }
+  void setLazySaveTPIDR2Obj(unsigned Reg) { LazySaveTPIDR2Obj = Reg; }
+
   void initializeBaseYamlFields(const yaml::AArch64FunctionInfo &YamlMFI);
 
   unsigned getBytesInStackArgArea() const { return BytesInStackArgArea; }

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index 8c75cb94cd5b..1cb2a4dcc836 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -325,6 +325,11 @@ const uint32_t *AArch64RegisterInfo::getSMStartStopCallPreservedMask() const {
   return CSR_AArch64_SMStartStop_RegMask;
 }
 
+const uint32_t *
+AArch64RegisterInfo::SMEABISupportRoutinesCallPreservedMaskFromX0() const {
+  return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask;
+}
+
 const uint32_t *AArch64RegisterInfo::getNoPreservedMask() const {
   return CSR_AArch64_NoRegs_RegMask;
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
index 666642d2f7ed..65636bf0173c 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
@@ -69,6 +69,7 @@ class AArch64RegisterInfo final : public AArch64GenRegisterInfo {
   const uint32_t *getTLSCallPreservedMask() const;
 
   const uint32_t *getSMStartStopCallPreservedMask() const;
+  const uint32_t *SMEABISupportRoutinesCallPreservedMaskFromX0() const;
 
   // Funclets on ARM64 Windows don't preserve any registers.
   const uint32_t *getNoPreservedMask() const override;

diff  --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index f1359d9a3737..8932d9ab9b0f 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -18,6 +18,10 @@ def AArch64_smstop  : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 3,
                              [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue, SDNPOutGlue]>;
+def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
+                             [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
+                              SDNPOptInGlue]>;
 
 def AArch64ObscureCopy : SDNode<"AArch64ISD::OBSCURE_COPY", SDTypeProfile<1, 1, []>, []>;
 
@@ -164,6 +168,24 @@ def MSRpstatePseudo :
            (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>,
     Sched<[WriteSys]>;
 
+// Pseudo to conditionally restore ZA state. This expands:
+//
+//   pseudonode tpidr2_el0, tpidr2obj, restore_routine
+//
+// Into:
+//
+//   if (tpidr2_el0 == 0)
+//     BL restore_routine, implicit-use tpidr2obj
+//
+def RestoreZAPseudo :
+  Pseudo<(outs),
+         (ins GPR64:$tpidr2_el0, GPR64sp:$tpidr2obj, i64imm:$restore_routine, variable_ops), []>,
+         Sched<[]>;
+
+def : Pat<(AArch64_restore_za
+            (i64 GPR64:$tpidr2_el0), (i64 GPR64sp:$tpidr2obj), (i64 texternalsym:$restore_routine)),
+          (RestoreZAPseudo GPR64:$tpidr2_el0, GPR64sp:$tpidr2obj, texternalsym:$restore_routine)>;
+
 // Scenario A:
 //
 //   %pstate.before.call = 1

diff  --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
new file mode 100644
index 000000000000..c090e7b5848a
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
@@ -0,0 +1,167 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64 -mattr=+sme < %s | FileCheck %s
+
+declare void @private_za_callee()
+declare float @llvm.cos.f32(float)
+
+; Test lazy-save mechanism for a single callee.
+define void @test_lazy_save_1_callee() nounwind "aarch64_pstate_za_shared" {
+; CHECK-LABEL: test_lazy_save_1_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    mul x8, x8, x8
+; CHECK-NEXT:    sub x9, x9, x8
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    str x9, [x29]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    bl private_za_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbnz x8, .LBB0_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @private_za_callee()
+  ret void
+}
+
+; Test lazy-save mechanism for multiple callees.
+define void @test_lazy_save_2_callees() nounwind "aarch64_pstate_za_shared" {
+; CHECK-LABEL: test_lazy_save_2_callees:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp x20, x19, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mul x19, x8, x8
+; CHECK-NEXT:    mov x8, sp
+; CHECK-NEXT:    sub x8, x8, x19
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    sub x20, x29, #16
+; CHECK-NEXT:    str x8, [x29]
+; CHECK-NEXT:    sturh w19, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x20
+; CHECK-NEXT:    bl private_za_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbnz x8, .LBB1_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    sturh w19, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x20
+; CHECK-NEXT:    bl private_za_callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbnz x8, .LBB1_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB1_4:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x20, x19, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @private_za_callee()
+  call void @private_za_callee()
+  ret void
+}
+
+; Test a call of an intrinsic that gets expanded to a library call.
+define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_pstate_za_shared" {
+; CHECK-LABEL: test_lazy_save_expanded_intrinsic:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    mul x8, x8, x8
+; CHECK-NEXT:    sub x9, x9, x8
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    str x9, [x29]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    bl cosf
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbnz x8, .LBB2_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB2_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  %res = call float @llvm.cos.f32(float %a)
+  ret float %res
+}
+
+; Test a combination of streaming-compatible -> normal call with lazy-save.
+define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_pstate_za_shared" "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: test_lazy_save_and_conditional_smstart:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x29, x30, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    add x29, sp, #64
+; CHECK-NEXT:    str x19, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    mul x8, x8, x8
+; CHECK-NEXT:    sub x9, x9, x8
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #80
+; CHECK-NEXT:    stur x9, [x29, #-64]
+; CHECK-NEXT:    sturh w8, [x29, #-72]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbz x19, #0, .LBB3_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB3_2:
+; CHECK-NEXT:    bl private_za_callee
+; CHECK-NEXT:    tbz x19, #0, .LBB3_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB3_4:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    sub x0, x29, #80
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbnz x8, .LBB3_6
+; CHECK-NEXT:  // %bb.5:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB3_6:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    sub sp, x29, #64
+; CHECK-NEXT:    ldp x29, x30, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @private_za_callee()
+  ret void
+}

diff  --git a/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll b/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
index d276b177c2c0..1c180325326c 100644
--- a/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
+++ b/llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
@@ -4,17 +4,65 @@
 declare void @private_za_callee()
 
 ; Ensure that we don't use tail call optimization when a lazy-save is required.
-;
-; FIXME: The code below if obviously not yet correct, because it should set up
-; a lazy-save buffer before doing the call, and (conditionally) restore it after
-; the call. But this functionality will follow in a future patch.
 define void @disable_tailcallopt() "aarch64_pstate_za_shared" nounwind {
 ; CHECK-LABEL: disable_tailcallopt:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    mul x8, x8, x8
+; CHECK-NEXT:    sub x9, x9, x8
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    str x9, [x29]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
 ; CHECK-NEXT:    bl private_za_callee
-; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbnz x8, .LBB0_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
   tail call void @private_za_callee()
   ret void
 }
+
+; Ensure we set up and restore the lazy save correctly for instructions which are lowered to lib calls
+define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_pstate_za_shared" nounwind {
+; CHECK-LABEL: f128_call_za:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    mul x8, x8, x8
+; CHECK-NEXT:    sub x9, x9, x8
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    str x9, [x29]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    bl __addtf3
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    add x0, x29, #0
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbnz x8, .LBB1_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  %res = fadd fp128 %a, %b
+  ret fp128 %res
+}


        


More information about the llvm-commits mailing list