[llvm] a8a3711 - [AArch64][SME2] Preserve ZT0 state around function calls (#78321)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 20 04:06:05 PST 2024


Author: Kerry McLaughlin
Date: 2024-01-20T12:06:00Z
New Revision: a8a3711e745286fd26f726b3397dbe5fb03ea465

URL: https://github.com/llvm/llvm-project/commit/a8a3711e745286fd26f726b3397dbe5fb03ea465
DIFF: https://github.com/llvm/llvm-project/commit/a8a3711e745286fd26f726b3397dbe5fb03ea465.diff

LOG: [AArch64][SME2] Preserve ZT0 state around function calls (#78321)

If a function has ZT0 state and calls a function which does not
preserve ZT0, the caller must save and restore ZT0 around the call.
If the caller shares ZT0 state and the callee is not shared ZA, we must
additionally call SMSTOP/SMSTART ZA around the call.

This patch adds new AArch64ISDNodes for spilling & filling ZT0.
Where requiresPreservingZT0 is true, ZT0 state will be preserved
across a call.

Added: 
    llvm/test/CodeGen/AArch64/sme-zt0-state.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
    llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
    llvm/unittests/Target/AArch64/SMEAttributesTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 09e42b72be63cf0..96ea692d03f563e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2341,6 +2341,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::SMSTART)
     MAKE_CASE(AArch64ISD::SMSTOP)
     MAKE_CASE(AArch64ISD::RESTORE_ZA)
+    MAKE_CASE(AArch64ISD::RESTORE_ZT)
+    MAKE_CASE(AArch64ISD::SAVE_ZT)
     MAKE_CASE(AArch64ISD::CALL)
     MAKE_CASE(AArch64ISD::ADRP)
     MAKE_CASE(AArch64ISD::ADR)
@@ -7654,6 +7656,34 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     });
   }
 
