[llvm-branch-commits] [llvm] release/20.x: [AArch64][SME] Prevent spills of ZT0 when ZA is not enabled (PR #137683)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Apr 28 11:32:45 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This cherry-picks https://github.com/llvm/llvm-project/pull/132722 and https://github.com/llvm/llvm-project/pull/136726 (the latter is based on the former).
These patches are needed to prevent invalid codegen as attempting to store ZT0 without ZA enabled results in a SIGILL.
---
Full diff: https://github.com/llvm/llvm-project/pull/137683.diff
9 Files Affected:
- (modified) llvm/lib/IR/Verifier.cpp (+3)
- (modified) llvm/lib/Target/AArch64/SMEABIPass.cpp (+12-4)
- (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp (+2)
- (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (+5-3)
- (modified) llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll (+2-7)
- (added) llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll (+14)
- (modified) llvm/test/CodeGen/AArch64/sme-zt0-state.ll (+75-19)
- (modified) llvm/test/Verifier/sme-attributes.ll (+3)
- (modified) llvm/unittests/Target/AArch64/SMEAttributesTest.cpp (+30)
``````````diff
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8432779c107de..551c00a518b8f 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2818,6 +2818,9 @@ void Verifier::visitFunction(const Function &F) {
Check(!Attrs.hasAttrSomewhere(Attribute::ElementType),
"Attribute 'elementtype' can only be applied to a callsite.", &F);
+ Check(!Attrs.hasFnAttr("aarch64_zt0_undef"),
+ "Attribute 'aarch64_zt0_undef' can only be applied to a callsite.");
+
if (Attrs.hasFnAttr(Attribute::Naked))
for (const Argument &Arg : F.args())
Check(Arg.use_empty(), "cannot use argument of naked function", &Arg);
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index bb885d86392fe..b6685497e1fd1 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -54,14 +54,22 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
//===----------------------------------------------------------------------===//
// Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0.
-void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
+void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
+ auto &Ctx = M->getContext();
auto *TPIDR2SaveTy =
FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
- auto Attrs = AttributeList().addFnAttribute(M->getContext(),
- "aarch64_pstate_sm_compatible");
+ auto Attrs =
+ AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible");
FunctionCallee Callee =
M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
CallInst *Call = Builder.CreateCall(Callee);
+
+ // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
+ // that on the __arm_tpidr2_save call. This prevents an unnecessary spill of
+ // ZT0 that can occur before ZA is enabled.
+ if (ZT0IsUndef)
+ Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef"));
+
Call->setCallingConv(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
@@ -119,7 +127,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
// Create a call __arm_tpidr2_save, which commits the lazy save.
Builder.SetInsertPoint(&SaveBB->back());
- emitTPIDR2Save(M, Builder);
+ emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
// Enable pstate.za at the start of the function.
Builder.SetInsertPoint(&OrigBB->front());
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index bf16acd7f8f7e..76d2ac6a601e5 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -75,6 +75,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= SM_Body;
if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
Bitmask |= ZA_State_Agnostic;
+ if (Attrs.hasFnAttr("aarch64_zt0_undef"))
+ Bitmask |= ZT0_Undef;
if (Attrs.hasFnAttr("aarch64_in_za"))
Bitmask |= encodeZAState(StateValue::In);
if (Attrs.hasFnAttr("aarch64_out_za"))
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index fb093da70c46b..1691d4fec8b68 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -43,9 +43,10 @@ class SMEAttrs {
SM_Body = 1 << 2, // aarch64_pstate_sm_body
SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
ZA_State_Agnostic = 1 << 4,
- ZA_Shift = 5,
+ ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
+ ZA_Shift = 6,
ZA_Mask = 0b111 << ZA_Shift,
- ZT0_Shift = 8,
+ ZT0_Shift = 9,
ZT0_Mask = 0b111 << ZT0_Shift
};
@@ -125,6 +126,7 @@ class SMEAttrs {
bool isPreservesZT0() const {
return decodeZT0State(Bitmask) == StateValue::Preserved;
}
+ bool isUndefZT0() const { return Bitmask & ZT0_Undef; }
bool sharesZT0() const {
StateValue State = decodeZT0State(Bitmask);
return State == StateValue::In || State == StateValue::Out ||
@@ -132,7 +134,7 @@ class SMEAttrs {
}
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
- return hasZT0State() && !Callee.sharesZT0() &&
+ return hasZT0State() && !Callee.isUndefZT0() && !Callee.sharesZT0() &&
!Callee.hasAgnosticZAInterface();
}
bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
index 33d08beae2ca7..4a52bf27a7591 100644
--- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
+++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
@@ -475,16 +475,12 @@ declare double @zt0_shared_callee(double) "aarch64_inout_zt0"
define double @zt0_new_caller_to_zt0_shared_callee(double %x) nounwind noinline optnone "aarch64_new_zt0" {
; CHECK-COMMON-LABEL: zt0_new_caller_to_zt0_shared_callee:
; CHECK-COMMON: // %bb.0: // %prelude
-; CHECK-COMMON-NEXT: sub sp, sp, #80
-; CHECK-COMMON-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-COMMON-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-COMMON-NEXT: mrs x8, TPIDR2_EL0
; CHECK-COMMON-NEXT: cbz x8, .LBB13_2
; CHECK-COMMON-NEXT: b .LBB13_1
; CHECK-COMMON-NEXT: .LBB13_1: // %save.za
-; CHECK-COMMON-NEXT: mov x8, sp
-; CHECK-COMMON-NEXT: str zt0, [x8]
; CHECK-COMMON-NEXT: bl __arm_tpidr2_save
-; CHECK-COMMON-NEXT: ldr zt0, [x8]
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, xzr
; CHECK-COMMON-NEXT: b .LBB13_2
; CHECK-COMMON-NEXT: .LBB13_2: // %entry
@@ -495,8 +491,7 @@ define double @zt0_new_caller_to_zt0_shared_callee(double %x) nounwind noinline
; CHECK-COMMON-NEXT: fmov d1, x8
; CHECK-COMMON-NEXT: fadd d0, d0, d1
; CHECK-COMMON-NEXT: smstop za
-; CHECK-COMMON-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
-; CHECK-COMMON-NEXT: add sp, sp, #80
+; CHECK-COMMON-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-COMMON-NEXT: ret
entry:
%call = call double @zt0_shared_callee(double %x)
diff --git a/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
new file mode 100644
index 0000000000000..94968ab4fd9ac
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
@@ -0,0 +1,14 @@
+; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi %s | FileCheck %s
+
+declare void @callee();
+
+define void @private_za() "aarch64_new_zt0" {
+ call void @callee()
+ ret void
+}
+
+; CHECK: call aarch64_sme_preservemost_from_x0 void @__arm_tpidr2_save() #[[TPIDR2_SAVE_CALL_ATTR:[0-9]+]]
+; CHECK: declare void @__arm_tpidr2_save() #[[TPIDR2_SAVE_DECL_ATTR:[0-9]+]]
+
+; CHECK: attributes #[[TPIDR2_SAVE_DECL_ATTR]] = { "aarch64_pstate_sm_compatible" }
+; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_zt0_undef" }
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 312537630e77a..7361e850d713e 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -112,7 +112,7 @@ define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_inout_za" "aar
ret void;
}
-; New-ZA Callee
+; New-ZT0 Callee
; Expect spill & fill of ZT0 around call
; Expect smstop/smstart za around call
@@ -134,6 +134,72 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
ret void;
}
+; New-ZT0 Callee
+
+; Expect commit of lazy-save if ZA is dormant
+; Expect smstart ZA & clear ZT0
+; Expect spill & fill of ZT0 around call
+; Before return, expect smstop ZA
+define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: zt0_new_caller_zt0_new_callee:
+; CHECK: // %bb.0: // %prelude
+; CHECK-NEXT: sub sp, sp, #80
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: mrs x8, TPIDR2_EL0
+; CHECK-NEXT: cbz x8, .LBB6_2
+; CHECK-NEXT: // %bb.1: // %save.za
+; CHECK-NEXT: bl __arm_tpidr2_save
+; CHECK-NEXT: msr TPIDR2_EL0, xzr
+; CHECK-NEXT: .LBB6_2:
+; CHECK-NEXT: smstart za
+; CHECK-NEXT: zero { zt0 }
+; CHECK-NEXT: mov x19, sp
+; CHECK-NEXT: str zt0, [x19]
+; CHECK-NEXT: smstop za
+; CHECK-NEXT: bl callee
+; CHECK-NEXT: smstart za
+; CHECK-NEXT: ldr zt0, [x19]
+; CHECK-NEXT: smstop za
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: add sp, sp, #80
+; CHECK-NEXT: ret
+ call void @callee() "aarch64_new_zt0";
+ ret void;
+}
+
+; Expect commit of lazy-save if ZA is dormant
+; Expect smstart ZA & clear ZT0
+; No spill & fill of ZT0 around __arm_tpidr2_save
+; Expect spill & fill of ZT0 around __arm_sme_state call
+; Before return, expect smstop ZA
+define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: zt0_new_caller_abi_routine_callee:
+; CHECK: // %bb.0: // %prelude
+; CHECK-NEXT: sub sp, sp, #80
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: mrs x8, TPIDR2_EL0
+; CHECK-NEXT: cbz x8, .LBB7_2
+; CHECK-NEXT: // %bb.1: // %save.za
+; CHECK-NEXT: bl __arm_tpidr2_save
+; CHECK-NEXT: msr TPIDR2_EL0, xzr
+; CHECK-NEXT: .LBB7_2:
+; CHECK-NEXT: smstart za
+; CHECK-NEXT: zero { zt0 }
+; CHECK-NEXT: mov x19, sp
+; CHECK-NEXT: str zt0, [x19]
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: ldr zt0, [x19]
+; CHECK-NEXT: smstop za
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: add sp, sp, #80
+; CHECK-NEXT: ret
+ %res = call {i64, i64} @__arm_sme_state()
+ %res.0 = extractvalue {i64, i64} %res, 0
+ ret i64 %res.0
+}
+
+declare {i64, i64} @__arm_sme_state()
+
;
; New-ZA Caller
;
@@ -144,23 +210,18 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
; CHECK-LABEL: zt0_new_caller:
; CHECK: // %bb.0: // %prelude
-; CHECK-NEXT: sub sp, sp, #80
-; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
-; CHECK-NEXT: cbz x8, .LBB6_2
+; CHECK-NEXT: cbz x8, .LBB8_2
; CHECK-NEXT: // %bb.1: // %save.za
-; CHECK-NEXT: mov x8, sp
-; CHECK-NEXT: str zt0, [x8]
; CHECK-NEXT: bl __arm_tpidr2_save
-; CHECK-NEXT: ldr zt0, [x8]
; CHECK-NEXT: msr TPIDR2_EL0, xzr
-; CHECK-NEXT: .LBB6_2:
+; CHECK-NEXT: .LBB8_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstop za
-; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
-; CHECK-NEXT: add sp, sp, #80
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_in_zt0";
ret void;
@@ -172,24 +233,19 @@ define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
; CHECK-LABEL: new_za_zt0_caller:
; CHECK: // %bb.0: // %prelude
-; CHECK-NEXT: sub sp, sp, #80
-; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
-; CHECK-NEXT: cbz x8, .LBB7_2
+; CHECK-NEXT: cbz x8, .LBB9_2
; CHECK-NEXT: // %bb.1: // %save.za
-; CHECK-NEXT: mov x8, sp
-; CHECK-NEXT: str zt0, [x8]
; CHECK-NEXT: bl __arm_tpidr2_save
-; CHECK-NEXT: ldr zt0, [x8]
; CHECK-NEXT: msr TPIDR2_EL0, xzr
-; CHECK-NEXT: .LBB7_2:
+; CHECK-NEXT: .LBB9_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero {za}
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstop za
-; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
-; CHECK-NEXT: add sp, sp, #80
+; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
ret void;
diff --git a/llvm/test/Verifier/sme-attributes.ll b/llvm/test/Verifier/sme-attributes.ll
index 4bf5e813daf2f..0ae2b9fd91f52 100644
--- a/llvm/test/Verifier/sme-attributes.ll
+++ b/llvm/test/Verifier/sme-attributes.ll
@@ -68,3 +68,6 @@ declare void @zt0_inout_out() "aarch64_inout_zt0" "aarch64_out_zt0";
declare void @zt0_inout_agnostic() "aarch64_inout_zt0" "aarch64_za_state_agnostic";
; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
+
+declare void @zt0_undef_function() "aarch64_zt0_undef";
+; CHECK: Attribute 'aarch64_zt0_undef' can only be applied to a callsite.
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 3af5e24168c8c..f8c77fcba19cf 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -1,6 +1,7 @@
#include "Utils/AArch64SMEAttributes.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SourceMgr.h"
@@ -69,6 +70,15 @@ TEST(SMEAttributes, Constructors) {
ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_zt0\"")
->getFunction("foo"))
.isNewZT0());
+ ASSERT_TRUE(
+ SA(cast<CallBase>((parseIR("declare void @callee()\n"
+ "define void @foo() {"
+ "call void @callee() \"aarch64_zt0_undef\"\n"
+ "ret void\n}")
+ ->getFunction("foo")
+ ->begin()
+ ->front())))
+ .isUndefZT0());
// Invalid combinations.
EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
@@ -215,6 +225,18 @@ TEST(SMEAttributes, Basics) {
ASSERT_FALSE(ZT0_New.hasSharedZAInterface());
ASSERT_TRUE(ZT0_New.hasPrivateZAInterface());
+ SA ZT0_Undef = SA(SA::ZT0_Undef | SA::encodeZT0State(SA::StateValue::New));
+ ASSERT_TRUE(ZT0_Undef.isNewZT0());
+ ASSERT_FALSE(ZT0_Undef.isInZT0());
+ ASSERT_FALSE(ZT0_Undef.isOutZT0());
+ ASSERT_FALSE(ZT0_Undef.isInOutZT0());
+ ASSERT_FALSE(ZT0_Undef.isPreservesZT0());
+ ASSERT_FALSE(ZT0_Undef.sharesZT0());
+ ASSERT_TRUE(ZT0_Undef.hasZT0State());
+ ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface());
+ ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface());
+ ASSERT_TRUE(ZT0_Undef.isUndefZT0());
+
ASSERT_FALSE(SA(SA::Normal).isInZT0());
ASSERT_FALSE(SA(SA::Normal).isOutZT0());
ASSERT_FALSE(SA(SA::Normal).isInOutZT0());
@@ -285,6 +307,7 @@ TEST(SMEAttributes, Transitions) {
SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
SA ZA_ZT0_Shared = SA(SA::encodeZAState(SA::StateValue::In) |
SA::encodeZT0State(SA::StateValue::In));
+ SA Undef_ZT0 = SA(SA::ZT0_Undef);
// Shared ZA -> Private ZA Interface
ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
@@ -295,6 +318,13 @@ TEST(SMEAttributes, Transitions) {
ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+ // Shared Undef ZT0 -> Private ZA Interface
+ // Note: "Undef ZT0" is a callsite attribute that means ZT0 is undefined at
+ // point the of the call.
+ ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Undef_ZT0));
+ ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(Undef_ZT0));
+ ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Undef_ZT0));
+
// Shared ZA & ZT0 -> Private ZA Interface
ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
``````````
</details>
https://github.com/llvm/llvm-project/pull/137683
More information about the llvm-branch-commits
mailing list