[llvm] d4d81ac - [AArch64][SME2] Extend SMEABIPass to handle functions with new ZT0 state (#78848)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 22 08:30:47 PST 2024


Author: Kerry McLaughlin
Date: 2024-01-22T16:30:43Z
New Revision: d4d81acb52bd44681210001c148ac86df8a344f0

URL: https://github.com/llvm/llvm-project/commit/d4d81acb52bd44681210001c148ac86df8a344f0
DIFF: https://github.com/llvm/llvm-project/commit/d4d81acb52bd44681210001c148ac86df8a344f0.diff

LOG: [AArch64][SME2] Extend SMEABIPass to handle functions with new ZT0 state (#78848)

updateNewZAFunctions is extended to generate the following on entry to a
function with either the "aarch64_pstate_za_new" or "arm_new_zt0"
attribute:
- Private-ZA interface: commit any active lazy-saves & enable PSTATE.ZA.
  - "aarch64_pstate_za_new": zero ZA.
  - "arm_new_zt0": zero ZT0.

Additionally, PSTATE.ZA should disabled before returning if the function
has a private-ZA interface.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/SMEABIPass.cpp
    llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
    llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
    llvm/test/CodeGen/AArch64/sme-zt0-state.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 3315171798d9f1..0247488ce93f1d 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -40,7 +40,8 @@ struct SMEABI : public FunctionPass {
   bool runOnFunction(Function &F) override;
 
 private:
-  bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder);
+  bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder,
+                               SMEAttrs FnAttrs);
 };
 } // end anonymous namespace
 
@@ -76,56 +77,87 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
                      Builder.getInt64(0));
 }
 
-/// This function generates code to commit a lazy save at the beginning of a
-/// function marked with `aarch64_pstate_za_new`. If the value read from
-/// TPIDR2_EL0 is not null on entry to the function then the lazy-saving scheme
-/// is active and we should call __arm_tpidr2_save to commit the lazy save.
-/// Additionally, PSTATE.ZA should be enabled at the beginning of the function
-/// and disabled before returning.
-bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
-                                  IRBuilder<> &Builder) {
+/// This function generates code at the beginning and end of a function marked
+/// with either `aarch64_pstate_za_new` or `aarch64_new_zt0`.
+/// At the beginning of the function, the following code is generated:
+///  - Commit lazy-save if active   [Private-ZA Interface*]
+///  - Enable PSTATE.ZA             [Private-ZA Interface]
+///  - Zero ZA                      [Has New ZA State]
+///  - Zero ZT0                     [Has New ZT0 State]
+///
+/// * A function with new ZT0 state will not change ZA, so committing the
+/// lazy-save is not strictly necessary. However, the lazy-save mechanism
+/// may be active on entry to the function, with PSTATE.ZA set to 1. If
+/// the new ZT0 function calls a function that does not share ZT0, we will
+/// need to conditionally SMSTOP ZA before the call, setting PSTATE.ZA to 0.
+/// For this reason, it's easier to always commit the lazy-save at the
+/// beginning of the function regardless of whether it has ZA state.
+///
+/// At the end of the function, PSTATE.ZA is disabled if the function has a
+/// Private-ZA Interface. A function is considered to have a Private-ZA
+/// interface if it does not share ZA or ZT0.
+///
+bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
+                                     IRBuilder<> &Builder, SMEAttrs FnAttrs) {
   LLVMContext &Context = F->getContext();
   BasicBlock *OrigBB = &F->getEntryBlock();
-
-  // Create the new blocks for reading TPIDR2_EL0 & enabling ZA state.
-  auto *SaveBB = OrigBB->splitBasicBlock(OrigBB->begin(), "save.za", true);
-  auto *PreludeBB = BasicBlock::Create(Context, "prelude", F, SaveBB);
-
-  // Read TPIDR2_EL0 in PreludeBB & branch to SaveBB if not 0.
-  Builder.SetInsertPoint(PreludeBB);
-  Function *TPIDR2Intr =
-      Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_get_tpidr2);
-  auto *TPIDR2 = Builder.CreateCall(TPIDR2Intr->getFunctionType(), TPIDR2Intr,
-                                    {}, "tpidr2");
-  auto *Cmp =
-      Builder.CreateCmp(ICmpInst::ICMP_NE, TPIDR2, Builder.getInt64(0), "cmp");
-  Builder.CreateCondBr(Cmp, SaveBB, OrigBB);
-
-  // Create a call __arm_tpidr2_save, which commits the lazy save.
-  Builder.SetInsertPoint(&SaveBB->back());
-  emitTPIDR2Save(M, Builder);
-
-  // Enable pstate.za at the start of the function.
   Builder.SetInsertPoint(&OrigBB->front());