+  SDValue ZTFrameIdx;
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+  bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
+
+  // If the caller has ZT0 state which will not be preserved by the callee,
+  // spill ZT0 before the call.
+  if (ShouldPreserveZT0) {
+    unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
+    ZTFrameIdx = DAG.getFrameIndex(
+        ZTObj,
+        DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
+
+    Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
+                        {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+  }
+
+  // If caller shares ZT0 but the callee is not shared ZA, we need to stop
+  // PSTATE.ZA before the call if there is no lazy-save active.
+  bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
+  assert((!DisableZA || !RequiresLazySave) &&
+         "Lazy-save should have PSTATE.SM=1 on entry to the function");
+
+  if (DisableZA)
+    Chain = DAG.getNode(
+        AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
+        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
   if (!IsSibCall)
@@ -8065,13 +8095,19 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                                  Result, InGlue, PStateSM, false);
   }
 
-  if (RequiresLazySave) {
+  if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
     // Unconditionally resume ZA.
     Result = DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other, Result,
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
         DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
 
+  if (ShouldPreserveZT0)
+    Result =
+        DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
+                    {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
+
+  if (RequiresLazySave) {
     // Conditionally restore the lazy save using a pseudo node.
     unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
     SDValue RegMask = DAG.getRegisterMask(
@@ -8100,7 +8136,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(0, DL, MVT::i64));
   }
 
-  if (RequiresSMChange || RequiresLazySave) {
+  if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
@@ -23977,6 +24013,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
       return DAG.getMergeValues(
           {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
     }
+    case Intrinsic::aarch64_sme_ldr_zt:
+      return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
+    case Intrinsic::aarch64_sme_str_zt:
+      return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
+                         DAG.getVTList(MVT::Other), N->getOperand(0),
+                         N->getOperand(2), N->getOperand(3));
     default:
       break;
     }

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6047a3b7b2864aa..abecc3560ccbb32 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -61,6 +61,8 @@ enum NodeType : unsigned {
   SMSTART,
   SMSTOP,
   RESTORE_ZA,
+  RESTORE_ZT,
+  SAVE_ZT,
 
   // Produces the full sequence of instructions for getting the thread pointer
   // offset of a variable into X0, using the TLSDesc model.

diff  --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 380f6e1fcfdaefc..eeae5303a3f8987 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
                              [SDTCisInt<0>, SDTCisPtrTy<1>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue]>;
+def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
+                                [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                                [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
+def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisPtrTy<1>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
 
 //===----------------------------------------------------------------------===//
 // Instruction naming conventions.
@@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops
 
 defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;
 
-defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
-defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
+defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
+defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;
 
 def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
 def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;

diff  --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 6f622f1996a3a0a..8af219bb361fdcf 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -112,6 +112,15 @@ class SMEAttrs {
            State == StateValue::InOut || State == StateValue::Preserved;
   }
   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
+  bool requiresPreservingZT0(const SMEAttrs &Callee) const {
+    return hasZT0State() && !Callee.sharesZT0();
+  }
+  bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
+    return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
+  }
+  bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
+    return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
+  }
 };
 
 } // namespace llvm

diff  --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
new file mode 100644
index 000000000000000..88eaf19ec488f3d
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -0,0 +1,155 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+declare void @callee();
+
+;
+; Private-ZA Callee
+;
+
+; Expect spill & fill of ZT0 around call
+; Expect smstop/smstart za around call
+define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_in_caller_no_state_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    mov x19, sp
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #80
+; CHECK-NEXT:    ret
+  call void @callee();
+  ret void;
+}
+
+; Expect spill & fill of ZT0 around call
+; Expect setup and restore lazy-save around call
+; Expect smstart za after call
+define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x9, x8, x8, x9
+; CHECK-NEXT:    mov sp, x9
+; CHECK-NEXT:    sub x10, x29, #16
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x9, [x29, #-16]
+; CHECK-NEXT:    sturh w8, [x29, #-8]
+; CHECK-NEXT:    msr TPIDR2_EL0, x10
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    sub x0, x29, #16
+; CHECK-NEXT:    cbnz x8, .LBB1_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    bl __arm_tpidr2_restore
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee();
+  ret void;
+}
+
+;
+; Shared-ZA Callee
+;
+
+; Caller and callee have shared ZT0 state, no spill/fill of ZT0 required
+define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_in_zt0";
+  ret void;
+}
+
+; Expect spill & fill of ZT0 around call
+define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    sub x19, x29, #80
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_za_shared";
+  ret void;
+}
+
+; Caller and callee have shared ZA & ZT0
+define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  ret void;
+}
+
+; New-ZA Callee
+
+; Expect spill & fill of ZT0 around call
+; Expect smstop/smstart za around call
+define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: zt0_in_caller_zt0_new_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    mov x19, sp
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #80
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_new_zt0";
+  ret void;
+}

diff  --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 294e5571814246a..044de72449ec890 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -191,6 +191,9 @@ TEST(SMEAttributes, Basics) {
 TEST(SMEAttributes, Transitions) {
   // Normal -> Normal
   ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresPreservingZT0(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresDisablingZABeforeCall(SA(SA::Normal)));
+  ASSERT_FALSE(SA(SA::Normal).requiresEnablingZAAfterCall(SA(SA::Normal)));
   // Normal -> Normal + LocallyStreaming
   ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
 
@@ -240,4 +243,37 @@ TEST(SMEAttributes, Transitions) {
   // Streaming-compatible -> Streaming-compatible + LocallyStreaming
   ASSERT_FALSE(SA(SA::SM_Compatible)
                    .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
+
+  SA Private_ZA = SA(SA::Normal);
+  SA ZA_Shared = SA(SA::ZA_Shared);
+  SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
+  SA ZA_ZT0_Shared = SA(SA::ZA_Shared | SA::encodeZT0State(SA::StateValue::In));
+
+  // Shared ZA -> Private ZA Interface
+  ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZA_Shared.requiresEnablingZAAfterCall(Private_ZA));
+
+  // Shared ZT0 -> Private ZA Interface
+  ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
+  ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+
+  // Shared ZA & ZT0 -> Private ZA Interface
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
+  ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
+  ASSERT_TRUE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+
+  // Shared ZA -> Shared ZA Interface
+  ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZA_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+
+  // Shared ZT0 -> Shared ZA Interface
+  ASSERT_FALSE(ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
+  ASSERT_FALSE(ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+
+  // Shared ZA & ZT0 -> Shared ZA Interface
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
+  ASSERT_FALSE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
 }


        


More information about the llvm-commits mailing list