[llvm] d297399 - [AArch64][SME] Enable TPIDR2 lazy-save for za_preserved

Matt Devereau via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 20 06:35:02 PDT 2023


Author: Matt Devereau
Date: 2023-09-20T13:34:41Z
New Revision: d297399b359a77e522c54e865aebc24f252d67e9

URL: https://github.com/llvm/llvm-project/commit/d297399b359a77e522c54e865aebc24f252d67e9
DIFF: https://github.com/llvm/llvm-project/commit/d297399b359a77e522c54e865aebc24f252d67e9.diff

LOG: [AArch64][SME] Enable TPIDR2 lazy-save for za_preserved

This change makes callees with the __arm_preserves_za
type attribute comply with the dormant state requirements
when it's caller has the __arm_shared_za type attribute.
Several external SME functions also do not need to lazy
save.

https://github.com/ARM-software/abi-aa/blob/5e67092434b50c04f8ad178a9c272ce3c6ada7fd/aapcs64/aapcs64.rst?plain=1#L1381

Differential Revision: https://reviews.llvm.org/D159186

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
    llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
    llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 576339cfbd00a92..ad01a206c93fb39 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4823,17 +4823,6 @@ SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
                      Mask);
 }
 
-static std::optional<SMEAttrs> getCalleeAttrsFromExternalFunction(SDValue V) {
-  if (auto *ES = dyn_cast<ExternalSymbolSDNode>(V)) {
-    StringRef S(ES->getSymbol());
-    if (S == "__arm_sme_state" || S == "__arm_tpidr2_save")
-      return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved);
-    if (S == "__arm_tpidr2_restore")
-      return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared);
-  }
-  return std::nullopt;
-}
-
 SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
                                                    SelectionDAG &DAG) const {
   unsigned IntNo = Op.getConstantOperandVal(1);
@@ -7375,28 +7364,31 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
   if (CLI.CB)
     CalleeAttrs = SMEAttrs(*CLI.CB);
-  else if (std::optional<SMEAttrs> Attrs =
-               getCalleeAttrsFromExternalFunction(CLI.Callee))
-    CalleeAttrs = *Attrs;
+  else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
+    CalleeAttrs = SMEAttrs(ES->getSymbol());
 
   bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
-
-  MachineFrameInfo &MFI = MF.getFrameInfo();
   if (RequiresLazySave) {
-    // Set up a lazy save mechanism by storing the runtime live slices
-    // (worst-case N*N) to the TPIDR2 stack object.
-    SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
-                            DAG.getConstant(1, DL, MVT::i32));
-    SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N);
-    unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj();
+    SDValue NumZaSaveSlices;
+    if (!CalleeAttrs.preservesZA()) {
+      // Set up a lazy save mechanism by storing the runtime live slices
+      // (worst-case SVL*SVL) to the TPIDR2 stack object.
+      SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+                                DAG.getConstant(1, DL, MVT::i32));
+      NumZaSaveSlices = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
+    } else if (CalleeAttrs.preservesZA()) {
+      NumZaSaveSlices = DAG.getConstant(0, DL, MVT::i64);
+    }
 