-  Function *EnableZAIntr =
-      Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_enable);
-  Builder.CreateCall(EnableZAIntr->getFunctionType(), EnableZAIntr);
-
-  // ZA state must be zeroed upon entry to a function with NewZA
-  Function *ZeroIntr =
-      Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero);
-  Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
-                     Builder.getInt32(0xff));
-
-  // Before returning, disable pstate.za
-  for (BasicBlock &BB : *F) {
-    Instruction *T = BB.getTerminator();
-    if (!T || !isa<ReturnInst>(T))
-      continue;
-    Builder.SetInsertPoint(T);
-    Function *DisableZAIntr =
-        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_disable);
-    Builder.CreateCall(DisableZAIntr->getFunctionType(), DisableZAIntr);
+
+  // Commit any active lazy-saves if this is a Private-ZA function. If the
+  // value read from TPIDR2_EL0 is not null on entry to the function then
+  // the lazy-saving scheme is active and we should call __arm_tpidr2_save
+  // to commit the lazy save.
+  if (FnAttrs.hasPrivateZAInterface()) {
+    // Create the new blocks for reading TPIDR2_EL0 & enabling ZA state.
+    auto *SaveBB = OrigBB->splitBasicBlock(OrigBB->begin(), "save.za", true);
+    auto *PreludeBB = BasicBlock::Create(Context, "prelude", F, SaveBB);
+
+    // Read TPIDR2_EL0 in PreludeBB & branch to SaveBB if not 0.
+    Builder.SetInsertPoint(PreludeBB);
+    Function *TPIDR2Intr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_get_tpidr2);
+    auto *TPIDR2 = Builder.CreateCall(TPIDR2Intr->getFunctionType(), TPIDR2Intr,
+                                      {}, "tpidr2");
+    auto *Cmp = Builder.CreateCmp(ICmpInst::ICMP_NE, TPIDR2,
+                                  Builder.getInt64(0), "cmp");
+    Builder.CreateCondBr(Cmp, SaveBB, OrigBB);
+
+    // Create a call __arm_tpidr2_save, which commits the lazy save.
+    Builder.SetInsertPoint(&SaveBB->back());
+    emitTPIDR2Save(M, Builder);
+
+    // Enable pstate.za at the start of the function.
+    Builder.SetInsertPoint(&OrigBB->front());
+    Function *EnableZAIntr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_enable);
+    Builder.CreateCall(EnableZAIntr->getFunctionType(), EnableZAIntr);
+  }
+
+  if (FnAttrs.hasNewZABody()) {
+    Function *ZeroIntr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero);
+    Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
+                       Builder.getInt32(0xff));
+  }
+
+  if (FnAttrs.isNewZT0()) {
+    Function *ClearZT0Intr =
+        Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero_zt);
+    Builder.CreateCall(ClearZT0Intr->getFunctionType(), ClearZT0Intr,
+                       {Builder.getInt32(0)});
+  }
+
+  if (FnAttrs.hasPrivateZAInterface()) {
+    // Before returning, disable pstate.za
+    for (BasicBlock &BB : *F) {
+      Instruction *T = BB.getTerminator();
+      if (!T || !isa<ReturnInst>(T))
+        continue;
+      Builder.SetInsertPoint(T);
+      Function *DisableZAIntr =
+          Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_disable);
+      Builder.CreateCall(DisableZAIntr->getFunctionType(), DisableZAIntr);
+    }
   }
 
   F->addFnAttr("aarch64_expanded_pstate_za");
@@ -142,8 +174,8 @@ bool SMEABI::runOnFunction(Function &F) {
 
   bool Changed = false;
   SMEAttrs FnAttrs(F);
-  if (FnAttrs.hasNewZABody())
-    Changed |= updateNewZAFunctions(M, &F, Builder);
+  if (FnAttrs.hasNewZABody() || FnAttrs.isNewZT0())
+    Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs);
 
   return Changed;
 }

diff  --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 9693b6a664be26..3ee54e5df0a13d 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -27,10 +27,8 @@ void SMEAttrs::set(unsigned M, bool Enable) {
          "ZA_New and ZA_Shared are mutually exclusive");
   assert(!(hasNewZABody() && preservesZA()) &&
          "ZA_New and ZA_Preserved are mutually exclusive");
