[llvm] [AArch64][SME] Split SMECallAttrs out of SMEAttrs (NFC) (PR #137239)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 25 04:35:14 PDT 2025
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/137239
>From 68b5f0e76615da387d973d663121c3285bc2ce4b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 22 Apr 2025 15:41:22 +0000
Subject: [PATCH 1/5] [AArch64][SME] Allow spills of ZT0 arounds SME ABI
routines again
In #132722 spills of ZT0 were disabled around all SME ABI routines to
avoid a case where ZT0 is spilled before ZA is enabled (resulting in a
crash).
It turns out that the ABI does not promise that routines will preserve
ZT0 (however in practice they do), so generally disabling ZT0 spills for
ABI routines is not correct.
The case where a crash was possible was "aarch64_new_zt0" functions with
ZA disabled on entry and a ZT0 spill around __arm_tpidr2_save. In this
case, ZT0 will be undefined at the call to __arm_tpidr2_save, so we can
mark the call as preserving ZT0 (whether it does or not) to avoid the
ZT0 spills.
---
llvm/lib/Target/AArch64/SMEABIPass.cpp | 16 ++++++--
.../AArch64/Utils/AArch64SMEAttributes.h | 3 +-
.../CodeGen/AArch64/sme-new-zt0-function.ll | 14 +++++++
llvm/test/CodeGen/AArch64/sme-zt0-state.ll | 41 +++++++++++++++++--
4 files changed, 64 insertions(+), 10 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index bb885d86392fe..440bbb2a941ab 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -54,14 +54,22 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
//===----------------------------------------------------------------------===//
// Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0.
-void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
+void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
+ auto &Ctx = M->getContext();
auto *TPIDR2SaveTy =
FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
- auto Attrs = AttributeList().addFnAttribute(M->getContext(),
- "aarch64_pstate_sm_compatible");
+ auto Attrs =
+ AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible");
FunctionCallee Callee =
M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
CallInst *Call = Builder.CreateCall(Callee);
+
+ // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
+ // __arm_tpidr2_save as preserving ZT0. This prevents an unnecessary spill of
+ // ZT0 that can occur before ZA is enabled.
+ if (ZT0IsUndef)
+ Call->addFnAttr(Attribute::get(Ctx, "aarch64_preserves_zt0"));
+
Call->setCallingConv(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
@@ -119,7 +127,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
// Create a call __arm_tpidr2_save, which commits the lazy save.
Builder.SetInsertPoint(&SaveBB->back());
- emitTPIDR2Save(M, Builder);
+ emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0());
// Enable pstate.za at the start of the function.
Builder.SetInsertPoint(&OrigBB->front());
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index a3ebf764a6e0c..fb093da70c46b 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -133,8 +133,7 @@ class SMEAttrs {
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
return hasZT0State() && !Callee.sharesZT0() &&
- !Callee.hasAgnosticZAInterface() &&
- !(Callee.Bitmask & SME_ABI_Routine);
+ !Callee.hasAgnosticZAInterface();
}
bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
diff --git a/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
new file mode 100644
index 0000000000000..715122d0fa4b4
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
@@ -0,0 +1,14 @@
+; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi %s | FileCheck %s
+
+declare void @callee();
+
+define void @private_za() "aarch64_new_zt0" {
+ call void @callee()
+ ret void
+}
+
+; CHECK: call aarch64_sme_preservemost_from_x0 void @__arm_tpidr2_save() #[[TPIDR2_SAVE_CALL_ATTR:[0-9]+]]
+; CHECK: declare void @__arm_tpidr2_save() #[[TPIDR2_SAVE_DECL_ATTR:[0-9]+]]
+
+; CHECK: attributes #[[TPIDR2_SAVE_DECL_ATTR]] = { "aarch64_pstate_sm_compatible" }
+; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_preserves_zt0" }
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 500fff4eb20db..7361e850d713e 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -167,6 +167,39 @@ define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind {
ret void;
}
+; Expect commit of lazy-save if ZA is dormant
+; Expect smstart ZA & clear ZT0
+; No spill & fill of ZT0 around __arm_tpidr2_save
+; Expect spill & fill of ZT0 around __arm_sme_state call
+; Before return, expect smstop ZA
+define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind {
+; CHECK-LABEL: zt0_new_caller_abi_routine_callee:
+; CHECK: // %bb.0: // %prelude
+; CHECK-NEXT: sub sp, sp, #80
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: mrs x8, TPIDR2_EL0
+; CHECK-NEXT: cbz x8, .LBB7_2
+; CHECK-NEXT: // %bb.1: // %save.za
+; CHECK-NEXT: bl __arm_tpidr2_save
+; CHECK-NEXT: msr TPIDR2_EL0, xzr
+; CHECK-NEXT: .LBB7_2:
+; CHECK-NEXT: smstart za
+; CHECK-NEXT: zero { zt0 }
+; CHECK-NEXT: mov x19, sp
+; CHECK-NEXT: str zt0, [x19]
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: ldr zt0, [x19]
+; CHECK-NEXT: smstop za
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: add sp, sp, #80
+; CHECK-NEXT: ret
+ %res = call {i64, i64} @__arm_sme_state()
+ %res.0 = extractvalue {i64, i64} %res, 0
+ ret i64 %res.0
+}
+
+declare {i64, i64} @__arm_sme_state()
+
;
; New-ZA Caller
;
@@ -179,11 +212,11 @@ define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
-; CHECK-NEXT: cbz x8, .LBB7_2
+; CHECK-NEXT: cbz x8, .LBB8_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: msr TPIDR2_EL0, xzr
-; CHECK-NEXT: .LBB7_2:
+; CHECK-NEXT: .LBB8_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
@@ -202,11 +235,11 @@ define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
-; CHECK-NEXT: cbz x8, .LBB8_2
+; CHECK-NEXT: cbz x8, .LBB9_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: msr TPIDR2_EL0, xzr
-; CHECK-NEXT: .LBB8_2:
+; CHECK-NEXT: .LBB9_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero {za}
; CHECK-NEXT: zero { zt0 }
>From e6357744193c35426a703e97f76f1b100ca602ee Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 24 Apr 2025 12:12:44 +0000
Subject: [PATCH 2/5] Add new aarch64_zt0_undef attribute
---
llvm/lib/IR/Verifier.cpp | 3 ++
llvm/lib/Target/AArch64/SMEABIPass.cpp | 4 +--
.../AArch64/Utils/AArch64SMEAttributes.cpp | 2 ++
.../AArch64/Utils/AArch64SMEAttributes.h | 8 +++--
.../CodeGen/AArch64/sme-new-zt0-function.ll | 2 +-
llvm/test/Verifier/sme-attributes.ll | 3 ++
.../Target/AArch64/SMEAttributesTest.cpp | 30 +++++++++++++++++++
7 files changed, 47 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8afe360d088bc..6060ab3f76d50 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2859,6 +2859,9 @@ void Verifier::visitFunction(const Function &F) {
Check(!Attrs.hasAttrSomewhere(Attribute::ElementType),
"Attribute 'elementtype' can only be applied to a callsite.", &F);
+ Check(!Attrs.hasFnAttr("aarch64_zt0_undef"),
+ "Attribute 'aarch64_zt0_undef' can only be applied to a callsite.");
+
if (Attrs.hasFnAttr(Attribute::Naked))
for (const Argument &Arg : F.args())
Check(Arg.use_empty(), "cannot use argument of naked function", &Arg);
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 440bbb2a941ab..b6685497e1fd1 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -65,10 +65,10 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
CallInst *Call = Builder.CreateCall(Callee);
// If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
- // __arm_tpidr2_save as preserving ZT0. This prevents an unnecessary spill of
+ // that on the __arm_tpidr2_save call. This prevents an unnecessary spill of
// ZT0 that can occur before ZA is enabled.
if (ZT0IsUndef)
- Call->addFnAttr(Attribute::get(Ctx, "aarch64_preserves_zt0"));
+ Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef"));
Call->setCallingConv(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index bf16acd7f8f7e..d039a61d6b27e 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -95,6 +95,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= encodeZT0State(StateValue::Preserved);
if (Attrs.hasFnAttr("aarch64_new_zt0"))
Bitmask |= encodeZT0State(StateValue::New);
+ if (Attrs.hasFnAttr("aarch64_zt0_undef"))
+ Bitmask |= encodeZT0State(StateValue::Undef);
}
bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index fb093da70c46b..ae4d3824373fe 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -32,7 +32,8 @@ class SMEAttrs {
Out = 2, // aarch64_out_zt0
InOut = 3, // aarch64_inout_zt0
Preserved = 4, // aarch64_preserves_zt0
- New = 5 // aarch64_new_zt0
+ New = 5, // aarch64_new_zt0
+ Undef = 6 // aarch64_zt0_undef
};
// Enum with bitmasks for each individual SME feature.
@@ -125,6 +126,9 @@ class SMEAttrs {
bool isPreservesZT0() const {
return decodeZT0State(Bitmask) == StateValue::Preserved;
}
+ bool isUndefZT0() const {
+ return decodeZT0State(Bitmask) == StateValue::Undef;
+ }
bool sharesZT0() const {
StateValue State = decodeZT0State(Bitmask);
return State == StateValue::In || State == StateValue::Out ||
@@ -132,7 +136,7 @@ class SMEAttrs {
}
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
- return hasZT0State() && !Callee.sharesZT0() &&
+ return hasZT0State() && !Callee.isUndefZT0() && !Callee.sharesZT0() &&
!Callee.hasAgnosticZAInterface();
}
bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
diff --git a/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
index 715122d0fa4b4..94968ab4fd9ac 100644
--- a/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
+++ b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
@@ -11,4 +11,4 @@ define void @private_za() "aarch64_new_zt0" {
; CHECK: declare void @__arm_tpidr2_save() #[[TPIDR2_SAVE_DECL_ATTR:[0-9]+]]
; CHECK: attributes #[[TPIDR2_SAVE_DECL_ATTR]] = { "aarch64_pstate_sm_compatible" }
-; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_preserves_zt0" }
+; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_zt0_undef" }
diff --git a/llvm/test/Verifier/sme-attributes.ll b/llvm/test/Verifier/sme-attributes.ll
index 4bf5e813daf2f..0ae2b9fd91f52 100644
--- a/llvm/test/Verifier/sme-attributes.ll
+++ b/llvm/test/Verifier/sme-attributes.ll
@@ -68,3 +68,6 @@ declare void @zt0_inout_out() "aarch64_inout_zt0" "aarch64_out_zt0";
declare void @zt0_inout_agnostic() "aarch64_inout_zt0" "aarch64_za_state_agnostic";
; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive
+
+declare void @zt0_undef_function() "aarch64_zt0_undef";
+; CHECK: Attribute 'aarch64_zt0_undef' can only be applied to a callsite.
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 3af5e24168c8c..5f1811d0c9e8e 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -1,6 +1,7 @@
#include "Utils/AArch64SMEAttributes.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SourceMgr.h"
@@ -69,6 +70,15 @@ TEST(SMEAttributes, Constructors) {
ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_zt0\"")
->getFunction("foo"))
.isNewZT0());
+ ASSERT_TRUE(
+ SA(cast<CallBase>((parseIR("declare void @callee()\n"
+ "define void @foo() {"
+ "call void @callee() \"aarch64_zt0_undef\"\n"
+ "ret void\n}")
+ ->getFunction("foo")
+ ->begin()
+ ->front())))
+ .isUndefZT0());
// Invalid combinations.
EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
@@ -215,6 +225,18 @@ TEST(SMEAttributes, Basics) {
ASSERT_FALSE(ZT0_New.hasSharedZAInterface());
ASSERT_TRUE(ZT0_New.hasPrivateZAInterface());
+ SA ZT0_Undef = SA(SA::encodeZT0State(SA::StateValue::Undef));
+ ASSERT_FALSE(ZT0_Undef.isNewZT0());
+ ASSERT_FALSE(ZT0_Undef.isInZT0());
+ ASSERT_FALSE(ZT0_Undef.isOutZT0());
+ ASSERT_FALSE(ZT0_Undef.isInOutZT0());
+ ASSERT_FALSE(ZT0_Undef.isPreservesZT0());
+ ASSERT_FALSE(ZT0_Undef.sharesZT0());
+ ASSERT_FALSE(ZT0_Undef.hasZT0State());
+ ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface());
+ ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface());
+ ASSERT_TRUE(ZT0_Undef.isUndefZT0());
+
ASSERT_FALSE(SA(SA::Normal).isInZT0());
ASSERT_FALSE(SA(SA::Normal).isOutZT0());
ASSERT_FALSE(SA(SA::Normal).isInOutZT0());
@@ -285,6 +307,7 @@ TEST(SMEAttributes, Transitions) {
SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
SA ZA_ZT0_Shared = SA(SA::encodeZAState(SA::StateValue::In) |
SA::encodeZT0State(SA::StateValue::In));
+ SA Undef_ZT0 = SA((SA::encodeZT0State(SA::StateValue::Undef)));
// Shared ZA -> Private ZA Interface
ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
@@ -295,6 +318,13 @@ TEST(SMEAttributes, Transitions) {
ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+ // Shared Undef ZT0 -> Private ZA Interface
+ // Note: "Undef ZT0" is a callsite attribute that means ZT0 is undefined at
+ // point the of the call.
+ ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Undef_ZT0));
+ ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(Undef_ZT0));
+ ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Undef_ZT0));
+
// Shared ZA & ZT0 -> Private ZA Interface
ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
>From 2590140ff3f450630eec1feacff6352e3bfc8214 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 24 Apr 2025 13:19:34 +0000
Subject: [PATCH 3/5] Separate "ZT0 undef" from StateValue
This allows this to be independent from new/preserves/shares ZT0.
---
.../Target/AArch64/Utils/AArch64SMEAttributes.cpp | 4 ++--
llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h | 12 +++++-------
llvm/unittests/Target/AArch64/SMEAttributesTest.cpp | 8 ++++----
3 files changed, 11 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index d039a61d6b27e..76d2ac6a601e5 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -75,6 +75,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= SM_Body;
if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
Bitmask |= ZA_State_Agnostic;
+ if (Attrs.hasFnAttr("aarch64_zt0_undef"))
+ Bitmask |= ZT0_Undef;
if (Attrs.hasFnAttr("aarch64_in_za"))
Bitmask |= encodeZAState(StateValue::In);
if (Attrs.hasFnAttr("aarch64_out_za"))
@@ -95,8 +97,6 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= encodeZT0State(StateValue::Preserved);
if (Attrs.hasFnAttr("aarch64_new_zt0"))
Bitmask |= encodeZT0State(StateValue::New);
- if (Attrs.hasFnAttr("aarch64_zt0_undef"))
- Bitmask |= encodeZT0State(StateValue::Undef);
}
bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index ae4d3824373fe..1691d4fec8b68 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -32,8 +32,7 @@ class SMEAttrs {
Out = 2, // aarch64_out_zt0
InOut = 3, // aarch64_inout_zt0
Preserved = 4, // aarch64_preserves_zt0
- New = 5, // aarch64_new_zt0
- Undef = 6 // aarch64_zt0_undef
+ New = 5 // aarch64_new_zt0
};
// Enum with bitmasks for each individual SME feature.
@@ -44,9 +43,10 @@ class SMEAttrs {
SM_Body = 1 << 2, // aarch64_pstate_sm_body
SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
ZA_State_Agnostic = 1 << 4,
- ZA_Shift = 5,
+ ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
+ ZA_Shift = 6,
ZA_Mask = 0b111 << ZA_Shift,
- ZT0_Shift = 8,
+ ZT0_Shift = 9,
ZT0_Mask = 0b111 << ZT0_Shift
};
@@ -126,9 +126,7 @@ class SMEAttrs {
bool isPreservesZT0() const {
return decodeZT0State(Bitmask) == StateValue::Preserved;
}
- bool isUndefZT0() const {
- return decodeZT0State(Bitmask) == StateValue::Undef;
- }
+ bool isUndefZT0() const { return Bitmask & ZT0_Undef; }
bool sharesZT0() const {
StateValue State = decodeZT0State(Bitmask);
return State == StateValue::In || State == StateValue::Out ||
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 5f1811d0c9e8e..f8c77fcba19cf 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -225,14 +225,14 @@ TEST(SMEAttributes, Basics) {
ASSERT_FALSE(ZT0_New.hasSharedZAInterface());
ASSERT_TRUE(ZT0_New.hasPrivateZAInterface());
- SA ZT0_Undef = SA(SA::encodeZT0State(SA::StateValue::Undef));
- ASSERT_FALSE(ZT0_Undef.isNewZT0());
+ SA ZT0_Undef = SA(SA::ZT0_Undef | SA::encodeZT0State(SA::StateValue::New));
+ ASSERT_TRUE(ZT0_Undef.isNewZT0());
ASSERT_FALSE(ZT0_Undef.isInZT0());
ASSERT_FALSE(ZT0_Undef.isOutZT0());
ASSERT_FALSE(ZT0_Undef.isInOutZT0());
ASSERT_FALSE(ZT0_Undef.isPreservesZT0());
ASSERT_FALSE(ZT0_Undef.sharesZT0());
- ASSERT_FALSE(ZT0_Undef.hasZT0State());
+ ASSERT_TRUE(ZT0_Undef.hasZT0State());
ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface());
ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface());
ASSERT_TRUE(ZT0_Undef.isUndefZT0());
@@ -307,7 +307,7 @@ TEST(SMEAttributes, Transitions) {
SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
SA ZA_ZT0_Shared = SA(SA::encodeZAState(SA::StateValue::In) |
SA::encodeZT0State(SA::StateValue::In));
- SA Undef_ZT0 = SA((SA::encodeZT0State(SA::StateValue::Undef)));
+ SA Undef_ZT0 = SA(SA::ZT0_Undef);
// Shared ZA -> Private ZA Interface
ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
>From 20fc9f1fe1c610bca540820d4455bc0625e061c5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 25 Apr 2025 07:58:43 +0000
Subject: [PATCH 4/5] [AArch64][SME] Split SMECallAttrs out of SMEAttrs (NFC)
SMECallAttrs is a new helper class that holds all the SMEAttrs for a
call. The interfaces to query actions needed for the call (e.g. change
streaming mode) have been moved to the SMECallAttrs class.
The main motivation for this change is to make the split between caller,
callee, and callsite attributes more apparent. Places that previously
implicitly checked callsite attributes have been updated to make these
checks explicit. Similarly, places known to only check callee or
callsite attributes have also been updated to make this clear.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 74 ++++++------
.../AArch64/AArch64TargetTransformInfo.cpp | 25 +++--
.../AArch64/Utils/AArch64SMEAttributes.cpp | 58 +++++-----
.../AArch64/Utils/AArch64SMEAttributes.h | 104 ++++++++++++-----
.../Target/AArch64/SMEAttributesTest.cpp | 106 +++++++++---------
5 files changed, 206 insertions(+), 161 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 771eee1b3fecf..c445210cc3d4e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8652,6 +8652,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
}
}
+static SMECallAttrs
+getSMECallAttrs(const Function &Function,
+ const TargetLowering::CallLoweringInfo &CLI) {
+ if (CLI.CB)
+ return SMECallAttrs(*CLI.CB);
+ if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
+ return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
+ return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
+}
+
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
const CallLoweringInfo &CLI) const {
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8670,12 +8680,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
// SME Streaming functions are not eligible for TCO as they may require
// the streaming mode or ZA to be restored after returning from the call.
- SMEAttrs CallerAttrs(MF.getFunction());
- auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
- CallerAttrs.requiresLazySave(CalleeAttrs) ||
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
- CallerAttrs.hasStreamingBody())
+ SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
+ CallAttrs.requiresPreservingAllZAState() ||
+ CallAttrs.caller().hasStreamingBody())
return false;
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8967,14 +8975,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
return TLI.LowerCallTo(CLI).second;
}
-static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
- const SMEAttrs &CalleeAttrs) {
- if (!CallerAttrs.hasStreamingCompatibleInterface() ||
- CallerAttrs.hasStreamingBody())
+static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
+ if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
+ CallAttrs.caller().hasStreamingBody())
return AArch64SME::Always;
- if (CalleeAttrs.hasNonStreamingInterface())
+ if (CallAttrs.calleeOrCallsite().hasNonStreamingInterface())
return AArch64SME::IfCallerIsStreaming;
- if (CalleeAttrs.hasStreamingInterface())
+ if (CallAttrs.calleeOrCallsite().hasStreamingInterface())
return AArch64SME::IfCallerIsNonStreaming;
llvm_unreachable("Unsupported attributes");
@@ -9107,11 +9114,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}
// Determine whether we need any streaming mode changes.
- SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
- if (CLI.CB)
- CalleeAttrs = SMEAttrs(*CLI.CB);
- else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
- CalleeAttrs = SMEAttrs(ES->getSymbol());
+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
auto DescribeCallsite =
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9126,9 +9129,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return R;
};
- bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
- bool RequiresSaveAllZA =
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
+ bool RequiresLazySave = CallAttrs.requiresLazySave();
+ bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
MachinePointerInfo MPI =
@@ -9156,18 +9158,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return DescribeCallsite(R) << " sets up a lazy save for ZA";
});
} else if (RequiresSaveAllZA) {
- assert(!CalleeAttrs.hasSharedZAInterface() &&
+ assert(!CallAttrs.calleeOrCallsite().hasSharedZAInterface() &&
"Cannot share state that may not exist");
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
/*IsSave=*/true);
}
SDValue PStateSM;
- bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
+ bool RequiresSMChange = CallAttrs.requiresSMChange();
if (RequiresSMChange) {
- if (CallerAttrs.hasStreamingInterfaceOrBody())
+ if (CallAttrs.caller().hasStreamingInterfaceOrBody())
PStateSM = DAG.getConstant(1, DL, MVT::i64);
- else if (CallerAttrs.hasNonStreamingInterface())
+ else if (CallAttrs.caller().hasNonStreamingInterface())
PStateSM = DAG.getConstant(0, DL, MVT::i64);
else
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9184,7 +9186,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue ZTFrameIdx;
MachineFrameInfo &MFI = MF.getFrameInfo();
- bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
+ bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
@@ -9200,7 +9202,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
// PSTATE.ZA before the call if there is no lazy-save active.
- bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
+ bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
assert((!DisableZA || !RequiresLazySave) &&
"Lazy-save should have PSTATE.SM=1 on entry to the function");
@@ -9483,8 +9485,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}
SDValue NewChain = changeStreamingMode(
- DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
- getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
+ DAG, DL, CallAttrs.calleeOrCallsite().hasStreamingInterface(), Chain,
+ InGlue, getSMCondition(CallAttrs), PStateSM);
Chain = NewChain.getValue(0);
InGlue = NewChain.getValue(1);
}
@@ -9663,8 +9665,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (RequiresSMChange) {
assert(PStateSM && "Expected a PStateSM to be set");
Result = changeStreamingMode(
- DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
- getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
+ DAG, DL, !CallAttrs.calleeOrCallsite().hasStreamingInterface(), Result,
+ InGlue, getSMCondition(CallAttrs), PStateSM);
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
InGlue = Result.getValue(1);
@@ -9674,7 +9676,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}
}
- if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
+ if (CallAttrs.requiresEnablingZAAfterCall())
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28573,12 +28575,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
// Checks to allow the use of SME instructions
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
- auto CallerAttrs = SMEAttrs(*Inst.getFunction());
- auto CalleeAttrs = SMEAttrs(*Base);
- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
- CallerAttrs.requiresLazySave(CalleeAttrs) ||
- CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
+ auto CallAttrs = SMECallAttrs(*Base);
+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
+ CallAttrs.requiresPreservingZT0() ||
+ CallAttrs.requiresPreservingAllZAState())
return true;
}
return false;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 594f1bff5c458..84d33b2b3e128 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -268,22 +268,21 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
- SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
+ SMECallAttrs CallAttrs(*Caller, *Callee);
// When inlining, we should consider the body of the function, not the
// interface.
- if (CalleeAttrs.hasStreamingBody()) {
- CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
- CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
+ if (CallAttrs.callee().hasStreamingBody()) {
+ CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
+ CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
}
- if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
+ if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
return false;
- if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
- CallerAttrs.requiresSMChange(CalleeAttrs) ||
- CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
+ if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
+ CallAttrs.requiresPreservingZT0() ||
+ CallAttrs.requiresPreservingAllZAState()) {
if (hasPossibleIncompatibleOps(Callee))
return false;
}
@@ -349,12 +348,14 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
// streaming-mode change, and the call to G from F would also require a
// streaming-mode change, then there is benefit to do the streaming-mode
// change only once and avoid inlining of G into F.
+
SMEAttrs FAttrs(*F);
- SMEAttrs CalleeAttrs(Call);
- if (FAttrs.requiresSMChange(CalleeAttrs)) {
+ SMECallAttrs CallAttrs(Call);
+
+ if (SMECallAttrs(FAttrs, CallAttrs.calleeOrCallsite()).requiresSMChange()) {
if (F == Call.getCaller()) // (1)
return CallPenaltyChangeSM * DefaultCallPenalty;
- if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
+ if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
}
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 76d2ac6a601e5..1085d618116eb 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -27,15 +27,14 @@ void SMEAttrs::set(unsigned M, bool Enable) {
"ZA_New and SME_ABI_Routine are mutually exclusive");
assert(
- (!sharesZA() ||
- (isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
+ (isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
// ZT0 Attrs
assert(
- (!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
- isPreservesZT0())) &&
+ (isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
+ 1 &&
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
@@ -44,27 +43,6 @@ void SMEAttrs::set(unsigned M, bool Enable) {
"interface");
}
-SMEAttrs::SMEAttrs(const CallBase &CB) {
- *this = SMEAttrs(CB.getAttributes());
- if (auto *F = CB.getCalledFunction()) {
- set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
- }
-}
-
-SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
- if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
- Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
- if (FuncName == "__arm_tpidr2_restore")
- Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
- SMEAttrs::SME_ABI_Routine;
- if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
- FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
- Bitmask |= SMEAttrs::SM_Compatible;
- if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
- FuncName == "__arm_sme_state_size")
- Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
-}
-
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask = 0;
if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
@@ -99,17 +77,39 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= encodeZT0State(StateValue::New);
}
-bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
- if (Callee.hasStreamingCompatibleInterface())
+void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
+ unsigned KnownAttrs = SMEAttrs::Normal;
+ if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
+ KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
+ if (FuncName == "__arm_tpidr2_restore")
+ KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
+ SMEAttrs::SME_ABI_Routine;
+ if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
+ FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
+ KnownAttrs |= SMEAttrs::SM_Compatible;
+ if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
+ FuncName == "__arm_sme_state_size")
+ KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
+ set(KnownAttrs, /*Enable=*/true);
+}
+
+bool SMECallAttrs::requiresSMChange() const {
+ if ((Callsite | Callee).hasStreamingCompatibleInterface())
return false;
// Both non-streaming
- if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
+ if (Caller.hasNonStreamingInterfaceAndBody() &&
+ (Callsite | Callee).hasNonStreamingInterface())
return false;
// Both streaming
- if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
+ if (Caller.hasStreamingInterfaceOrBody() &&
+ (Callsite | Callee).hasStreamingInterface())
return false;
return true;
}
+
+SMECallAttrs::SMECallAttrs(const CallBase &CB)
+ : SMECallAttrs(*CB.getFunction(), CB.getCalledFunction(),
+ CB.getAttributes()) {}
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 1691d4fec8b68..791bb891a18b7 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -18,12 +18,9 @@ class CallBase;
class AttributeList;
/// SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
-/// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM. It
-/// has interfaces to query whether a streaming mode change or lazy-save
-/// mechanism is required when going from one function to another (e.g. through
-/// a call).
+/// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM.
class SMEAttrs {
- unsigned Bitmask;
+ unsigned Bitmask = Normal;
public:
enum class StateValue {
@@ -43,18 +40,23 @@ class SMEAttrs {
SM_Body = 1 << 2, // aarch64_pstate_sm_body
SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
ZA_State_Agnostic = 1 << 4,
- ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
+ ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills
ZA_Shift = 6,
ZA_Mask = 0b111 << ZA_Shift,
ZT0_Shift = 9,
ZT0_Mask = 0b111 << ZT0_Shift
};
- SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); }
- SMEAttrs(const Function &F) : SMEAttrs(F.getAttributes()) {}
- SMEAttrs(const CallBase &CB);
+ SMEAttrs() = default;
+ SMEAttrs(unsigned Mask) { set(Mask); }
+ SMEAttrs(const Function *F)
+ : SMEAttrs(F ? F->getAttributes() : AttributeList()) {
+ if (F)
+ addKnownFunctionAttrs(F->getName());
+ }
+ SMEAttrs(const Function &F) : SMEAttrs(&F) {}
SMEAttrs(const AttributeList &L);
- SMEAttrs(StringRef FuncName);
+ SMEAttrs(StringRef FuncName) { addKnownFunctionAttrs(FuncName); };
void set(unsigned M, bool Enable = true);
@@ -74,10 +76,6 @@ class SMEAttrs {
return hasNonStreamingInterface() && !hasStreamingBody();
}
- /// \return true if a call from Caller -> Callee requires a change in
- /// streaming mode.
- bool requiresSMChange(const SMEAttrs &Callee) const;
-
// Interfaces to query ZA
static StateValue decodeZAState(unsigned Bitmask) {
return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift);
@@ -104,10 +102,7 @@ class SMEAttrs {
return !hasSharedZAInterface() && !hasAgnosticZAInterface();
}
bool hasZAState() const { return isNewZA() || sharesZA(); }
- bool requiresLazySave(const SMEAttrs &Callee) const {
- return hasZAState() && Callee.hasPrivateZAInterface() &&
- !(Callee.Bitmask & SME_ABI_Routine);
- }
+ bool isSMEABIRoutine() const { return Bitmask & SME_ABI_Routine; }
// Interfaces to query ZT0 State
static StateValue decodeZT0State(unsigned Bitmask) {
@@ -126,27 +121,76 @@ class SMEAttrs {
bool isPreservesZT0() const {
return decodeZT0State(Bitmask) == StateValue::Preserved;
}
- bool isUndefZT0() const { return Bitmask & ZT0_Undef; }
+ bool hasUndefZT0() const { return Bitmask & ZT0_Undef; }
bool sharesZT0() const {
StateValue State = decodeZT0State(Bitmask);
return State == StateValue::In || State == StateValue::Out ||
State == StateValue::InOut || State == StateValue::Preserved;
}
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
- bool requiresPreservingZT0(const SMEAttrs &Callee) const {
- return hasZT0State() && !Callee.isUndefZT0() && !Callee.sharesZT0() &&
- !Callee.hasAgnosticZAInterface();
+
+ SMEAttrs operator|(SMEAttrs Other) const {
+ SMEAttrs Merged(*this);
+ Merged.set(Other.Bitmask, /*Enable=*/true);
+ return Merged;
}
- bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
- return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
- !(Callee.Bitmask & SME_ABI_Routine);
+
+private:
+ void addKnownFunctionAttrs(StringRef FuncName);
+};
+
+/// SMECallAttrs is a utility class to hold the SMEAttrs for a callsite. It has
+/// interfaces to query whether a streaming mode change or lazy-save mechanism
+/// is required when going from one function to another (e.g. through a call).
+class SMECallAttrs {
+ SMEAttrs Caller;
+ SMEAttrs Callee;
+ SMEAttrs Callsite;
+
+public:
+ SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee,
+ SMEAttrs Callsite = SMEAttrs::Normal)
+ : Caller(Caller), Callee(Callee), Callsite(Callsite) {}
+
+ SMECallAttrs(const CallBase &CB);
+
+ SMEAttrs &caller() { return Caller; }
+ SMEAttrs &callee() { return Callee; }
+ SMEAttrs &callsite() { return Callsite; }
+ SMEAttrs const &caller() const { return Caller; }
+ SMEAttrs const &callee() const { return Callee; }
+ SMEAttrs const &callsite() const { return Callsite; }
+ SMEAttrs calleeOrCallsite() const { return Callsite | Callee; }
+
+ /// \return true if a call from Caller -> Callee requires a change in
+ /// streaming mode.
+ bool requiresSMChange() const;
+
+ bool requiresLazySave() const {
+ return Caller.hasZAState() && (Callsite | Callee).hasPrivateZAInterface() &&
+ !Callee.isSMEABIRoutine();
}
- bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
- return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
+
+ bool requiresPreservingZT0() const {
+ return Caller.hasZT0State() && !Callsite.hasUndefZT0() &&
+ !(Callsite | Callee).sharesZT0() &&
+ !(Callsite | Callee).hasAgnosticZAInterface();
}
- bool requiresPreservingAllZAState(const SMEAttrs &Callee) const {
- return hasAgnosticZAInterface() && !Callee.hasAgnosticZAInterface() &&
- !(Callee.Bitmask & SME_ABI_Routine);
+
+ bool requiresDisablingZABeforeCall() const {
+ return Caller.hasZT0State() && !Caller.hasZAState() &&
+ (Callsite | Callee).hasPrivateZAInterface() &&
+ !Callee.isSMEABIRoutine();
+ }
+
+ bool requiresEnablingZAAfterCall() const {
+ return requiresLazySave() || requiresDisablingZABeforeCall();
+ }
+
+ bool requiresPreservingAllZAState() const {
+ return Caller.hasAgnosticZAInterface() &&
+ !(Callsite | Callee).hasAgnosticZAInterface() &&
+ !Callee.isSMEABIRoutine();
}
};
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index f8c77fcba19cf..f13252f3a4c28 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -9,6 +9,7 @@
using namespace llvm;
using SA = SMEAttrs;
+using CA = SMECallAttrs;
std::unique_ptr<Module> parseIR(const char *IR) {
static LLVMContext C;
@@ -70,15 +71,14 @@ TEST(SMEAttributes, Constructors) {
ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_zt0\"")
->getFunction("foo"))
.isNewZT0());
- ASSERT_TRUE(
- SA(cast<CallBase>((parseIR("declare void @callee()\n"
- "define void @foo() {"
- "call void @callee() \"aarch64_zt0_undef\"\n"
- "ret void\n}")
- ->getFunction("foo")
- ->begin()
- ->front())))
- .isUndefZT0());
+
+ auto CallModule = parseIR("declare void @callee()\n"
+ "define void @foo() {"
+ "call void @callee() \"aarch64_zt0_undef\"\n"
+ "ret void\n}");
+ CallBase &Call =
+ cast<CallBase>((CallModule->getFunction("foo")->begin()->front()));
+ ASSERT_TRUE(SMECallAttrs(Call).callsite().hasUndefZT0());
// Invalid combinations.
EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
@@ -235,7 +235,7 @@ TEST(SMEAttributes, Basics) {
ASSERT_TRUE(ZT0_Undef.hasZT0State());
ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface());
ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface());
- ASSERT_TRUE(ZT0_Undef.isUndefZT0());
+ ASSERT_TRUE(ZT0_Undef.hasUndefZT0());
ASSERT_FALSE(SA(SA::Normal).isInZT0());
ASSERT_FALSE(SA(SA::Normal).isOutZT0());
@@ -248,59 +248,57 @@ TEST(SMEAttributes, Basics) {
TEST(SMEAttributes, Transitions) {
// Normal -> Normal
- ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal)));
- ASSERT_FALSE(SA(SA::Normal).requiresPreservingZT0(SA(SA::Normal)));
- ASSERT_FALSE(SA(SA::Normal).requiresDisablingZABeforeCall(SA(SA::Normal)));
- ASSERT_FALSE(SA(SA::Normal).requiresEnablingZAAfterCall(SA(SA::Normal)));
+ ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresSMChange());
+ ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresPreservingZT0());
+ ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresDisablingZABeforeCall());
+ ASSERT_FALSE(CA(SA::Normal, SA::Normal).requiresEnablingZAAfterCall());
// Normal -> Normal + LocallyStreaming
- ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
+ ASSERT_FALSE(CA(SA::Normal, SA::Normal | SA::SM_Body).requiresSMChange());
// Normal -> Streaming
- ASSERT_TRUE(SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled)));
+ ASSERT_TRUE(CA(SA::Normal, SA::SM_Enabled).requiresSMChange());
// Normal -> Streaming + LocallyStreaming
- ASSERT_TRUE(
- SA(SA::Normal).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
+ ASSERT_TRUE(CA(SA::Normal, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
// Normal -> Streaming-compatible
- ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::SM_Compatible)));
+ ASSERT_FALSE(CA(SA::Normal, SA::SM_Compatible).requiresSMChange());
// Normal -> Streaming-compatible + LocallyStreaming
ASSERT_FALSE(
- SA(SA::Normal).requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
+ CA(SA::Normal, SA::SM_Compatible | SA::SM_Body).requiresSMChange());
// Streaming -> Normal
- ASSERT_TRUE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal)));
+ ASSERT_TRUE(CA(SA::SM_Enabled, SA::Normal).requiresSMChange());
// Streaming -> Normal + LocallyStreaming
- ASSERT_TRUE(
- SA(SA::SM_Enabled).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
+ ASSERT_TRUE(CA(SA::SM_Enabled, SA::Normal | SA::SM_Body).requiresSMChange());
// Streaming -> Streaming
- ASSERT_FALSE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Enabled)));
+ ASSERT_FALSE(CA(SA::SM_Enabled, SA::SM_Enabled).requiresSMChange());
// Streaming -> Streaming + LocallyStreaming
ASSERT_FALSE(
- SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
+ CA(SA::SM_Enabled, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
// Streaming -> Streaming-compatible
- ASSERT_FALSE(SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Compatible)));
+ ASSERT_FALSE(CA(SA::SM_Enabled, SA::SM_Compatible).requiresSMChange());
// Streaming -> Streaming-compatible + LocallyStreaming
ASSERT_FALSE(
- SA(SA::SM_Enabled).requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
+ CA(SA::SM_Enabled, SA::SM_Compatible | SA::SM_Body).requiresSMChange());
// Streaming-compatible -> Normal
- ASSERT_TRUE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal)));
+ ASSERT_TRUE(CA(SA::SM_Compatible, SA::Normal).requiresSMChange());
ASSERT_TRUE(
- SA(SA::SM_Compatible).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
+ CA(SA::SM_Compatible, SA::Normal | SA::SM_Body).requiresSMChange());
// Streaming-compatible -> Streaming
- ASSERT_TRUE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled)));
+ ASSERT_TRUE(CA(SA::SM_Compatible, SA::SM_Enabled).requiresSMChange());
// Streaming-compatible -> Streaming + LocallyStreaming
ASSERT_TRUE(
- SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Enabled | SA::SM_Body)));
+ CA(SA::SM_Compatible, SA::SM_Enabled | SA::SM_Body).requiresSMChange());
// Streaming-compatible -> Streaming-compatible
- ASSERT_FALSE(SA(SA::SM_Compatible).requiresSMChange(SA(SA::SM_Compatible)));
+ ASSERT_FALSE(CA(SA::SM_Compatible, SA::SM_Compatible).requiresSMChange());
// Streaming-compatible -> Streaming-compatible + LocallyStreaming
- ASSERT_FALSE(SA(SA::SM_Compatible)
- .requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
+ ASSERT_FALSE(CA(SA::SM_Compatible, SA::SM_Compatible | SA::SM_Body)
+ .requiresSMChange());
SA Private_ZA = SA(SA::Normal);
SA ZA_Shared = SA(SA::encodeZAState(SA::StateValue::In));
@@ -310,37 +308,39 @@ TEST(SMEAttributes, Transitions) {
SA Undef_ZT0 = SA(SA::ZT0_Undef);
// Shared ZA -> Private ZA Interface
- ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
- ASSERT_TRUE(ZA_Shared.requiresEnablingZAAfterCall(Private_ZA));
+ ASSERT_FALSE(CA(ZA_Shared, Private_ZA).requiresDisablingZABeforeCall());
+ ASSERT_TRUE(CA(ZA_Shared, Private_ZA).requiresEnablingZAAfterCall());
// Shared ZT0 -> Private ZA Interface
- ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
- ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
- ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+ ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresDisablingZABeforeCall());
+ ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresPreservingZT0());
+ ASSERT_TRUE(CA(ZT0_Shared, Private_ZA).requiresEnablingZAAfterCall());
// Shared Undef ZT0 -> Private ZA Interface
// Note: "Undef ZT0" is a callsite attribute that means ZT0 is undefined at
// point the of the call.
- ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Undef_ZT0));
- ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(Undef_ZT0));
- ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Undef_ZT0));
+ ASSERT_TRUE(
+ CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresDisablingZABeforeCall());
+ ASSERT_FALSE(CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresPreservingZT0());
+ ASSERT_TRUE(
+ CA(ZT0_Shared, Private_ZA, Undef_ZT0).requiresEnablingZAAfterCall());
// Shared ZA & ZT0 -> Private ZA Interface
- ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
- ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
- ASSERT_TRUE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
+ ASSERT_FALSE(CA(ZA_ZT0_Shared, Private_ZA).requiresDisablingZABeforeCall());
+ ASSERT_TRUE(CA(ZA_ZT0_Shared, Private_ZA).requiresPreservingZT0());
+ ASSERT_TRUE(CA(ZA_ZT0_Shared, Private_ZA).requiresEnablingZAAfterCall());
// Shared ZA -> Shared ZA Interface
- ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
- ASSERT_FALSE(ZA_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+ ASSERT_FALSE(CA(ZA_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
+ ASSERT_FALSE(CA(ZA_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
// Shared ZT0 -> Shared ZA Interface
- ASSERT_FALSE(ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
- ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
- ASSERT_FALSE(ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+ ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
+ ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresPreservingZT0());
+ ASSERT_FALSE(CA(ZT0_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
// Shared ZA & ZT0 -> Shared ZA Interface
- ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
- ASSERT_FALSE(ZA_ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
- ASSERT_FALSE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
+ ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresDisablingZABeforeCall());
+ ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresPreservingZT0());
+ ASSERT_FALSE(CA(ZA_ZT0_Shared, ZT0_Shared).requiresEnablingZAAfterCall());
}
>From 6017c7ed14cb6cf37311dba342d2900367364396 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 25 Apr 2025 11:34:59 +0000
Subject: [PATCH 5/5] [AArch64][SME] Disallow SME attributes on direct function
calls
This was only used in a handful of tests (mainly to avoid making
multiple function declarations). These tests can easily be updated to
use indirect calls or attributes on declarations.
This allows us to remove checks that looked at both the "callee" and
"callsite" attributes, which makes the API of SMECallAttrs a clearer
and less error-prone (as you can't accidentally use .callee() when you
should have used .calleeOrCallsite()).
Note: This currently still allows non-conflicting attributes on direct
calls (as clang currently duplicates streaming mode attributes at each
callsite).
---
.../Target/AArch64/AArch64ISelLowering.cpp | 16 ++---
.../AArch64/AArch64TargetTransformInfo.cpp | 2 +-
.../AArch64/Utils/AArch64SMEAttributes.cpp | 20 ++++--
.../AArch64/Utils/AArch64SMEAttributes.h | 52 ++++++++------
.../test/CodeGen/AArch64/sme-peephole-opts.ll | 23 ++++---
llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll | 4 +-
llvm/test/CodeGen/AArch64/sme-zt0-state.ll | 68 +++++++++----------
7 files changed, 101 insertions(+), 84 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c445210cc3d4e..f4e6f7182784e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8979,9 +8979,9 @@ static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
CallAttrs.caller().hasStreamingBody())
return AArch64SME::Always;
- if (CallAttrs.calleeOrCallsite().hasNonStreamingInterface())
+ if (CallAttrs.callee().hasNonStreamingInterface())
return AArch64SME::IfCallerIsStreaming;
- if (CallAttrs.calleeOrCallsite().hasStreamingInterface())
+ if (CallAttrs.callee().hasStreamingInterface())
return AArch64SME::IfCallerIsNonStreaming;
llvm_unreachable("Unsupported attributes");
@@ -9158,7 +9158,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return DescribeCallsite(R) << " sets up a lazy save for ZA";
});
} else if (RequiresSaveAllZA) {
- assert(!CallAttrs.calleeOrCallsite().hasSharedZAInterface() &&
+ assert(!CallAttrs.callee().hasSharedZAInterface() &&
"Cannot share state that may not exist");
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
/*IsSave=*/true);
@@ -9484,9 +9484,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
InGlue = Chain.getValue(1);
}
- SDValue NewChain = changeStreamingMode(
- DAG, DL, CallAttrs.calleeOrCallsite().hasStreamingInterface(), Chain,
- InGlue, getSMCondition(CallAttrs), PStateSM);
+ SDValue NewChain =
+ changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
+ Chain, InGlue, getSMCondition(CallAttrs), PStateSM);
Chain = NewChain.getValue(0);
InGlue = NewChain.getValue(1);
}
@@ -9665,8 +9665,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (RequiresSMChange) {
assert(PStateSM && "Expected a PStateSM to be set");
Result = changeStreamingMode(
- DAG, DL, !CallAttrs.calleeOrCallsite().hasStreamingInterface(), Result,
- InGlue, getSMCondition(CallAttrs), PStateSM);
+ DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
+ getSMCondition(CallAttrs), PStateSM);
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
InGlue = Result.getValue(1);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 84d33b2b3e128..374d9e4d7822d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -352,7 +352,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
SMEAttrs FAttrs(*F);
SMECallAttrs CallAttrs(Call);
- if (SMECallAttrs(FAttrs, CallAttrs.calleeOrCallsite()).requiresSMChange()) {
+ if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
if (F == Call.getCaller()) // (1)
return CallPenaltyChangeSM * DefaultCallPenalty;
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 1085d618116eb..16ae5434e596a 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -94,22 +94,28 @@ void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
}
bool SMECallAttrs::requiresSMChange() const {
- if ((Callsite | Callee).hasStreamingCompatibleInterface())
+ if (callee().hasStreamingCompatibleInterface())
return false;
// Both non-streaming
- if (Caller.hasNonStreamingInterfaceAndBody() &&
- (Callsite | Callee).hasNonStreamingInterface())
+ if (caller().hasNonStreamingInterfaceAndBody() &&
+ callee().hasNonStreamingInterface())
return false;
// Both streaming
- if (Caller.hasStreamingInterfaceOrBody() &&
- (Callsite | Callee).hasStreamingInterface())
+ if (caller().hasStreamingInterfaceOrBody() &&
+ callee().hasStreamingInterface())
return false;
return true;
}
SMECallAttrs::SMECallAttrs(const CallBase &CB)
- : SMECallAttrs(*CB.getFunction(), CB.getCalledFunction(),
- CB.getAttributes()) {}
+ : CallerFn(*CB.getFunction()), CalledFn(CB.getCalledFunction()),
+ Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
+ // FIXME: We probably should not allow SME attributes on direct calls but
+ // clang duplicates streaming mode attributes at each callsite.
+ assert((IsIndirect ||
+ ((Callsite.withoutPerCallsiteFlags() | CalledFn) == CalledFn)) &&
+ "SME attributes at callsite do not match declaration");
+}
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 791bb891a18b7..628c55ce3cbaa 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -44,7 +44,8 @@ class SMEAttrs {
ZA_Shift = 6,
ZA_Mask = 0b111 << ZA_Shift,
ZT0_Shift = 9,
- ZT0_Mask = 0b111 << ZT0_Shift
+ ZT0_Mask = 0b111 << ZT0_Shift,
+ Callsite_Flags = ZT0_Undef
};
SMEAttrs() = default;
@@ -135,6 +136,14 @@ class SMEAttrs {
return Merged;
}
+ SMEAttrs withoutPerCallsiteFlags() const {
+ return (Bitmask & ~Callsite_Flags);
+ }
+
+ bool operator==(SMEAttrs const &Other) const {
+ return Bitmask == Other.Bitmask;
+ }
+
private:
void addKnownFunctionAttrs(StringRef FuncName);
};
@@ -143,44 +152,48 @@ class SMEAttrs {
/// interfaces to query whether a streaming mode change or lazy-save mechanism
/// is required when going from one function to another (e.g. through a call).
class SMECallAttrs {
- SMEAttrs Caller;
- SMEAttrs Callee;
+ SMEAttrs CallerFn;
+ SMEAttrs CalledFn;
SMEAttrs Callsite;
+ bool IsIndirect = false;
public:
SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee,
SMEAttrs Callsite = SMEAttrs::Normal)
- : Caller(Caller), Callee(Callee), Callsite(Callsite) {}
+ : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}
SMECallAttrs(const CallBase &CB);
- SMEAttrs &caller() { return Caller; }
- SMEAttrs &callee() { return Callee; }
+ SMEAttrs &caller() { return CallerFn; }
+ SMEAttrs &callee() {
+ if (IsIndirect)
+ return Callsite;
+ return CalledFn;
+ }
SMEAttrs &callsite() { return Callsite; }
- SMEAttrs const &caller() const { return Caller; }
- SMEAttrs const &callee() const { return Callee; }
+ SMEAttrs const &caller() const { return CallerFn; }
+ SMEAttrs const &callee() const {
+ return const_cast<SMECallAttrs *>(this)->callee();
+ }
SMEAttrs const &callsite() const { return Callsite; }
- SMEAttrs calleeOrCallsite() const { return Callsite | Callee; }
/// \return true if a call from Caller -> Callee requires a change in
/// streaming mode.
bool requiresSMChange() const;
bool requiresLazySave() const {
- return Caller.hasZAState() && (Callsite | Callee).hasPrivateZAInterface() &&
- !Callee.isSMEABIRoutine();
+ return caller().hasZAState() && callee().hasPrivateZAInterface() &&
+ !callee().isSMEABIRoutine();
}
bool requiresPreservingZT0() const {
- return Caller.hasZT0State() && !Callsite.hasUndefZT0() &&
- !(Callsite | Callee).sharesZT0() &&
- !(Callsite | Callee).hasAgnosticZAInterface();
+ return caller().hasZT0State() && !callsite().hasUndefZT0() &&
+ !callee().sharesZT0() && !callee().hasAgnosticZAInterface();
}
bool requiresDisablingZABeforeCall() const {
- return Caller.hasZT0State() && !Caller.hasZAState() &&
- (Callsite | Callee).hasPrivateZAInterface() &&
- !Callee.isSMEABIRoutine();
+ return caller().hasZT0State() && !caller().hasZAState() &&
+ callee().hasPrivateZAInterface() && !callee().isSMEABIRoutine();
}
bool requiresEnablingZAAfterCall() const {
@@ -188,9 +201,8 @@ class SMECallAttrs {
}
bool requiresPreservingAllZAState() const {
- return Caller.hasAgnosticZAInterface() &&
- !(Callsite | Callee).hasAgnosticZAInterface() &&
- !Callee.isSMEABIRoutine();
+ return caller().hasAgnosticZAInterface() &&
+ !callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
}
};
diff --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
index 6ea2267cd22e6..130a316bcc2ba 100644
--- a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
@@ -2,11 +2,12 @@
; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-streaming-hazard-size=0 -mattr=+sve,+sme2 < %s | FileCheck %s
declare void @callee()
+declare void @callee_sm() "aarch64_pstate_sm_enabled"
declare void @callee_farg(float)
declare float @callee_farg_fret(float)
; normal caller -> streaming callees
-define void @test0() nounwind {
+define void @test0(ptr %callee) nounwind {
; CHECK-LABEL: test0:
; CHECK: // %bb.0:
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
@@ -16,8 +17,8 @@ define void @test0() nounwind {
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: stp x30, x9, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: smstart sm
-; CHECK-NEXT: bl callee
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: bl callee_sm
+; CHECK-NEXT: bl callee_sm
; CHECK-NEXT: smstop sm
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
@@ -25,8 +26,8 @@ define void @test0() nounwind {
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee() "aarch64_pstate_sm_enabled"
- call void @callee() "aarch64_pstate_sm_enabled"
+ call void @callee_sm()
+ call void @callee_sm()
ret void
}
@@ -118,7 +119,7 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB3_2:
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: bl callee_sm
; CHECK-NEXT: tbnz w19, #0, .LBB3_4
; CHECK-NEXT: // %bb.3:
; CHECK-NEXT: smstop sm
@@ -140,7 +141,7 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
; CHECK-NEXT: // %bb.9:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB3_10:
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: bl callee_sm
; CHECK-NEXT: tbnz w19, #0, .LBB3_12
; CHECK-NEXT: // %bb.11:
; CHECK-NEXT: smstop sm
@@ -152,9 +153,9 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: ldp d15, d14, [sp], #96 // 16-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee() "aarch64_pstate_sm_enabled"
+ call void @callee_sm()
call void @callee()
- call void @callee() "aarch64_pstate_sm_enabled"
+ call void @callee_sm()
ret void
}
@@ -342,7 +343,7 @@ define void @test10() "aarch64_pstate_sm_body" {
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .cfi_restore vg
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: bl callee_sm
; CHECK-NEXT: .cfi_offset vg, -24
; CHECK-NEXT: smstop sm
; CHECK-NEXT: bl callee
@@ -363,7 +364,7 @@ define void @test10() "aarch64_pstate_sm_body" {
; CHECK-NEXT: .cfi_restore b15
; CHECK-NEXT: ret
call void @callee()
- call void @callee() "aarch64_pstate_sm_enabled"
+ call void @callee_sm()
call void @callee()
ret void
}
diff --git a/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll b/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
index 17d689d2c9eb5..0853325e449af 100644
--- a/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
+++ b/llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll
@@ -1098,11 +1098,11 @@ define void @test_rdsvl_right_after_prologue(i64 %x0) nounwind {
; NO-SVE-CHECK-NEXT: ret
%some_alloc = alloca i64, align 8
%rdsvl = tail call i64 @llvm.aarch64.sme.cntsd()
- call void @bar(i64 %rdsvl, i64 %x0) "aarch64_pstate_sm_enabled"
+ call void @bar(i64 %rdsvl, i64 %x0)
ret void
}
-declare void @bar(i64, i64)
+declare void @bar(i64, i64) "aarch64_pstate_sm_enabled"
; Ensure we still emit async unwind information with -fno-asynchronous-unwind-tables
; if the function contains a streaming-mode change.
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 7361e850d713e..63577e4d217a8 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -1,15 +1,13 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
-declare void @callee();
-
;
; Private-ZA Callee
;
; Expect spill & fill of ZT0 around call
; Expect smstop/smstart za around call
-define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
+define void @zt0_in_caller_no_state_callee(ptr %callee) "aarch64_in_zt0" nounwind {
; CHECK-LABEL: zt0_in_caller_no_state_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #80
@@ -17,20 +15,20 @@ define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
- call void @callee();
+ call void %callee();
ret void;
}
; Expect spill & fill of ZT0 around call
; Expect setup and restore lazy-save around call
; Expect smstart za after call
-define void @za_zt0_shared_caller_no_state_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_no_state_callee(ptr %callee) "aarch64_inout_za" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
@@ -49,7 +47,7 @@ define void @za_zt0_shared_caller_no_state_callee() "aarch64_inout_za" "aarch64_
; CHECK-NEXT: sturh w8, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x9
; CHECK-NEXT: str zt0, [x19]
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: mrs x8, TPIDR2_EL0
@@ -63,7 +61,7 @@ define void @za_zt0_shared_caller_no_state_callee() "aarch64_inout_za" "aarch64_
; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee();
+ call void %callee();
ret void;
}
@@ -72,43 +70,43 @@ define void @za_zt0_shared_caller_no_state_callee() "aarch64_inout_za" "aarch64_
;
; Caller and callee have shared ZT0 state, no spill/fill of ZT0 required
-define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
+define void @zt0_shared_caller_zt0_shared_callee(ptr %callee) "aarch64_in_zt0" nounwind {
; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee() "aarch64_in_zt0";
+ call void %callee() "aarch64_in_zt0";
ret void;
}
; Expect spill & fill of ZT0 around call
-define void @za_zt0_shared_caller_za_shared_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_za_shared_callee(ptr %callee) "aarch64_inout_za" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
- call void @callee() "aarch64_inout_za";
+ call void %callee() "aarch64_inout_za";
ret void;
}
; Caller and callee have shared ZA & ZT0
-define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_inout_za" "aarch64_in_zt0" nounwind {
+define void @za_zt0_shared_caller_za_zt0_shared_callee(ptr %callee) "aarch64_inout_za" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
+ call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
ret void;
}
@@ -116,7 +114,7 @@ define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_inout_za" "aar
; Expect spill & fill of ZT0 around call
; Expect smstop/smstart za around call
-define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
+define void @zt0_in_caller_zt0_new_callee(ptr %callee) "aarch64_in_zt0" nounwind {
; CHECK-LABEL: zt0_in_caller_zt0_new_callee:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #80
@@ -124,13 +122,13 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
- call void @callee() "aarch64_new_zt0";
+ call void %callee() "aarch64_new_zt0";
ret void;
}
@@ -140,7 +138,7 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
; Expect smstart ZA & clear ZT0
; Expect spill & fill of ZT0 around call
; Before return, expect smstop ZA
-define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind {
+define void @zt0_new_caller_zt0_new_callee(ptr %callee) "aarch64_new_zt0" nounwind {
; CHECK-LABEL: zt0_new_caller_zt0_new_callee:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: sub sp, sp, #80
@@ -156,14 +154,14 @@ define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind {
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
- call void @callee() "aarch64_new_zt0";
+ call void %callee() "aarch64_new_zt0";
ret void;
}
@@ -207,7 +205,7 @@ declare {i64, i64} @__arm_sme_state()
; Expect commit of lazy-save if ZA is dormant
; Expect smstart ZA & clear ZT0
; Before return, expect smstop ZA
-define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
+define void @zt0_new_caller(ptr %callee) "aarch64_new_zt0" nounwind {
; CHECK-LABEL: zt0_new_caller:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
@@ -219,18 +217,18 @@ define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
; CHECK-NEXT: .LBB8_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee() "aarch64_in_zt0";
+ call void %callee() "aarch64_in_zt0";
ret void;
}
; Expect commit of lazy-save if ZA is dormant
; Expect smstart ZA, clear ZA & clear ZT0
; Before return, expect smstop ZA
-define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
+define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" nounwind {
; CHECK-LABEL: new_za_zt0_caller:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
@@ -243,36 +241,36 @@ define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero {za}
; CHECK-NEXT: zero { zt0 }
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
+ call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
ret void;
}
; Expect clear ZA on entry
-define void @new_za_shared_zt0_caller() "aarch64_new_za" "aarch64_in_zt0" nounwind {
+define void @new_za_shared_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: new_za_shared_zt0_caller:
; CHECK: // %bb.0:
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: zero {za}
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
+ call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
ret void;
}
; Expect clear ZT0 on entry
-define void @shared_za_new_zt0() "aarch64_inout_za" "aarch64_new_zt0" nounwind {
+define void @shared_za_new_zt0(ptr %callee) "aarch64_inout_za" "aarch64_new_zt0" nounwind {
; CHECK-LABEL: shared_za_new_zt0:
; CHECK: // %bb.0:
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: zero { zt0 }
-; CHECK-NEXT: bl callee
+; CHECK-NEXT: blr x0
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
- call void @callee() "aarch64_inout_za" "aarch64_in_zt0";
+ call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
ret void;
}
More information about the llvm-commits
mailing list