[llvm] [AArch64][SME] Allow spills of ZT0 around SME ABI routines again (PR #136726)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 24 05:14:01 PDT 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/136726

>From 68b5f0e76615da387d973d663121c3285bc2ce4b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 22 Apr 2025 15:41:22 +0000
Subject: [PATCH 1/2] [AArch64][SME] Allow spills of ZT0 arounds SME ABI
 routines again

In #132722 spills of ZT0 were disabled around all SME ABI routines to
avoid a case where ZT0 is spilled before ZA is enabled (resulting in a
crash).

It turns out that the ABI does not promise that routines will preserve
ZT0 (however in practice they do), so generally disabling ZT0 spills for
ABI routines is not correct.

The case where a crash was possible was "aarch64_new_zt0" functions with
ZA disabled on entry and a ZT0 spill around __arm_tpidr2_save. In this
case, ZT0 will be undefined at the call to __arm_tpidr2_save, so we can
mark the call as preserving ZT0 (whether it does or not) to avoid the
ZT0 spills.
---
 llvm/lib/Target/AArch64/SMEABIPass.cpp        | 16 ++++++--
 .../AArch64/Utils/AArch64SMEAttributes.h      |  3 +-
 .../CodeGen/AArch64/sme-new-zt0-function.ll   | 14 +++++++
 llvm/test/CodeGen/AArch64/sme-zt0-state.ll    | 41 +++++++++++++++++--
 4 files changed, 64 insertions(+), 10 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll

diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index bb885d86392fe..440bbb2a941ab 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -54,14 +54,22 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
 //===----------------------------------------------------------------------===//
 
 // Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0.
-void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
+void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
+  auto &Ctx = M->getContext();
   auto *TPIDR2SaveTy =
       FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
-  auto Attrs = AttributeList().addFnAttribute(M->getContext(),
-                                              "aarch64_pstate_sm_compatible");
+  auto Attrs =
+      AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible");
   FunctionCallee Callee =
       M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
   CallInst *Call = Builder.CreateCall(Callee);
+
+  // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
+  // __arm_tpidr2_save as preserving ZT0. This prevents an unnecessary spill of
+  // ZT0 that can occur before ZA is enabled.
+  if (ZT0IsUndef)
+    Call->addFnAttr(Attribute::get(Ctx, "aarch64_preserves_zt0"));
+
   Call->setCallingConv(
       CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
 
@@ -119,7 +127,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
 
     // Create a call __arm_tpidr2_save, which commits the lazy save.
     Builder.SetInsertPoint(&SaveBB->back());
-    emitTPIDR2Save(M, Builder);
+    emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
 
     // Enable pstate.za at the start of the function.
     Builder.SetInsertPoint(&OrigBB->front());
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index a3ebf764a6e0c..fb093da70c46b 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -133,8 +133,7 @@ class SMEAttrs {
   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
   bool requiresPreservingZT0(const SMEAttrs &Callee) const {
     return hasZT0State() && !Callee.sharesZT0() &&
-           !Callee.hasAgnosticZAInterface() &&
-           !(Callee.Bitmask & SME_ABI_Routine);
+           !Callee.hasAgnosticZAInterface();
   }
   bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
     return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
diff --git a/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
new file mode 100644
index 0000000000000..715122d0fa4b4
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
@@ -0,0 +1,14 @@
+; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi %s | FileCheck %s
+
+declare void @callee();
+
+define void @private_za() "aarch64_new_zt0" {
+  call void @callee()
+  ret void
+}
+
+; CHECK: call aarch64_sme_preservemost_from_x0 void @__arm_tpidr2_save() #[[TPIDR2_SAVE_CALL_ATTR:[0-9]+]]
+; CHECK: declare void @__arm_tpidr2_save() #[[TPIDR2_SAVE_DECL_ATTR:[0-9]+]]
+
+; CHECK: attributes #[[TPIDR2_SAVE_DECL_ATTR]] = { "aarch64_pstate_sm_compatible" }
+; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_preserves_zt0" }
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 500fff4eb20db..7361e850d713e 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -167,6 +167,39 @@ define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind {
   ret void;
 }
 
+; Expect commit of lazy-save if ZA is dormant
+; Expect smstart ZA & clear ZT0
+; No spill & fill of ZT0 around __arm_tpidr2_save
+; Expect spill & fill of ZT0 around __arm_sme_state call
+; Before return, expect smstop ZA
+define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: zt0_new_caller_abi_routine_callee:
+; CHECK:       // %bb.0: // %prelude
+; CHECK-NEXT:    sub sp, sp, #80
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    mrs x8, TPIDR2_EL0
+; CHECK-NEXT:    cbz x8, .LBB7_2
+; CHECK-NEXT:  // %bb.1: // %save.za
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEXT:  .LBB7_2:
+; CHECK-NEXT:    smstart za
+; CHECK-NEXT:    zero { zt0 }
+; CHECK-NEXT:    mov x19, sp
+; CHECK-NEXT:    str zt0, [x19]
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    ldr zt0, [x19]
+; CHECK-NEXT:    smstop za
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #80
+; CHECK-NEXT:    ret
+  %res = call {i64, i64} @__arm_sme_state()
+  %res.0 = extractvalue {i64, i64} %res, 0
+  ret i64 %res.0
+}
+
+declare {i64, i64} @__arm_sme_state()
+
 ;
 ; New-ZA Caller
 ;
@@ -179,11 +212,11 @@ define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
 ; CHECK:       // %bb.0: // %prelude
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
-; CHECK-NEXT:    cbz x8, .LBB7_2
+; CHECK-NEXT:    cbz x8, .LBB8_2
 ; CHECK-NEXT:  // %bb.1: // %save.za
 ; CHECK-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEXT:  .LBB7_2:
+; CHECK-NEXT:  .LBB8_2:
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    zero { zt0 }
 ; CHECK-NEXT:    bl callee
@@ -202,11 +235,11 @@ define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
 ; CHECK:       // %bb.0: // %prelude
 ; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
 ; CHECK-NEXT:    mrs x8, TPIDR2_EL0
-; CHECK-NEXT:    cbz x8, .LBB8_2
+; CHECK-NEXT:    cbz x8, .LBB9_2
 ; CHECK-NEXT:  // %bb.1: // %save.za
 ; CHECK-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEXT:    msr TPIDR2_EL0, xzr
-; CHECK-NEXT:  .LBB8_2:
+; CHECK-NEXT:  .LBB9_2:
 ; CHECK-NEXT:    smstart za
 ; CHECK-NEXT:    zero {za}
 ; CHECK-NEXT:    zero { zt0 }

>From e6357744193c35426a703e97f76f1b100ca602ee Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 24 Apr 2025 12:12:44 +0000
Subject: [PATCH 2/2] Add new aarch64_zt0_undef attribute

---
 llvm/lib/IR/Verifier.cpp                      |  3 ++
 llvm/lib/Target/AArch64/SMEABIPass.cpp        |  4 +--
 .../AArch64/Utils/AArch64SMEAttributes.cpp    |  2 ++
 .../AArch64/Utils/AArch64SMEAttributes.h      |  8 +++--
 .../CodeGen/AArch64/sme-new-zt0-function.ll   |  2 +-
 llvm/test/Verifier/sme-attributes.ll          |  3 ++
 .../Target/AArch64/SMEAttributesTest.cpp      | 30 +++++++++++++++++++
 7 files changed, 47 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8afe360d088bc..6060ab3f76d50 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2859,6 +2859,9 @@ void Verifier::visitFunction(const Function &F) {
   Check(!Attrs.hasAttrSomewhere(Attribute::ElementType),
         "Attribute 'elementtype' can only be applied to a callsite.", &F);
 
+  Check(!Attrs.hasFnAttr("aarch64_zt0_undef"),
+        "Attribute 'aarch64_zt0_undef' can only be applied to a callsite.");
+
   if (Attrs.hasFnAttr(Attribute::Naked))
     for (const Argument &Arg : F.args())
       Check(Arg.use_empty(), "cannot use argument of naked function", &Arg);
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 440bbb2a941ab..b6685497e1fd1 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -65,10 +65,10 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
   CallInst *Call = Builder.CreateCall(Callee);
 
   // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
-  // __arm_tpidr2_save as preserving ZT0. This prevents an unnecessary spill of
+  // that on the __arm_tpidr2_save call. This prevents an unnecessary spill of
   // ZT0 that can occur before ZA is enabled.
   if (ZT0IsUndef)
-    Call->addFnAttr(Attribute::get(Ctx, "aarch64_preserves_zt0"));
+    Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef"));
 
   Call->setCallingConv(
       CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index bf16acd7f8f7e..d039a61d6b27e 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -95,6 +95,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= encodeZT0State(StateValue::Preserved);
   if (Attrs.hasFnAttr("aarch64_new_zt0"))
     Bitmask |= encodeZT0State(StateValue::New);
+  if (Attrs.hasFnAttr("aarch64_zt0_undef"))
+    Bitmask |= encodeZT0State(StateValue::Undef);
 }
 
 bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index fb093da70c46b..ae4d3824373fe 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -32,7 +32,8 @@ class SMEAttrs {
     Out = 2,       // aarch64_out_zt0
     InOut = 3,     // aarch64_inout_zt0
     Preserved = 4, // aarch64_preserves_zt0
-    New = 5        // aarch64_new_zt0
+    New = 5,       // aarch64_new_zt0
+    Undef = 6      // aarch64_zt0_undef
   };
 
   // Enum with bitmasks for each individual SME feature.
@@ -125,6 +126,9 @@ class SMEAttrs {
   bool isPreservesZT0() const {
     return decodeZT0State(Bitmask) == StateValue::Preserved;
   }
+  bool isUndefZT0() const {
+    return decodeZT0State(Bitmask) == StateValue::Undef;
+  }
   bool sharesZT0() const {
     StateValue State = decodeZT0State(Bitmask);
     return State == StateValue::In || State == StateValue::Out ||
@@ -132,7 +136,7 @@ class SMEAttrs {
   }
   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
   bool requiresPreservingZT0(const SMEAttrs &Callee) const {
-    return hasZT0State() && !Callee.sharesZT0() &&
+    return hasZT0State() && !Callee.isUndefZT0() && !Callee.sharesZT0() &&
            !Callee.hasAgnosticZAInterface();
   }
   bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
diff --git a/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
index 715122d0fa4b4..94968ab4fd9ac 100644
--- a/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
+++ b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
@@ -11,4 +11,4 @@ define void @private_za() "aarch64_new_zt0" {
 ; CHECK: declare void @__arm_tpidr2_save() #[[TPIDR2_SAVE_DECL_ATTR:[0-9]+]]
 
 ; CHECK: attributes #[[TPIDR2_SAVE_DECL_ATTR]] = { "aarch64_pstate_sm_compatible" }
-; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_preserves_zt0" }
+; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_zt0_undef" }
diff --git a/llvm/test/Verifier/sme-attributes.ll b/llvm/test/Verifier/sme-attributes.ll
index 4bf5e813daf2f..0ae2b9fd91f52 100644
--- a/llvm/test/Verifier/sme-attributes.ll
+++ b/llvm/test/Verifier/sme-attributes.ll
@@ -68,3 +68,6 @@ declare void @zt0_inout_out() "aarch64_inout_zt0" "aarch64_out_zt0";
 
 declare void @zt0_inout_agnostic() "aarch64_inout_zt0" "aarch64_za_state_agnostic";
 ; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
+
+declare void @zt0_undef_function() "aarch64_zt0_undef";
+; CHECK: Attribute 'aarch64_zt0_undef' can only be applied to a callsite.
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 3af5e24168c8c..5f1811d0c9e8e 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -1,6 +1,7 @@
 #include "Utils/AArch64SMEAttributes.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/SourceMgr.h"
 
@@ -69,6 +70,15 @@ TEST(SMEAttributes, Constructors) {
   ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_zt0\"")
                       ->getFunction("foo"))
                   .isNewZT0());
+  ASSERT_TRUE(
+      SA(cast<CallBase>((parseIR("declare void @callee()\n"
+                                 "define void @foo() {"
+                                 "call void @callee() \"aarch64_zt0_undef\"\n"
+                                 "ret void\n}")
+                             ->getFunction("foo")
+                             ->begin()
+                             ->front())))
+          .isUndefZT0());
 
   // Invalid combinations.
   EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
@@ -215,6 +225,18 @@ TEST(SMEAttributes, Basics) {
   ASSERT_FALSE(ZT0_New.hasSharedZAInterface());
   ASSERT_TRUE(ZT0_New.hasPrivateZAInterface());
 
+  SA ZT0_Undef = SA(SA::encodeZT0State(SA::StateValue::Undef));
+  ASSERT_FALSE(ZT0_Undef.isNewZT0());
+  ASSERT_FALSE(ZT0_Undef.isInZT0());
+  ASSERT_FALSE(ZT0_Undef.isOutZT0());
+  ASSERT_FALSE(ZT0_Undef.isInOutZT0());
+  ASSERT_FALSE(ZT0_Undef.isPreservesZT0());
+  ASSERT_FALSE(ZT0_Undef.sharesZT0());
+  ASSERT_FALSE(ZT0_Undef.hasZT0State());
+  ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface());
+  ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface());
+  ASSERT_TRUE(ZT0_Undef.isUndefZT0());
+
   ASSERT_FALSE(SA(SA::Normal).isInZT0());
   ASSERT_FALSE(SA(SA::Normal).isOutZT0());
   ASSERT_FALSE(SA(SA::Normal).isInOutZT0());
@@ -285,6 +307,7 @@ TEST(SMEAttributes, Transitions) {
   SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
   SA ZA_ZT0_Shared = SA(SA::encodeZAState(SA::StateValue::In) |
                         SA::encodeZT0State(SA::StateValue::In));
+  SA Undef_ZT0 = SA((SA::encodeZT0State(SA::StateValue::Undef)));
 
   // Shared ZA -> Private ZA Interface
   ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
@@ -295,6 +318,13 @@ TEST(SMEAttributes, Transitions) {
   ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
   ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
 
+  // Shared Undef ZT0 -> Private ZA Interface
+  // Note: "Undef ZT0" is a callsite attribute that means ZT0 is undefined at
+  // point the of the call.
+  ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Undef_ZT0));
+  ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(Undef_ZT0));
+  ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Undef_ZT0));
+
   // Shared ZA & ZT0 -> Private ZA Interface
   ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
   ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));



More information about the llvm-commits mailing list