-  assert(!(hasNewZABody() && (Bitmask & ZA_NoLazySave)) &&
-         "ZA_New and ZA_NoLazySave are mutually exclusive");
-  assert(!(sharesZA() && (Bitmask & ZA_NoLazySave)) &&
-         "ZA_Shared and ZA_NoLazySave are mutually exclusive");
+  assert(!(hasNewZABody() && (Bitmask & SME_ABI_Routine)) &&
+         "ZA_New and SME_ABI_Routine are mutually exclusive");
 
   // ZT0 Attrs
   assert(
@@ -49,11 +47,10 @@ SMEAttrs::SMEAttrs(const CallBase &CB) {
 
 SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
-    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
-                SMEAttrs::ZA_NoLazySave);
+    Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
   if (FuncName == "__arm_tpidr2_restore")
     Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
-                SMEAttrs::ZA_NoLazySave);
+                SMEAttrs::SME_ABI_Routine);
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {

diff  --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 8af219bb361fdc..27b7075a0944ff 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -38,13 +38,13 @@ class SMEAttrs {
   // Enum with bitmasks for each individual SME feature.
   enum Mask {
     Normal = 0,
-    SM_Enabled = 1 << 0,    // aarch64_pstate_sm_enabled
-    SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible
-    SM_Body = 1 << 2,       // aarch64_pstate_sm_body
-    ZA_Shared = 1 << 3,     // aarch64_pstate_sm_shared
-    ZA_New = 1 << 4,        // aarch64_pstate_sm_new
-    ZA_Preserved = 1 << 5,  // aarch64_pstate_sm_preserved
-    ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves
+    SM_Enabled = 1 << 0,      // aarch64_pstate_sm_enabled
+    SM_Compatible = 1 << 1,   // aarch64_pstate_sm_compatible
+    SM_Body = 1 << 2,         // aarch64_pstate_sm_body
+    ZA_Shared = 1 << 3,       // aarch64_pstate_sm_shared
+    ZA_New = 1 << 4,          // aarch64_pstate_sm_new
+    ZA_Preserved = 1 << 5,    // aarch64_pstate_sm_preserved
+    SME_ABI_Routine = 1 << 6, // Used for SME ABI routines to avoid lazy saves
     ZT0_Shift = 7,
     ZT0_Mask = 0b111 << ZT0_Shift
   };
@@ -86,7 +86,7 @@ class SMEAttrs {
   bool hasZAState() const { return hasNewZABody() || sharesZA(); }
   bool requiresLazySave(const SMEAttrs &Callee) const {
     return hasZAState() && Callee.hasPrivateZAInterface() &&
-           !(Callee.Bitmask & ZA_NoLazySave);
+           !(Callee.Bitmask & SME_ABI_Routine);
   }
 
   // Interfaces to query ZT0 State
@@ -116,7 +116,8 @@ class SMEAttrs {
     return hasZT0State() && !Callee.sharesZT0();
   }
   bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
-    return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
+    return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
+           !(Callee.Bitmask & SME_ABI_Routine);
   }
   bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
     return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);

diff  --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 88eaf19ec488f3..18d1e40bf4d0fd 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -153,3 +153,118 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
   call void @callee() "aarch64_new_zt0";
   ret void;
 }
+
+;
+; New-ZA Caller
+;
+
+; Expect commit of lazy-save if ZA is dormant
+; Expect smstart ZA & clear ZT0
+; Before return, expect smstop ZA
+define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: zt0_new_caller:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB6_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    mov x8, sp
+; CHECK-NEXT:    str zt0, [x8]
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    ldr zt0, [x8]
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB6_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #80
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_in_zt0";
+  ret void;
+}
+
+; Expect commit of lazy-save if ZA is dormant
+; Expect smstart ZA, clear ZA & clear ZT0
+; Before return, expect smstop ZA
+define void @new_za_zt0_caller() "aarch64_pstate_za_new" "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: new_za_zt0_caller:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB7_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    sub x8, x29, #80
+; CHECK-NEXT:    str zt0, [x8]
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    ldr zt0, [x8]
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB7_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero {za}
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  ret void;
+}
+
+; Expect clear ZA on entry
+define void @new_za_shared_zt0_caller() "aarch64_pstate_za_new" "aarch64_in_zt0" nounwind {
+; CHECK-LABEL: new_za_shared_zt0_caller:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    zero {za}
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  ret void;
+}
+
+; Expect clear ZT0 on entry
+define void @shared_za_new_zt0() "aarch64_pstate_za_shared" "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: shared_za_new_zt0:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov x9, sp
+; CHECK-NEXT:    msub x8, x8, x8, x9
+; CHECK-NEXT:    mov sp, x8
+; CHECK-NEXT:    stur wzr, [x29, #-4]
+; CHECK-NEXT:    sturh wzr, [x29, #-6]
+; CHECK-NEXT:    stur x8, [x29, #-16]
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
+  ret void;
+}


        


More information about the llvm-commits mailing list