+    unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj();
     MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, TPIDR2Obj);
     SDValue TPIDR2ObjAddr = DAG.getFrameIndex(TPIDR2Obj,
         DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
-    SDValue BufferPtrAddr =
+    SDValue NumZaSaveSlicesAddr =
         DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
                     DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
-    Chain = DAG.getTruncStore(Chain, DL, NN, BufferPtrAddr, MPI, MVT::i16);
+    Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr,
+                              MPI, MVT::i16);
     Chain = DAG.getNode(
         ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
         DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
@@ -7503,6 +7495,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
       Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
       Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
+      MachineFrameInfo &MFI = MF.getFrameInfo();
       int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
       if (isScalable)
         MFI.setStackID(FI, TargetStackID::ScalableVector);
@@ -7819,35 +7812,34 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   }
 
   if (RequiresLazySave) {
-    // Unconditionally resume ZA.
-    Result = DAG.getNode(
-        AArch64ISD::SMSTART, DL, MVT::Other, Result,
-        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
-
-    // Conditionally restore the lazy save using a pseudo node.
-    unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
-    SDValue RegMask = DAG.getRegisterMask(
-        TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
-    SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
-        "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
-    SDValue TPIDR2_EL0 = DAG.getNode(
-        ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
-        DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
-
-    // Copy the address of the TPIDR2 block into X0 before 'calling' the
-    // RESTORE_ZA pseudo.
-    SDValue Glue;
-    SDValue TPIDR2Block = DAG.getFrameIndex(
-        FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
-    Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
-    Result = DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
-                         {Result, TPIDR2_EL0,
-                          DAG.getRegister(AArch64::X0, MVT::i64),
-                          RestoreRoutine,
-                          RegMask,
-                          Result.getValue(1)});
-
+    if (!CalleeAttrs.preservesZA()) {
+      // Unconditionally resume ZA.
+      Result = DAG.getNode(
+          AArch64ISD::SMSTART, DL, MVT::Other, Result,
+          DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
+          DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+
+      // Conditionally restore the lazy save using a pseudo node.
+      unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
+      SDValue RegMask = DAG.getRegisterMask(
+          TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+      SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
+          "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
+      SDValue TPIDR2_EL0 = DAG.getNode(
+          ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
+          DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
+
+      // Copy the address of the TPIDR2 block into X0 before 'calling' the
+      // RESTORE_ZA pseudo.
+      SDValue Glue;
+      SDValue TPIDR2Block = DAG.getFrameIndex(
+          FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+      Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
+      Result = DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
+                           {Result, TPIDR2_EL0,
+                            DAG.getRegister(AArch64::X0, MVT::i64),
+                            RestoreRoutine, RegMask, Result.getValue(1)});
+    }
     // Finally reset the TPIDR2_EL0 register to 0.
     Result = DAG.getNode(
         ISD::INTRINSIC_VOID, DL, MVT::Other, Result,

diff  --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index b2c126bbc6f3a82..0082b4017986c6f 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -24,12 +24,26 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "ZA_New and ZA_Shared are mutually exclusive");
   assert(!(hasNewZABody() && preservesZA()) &&
          "ZA_New and ZA_Preserved are mutually exclusive");
+  assert(!(hasNewZABody() && (Bitmask & ZA_NoLazySave)) &&
+         "ZA_New and ZA_NoLazySave are mutually exclusive");
+  assert(!(hasSharedZAInterface() && (Bitmask & ZA_NoLazySave)) &&
+         "ZA_Shared and ZA_NoLazySave are mutually exclusive");
 }
 
 SMEAttrs::SMEAttrs(const CallBase &CB) {
   *this = SMEAttrs(CB.getAttributes());
-  if (auto *F = CB.getCalledFunction())
-    set(SMEAttrs(*F).Bitmask);
+  if (auto *F = CB.getCalledFunction()) {
+    set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
+  }
+}
+
+SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
+  if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
+    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
+                SMEAttrs::ZA_NoLazySave);
+  if (FuncName == "__arm_tpidr2_restore")
+    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
+                SMEAttrs::ZA_NoLazySave);
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {

diff  --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 587765a7d63b75b..e766b778b541020 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -35,6 +35,7 @@ class SMEAttrs {
     ZA_Shared = 1 << 3,     // aarch64_pstate_sm_shared
     ZA_New = 1 << 4,        // aarch64_pstate_sm_new
     ZA_Preserved = 1 << 5,  // aarch64_pstate_sm_preserved
+    ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves
     All = ZA_Preserved - 1
   };
 
@@ -42,6 +43,7 @@ class SMEAttrs {
   SMEAttrs(const Function &F) : SMEAttrs(F.getAttributes()) {}
   SMEAttrs(const CallBase &CB);
   SMEAttrs(const AttributeList &L);
+  SMEAttrs(StringRef FuncName);
 
   void set(unsigned M, bool Enable = true);
 
@@ -82,7 +84,7 @@ class SMEAttrs {
   }
   bool requiresLazySave(const SMEAttrs &Callee) const {
     return hasZAState() && Callee.hasPrivateZAInterface() &&
-           !Callee.preservesZA();
+           !(Callee.Bitmask & ZA_NoLazySave);
   }
 };
 

diff  --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
index 538c403981b6d3f..ad16402a18f8b92 100644
--- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
+++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
@@ -2,6 +2,7 @@
 ; RUN: llc -mtriple=aarch64 -mattr=+sme < %s | FileCheck %s
 
 declare void @private_za_callee()
+declare void @private_za_preserved_callee() "aarch64_pstate_za_preserved"
 declare float @llvm.cos.f32(float)
 
 ; Test lazy-save mechanism for a single callee.
@@ -165,3 +166,49 @@ define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_pstate_z
   call void @private_za_callee()
   ret void
 }
+
+
+; Test lazy-save mechanism for an aarch64_pstate_za_shared caller
+; calling a callee with aarch64_pstate_za_preserved.
+define void @za_shared_caller_za_preserved_callee() nounwind "aarch64_pstate_za_shared" "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: za_shared_caller_za_preserved_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x29, x30, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    add x29, sp, #64
+; CHECK-NEXT:    str x19, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur x8, [x29, #-80]
+; CHECK-NEXT:    sub x8, x29, #80
+; CHECK-NEXT:    sturh wzr, [x29, #-72]
+; CHECK-NEXT:    msr TPIDR2_EL0, x8
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbz x19, #0, .LBB4_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB4_2:
+; CHECK-NEXT:    bl private_za_preserved_callee
+; CHECK-NEXT:    tbz x19, #0, .LBB4_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB4_4:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    sub sp, x29, #64
+; CHECK-NEXT:    ldp x29, x30, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x19, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @private_za_preserved_callee()
+  ret void
+}


        


More information about the llvm-commits mailing list