[llvm] [AArch64][SME] Store SME attributes in AArch64FunctionInfo (NFC) (PR #142362)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 2 04:11:14 PDT 2025
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/142362
The SMEAttrs class is tiny (simply a wrapper around a bitmask). Constructing SMEAttrs from a llvm::Function is relatively expensive (as we have to redo the checks for every SME attribute). So let's just construct the SMEAttrs as part of the AArch64FunctionInfo and reuse the parsed attributes where possible.
>From 1153997b9f51af49095d3d8cf32758c3ce45efd3 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 2 Jun 2025 10:58:09 +0000
Subject: [PATCH] [AArch64][SME] Store SME attributes in AArch64FunctionInfo
(NFC)
The SMEAttrs class is tiny (simply a wrapper around a bitmask).
Constructing SMEAttrs from a llvm::Function is relatively expensive (as
we have to redo the checks for every SME attribute). So let's just
construct the SMEAttrs as part of the AArch64FunctionInfo and reuse the
parsed attributes where possible.
---
llvm/lib/Target/AArch64/AArch64FastISel.cpp | 3 ++-
.../Target/AArch64/AArch64FrameLowering.cpp | 26 +++++++++----------
.../Target/AArch64/AArch64ISelLowering.cpp | 8 +++---
.../AArch64/AArch64MachineFunctionInfo.cpp | 3 +++
.../AArch64/AArch64MachineFunctionInfo.h | 6 +++++
.../Target/AArch64/AArch64RegisterInfo.cpp | 5 ++--
.../AArch64/AArch64SelectionDAGInfo.cpp | 11 +++++---
.../AArch64/GISel/AArch64CallLowering.cpp | 2 +-
8 files changed, 38 insertions(+), 26 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 5ddf83f45ac69..bb7e6b662f80e 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5198,7 +5198,8 @@ bool AArch64FastISel::fastSelectInstruction(const Instruction *I) {
FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
const TargetLibraryInfo *LibInfo) {
- SMEAttrs CallerAttrs(*FuncInfo.Fn);
+ SMEAttrs CallerAttrs =
+ FuncInfo.MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
CallerAttrs.hasStreamingCompatibleInterface() ||
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index 0f33e77d4eecc..c22dbb9bf0067 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -595,7 +595,7 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
MachineFunction &MF = *MBB.getParent();
MachineFrameInfo &MFI = MF.getFrameInfo();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
bool LocallyStreaming =
Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface();
@@ -2887,7 +2887,7 @@ bool enableMultiVectorSpillFill(const AArch64Subtarget &Subtarget,
if (DisableMultiVectorSpillFill)
return false;
- SMEAttrs FuncAttrs(MF.getFunction());
+ SMEAttrs FuncAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
bool IsLocallyStreaming =
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
@@ -3210,7 +3210,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
// Find an available register to store value of VG to.
Reg1 = findScratchNonCalleeSaveRegister(&MBB);
assert(Reg1 != AArch64::NoRegister);
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() &&
AFI->getStreamingVGIdx() == std::numeric_limits<int>::max()) {
@@ -3539,12 +3539,13 @@ static std::optional<int> getLdStFrameID(const MachineInstr &MI,
void AArch64FrameLowering::determineStackHazardSlot(
MachineFunction &MF, BitVector &SavedRegs) const {
unsigned StackHazardSize = getStackHazardSize(MF);
+ auto *AFI = MF.getInfo<AArch64FunctionInfo>();
if (StackHazardSize == 0 || StackHazardSize % 16 != 0 ||
- MF.getInfo<AArch64FunctionInfo>()->hasStackHazardSlotIndex())
+ AFI->hasStackHazardSlotIndex())
return;
// Stack hazards are only needed in streaming functions.
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (!StackHazardInNonStreaming && Attrs.hasNonStreamingInterfaceAndBody())
return;
@@ -3581,7 +3582,7 @@ void AArch64FrameLowering::determineStackHazardSlot(
int ID = MFI.CreateStackObject(StackHazardSize, Align(16), false);
LLVM_DEBUG(dbgs() << "Created Hazard slot at " << ID << " size "
<< StackHazardSize << "\n");
- MF.getInfo<AArch64FunctionInfo>()->setStackHazardSlotIndex(ID);
+ AFI->setStackHazardSlotIndex(ID);
}
}
@@ -3734,8 +3735,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
// changes, as we will need to spill the value of the VG register.
// For locally streaming functions, we spill both the streaming and
// non-streaming VG value.
- const Function &F = MF.getFunction();
- SMEAttrs Attrs(F);
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (requiresSaveVG(MF)) {
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
CSStackSize += 16;
@@ -3892,7 +3892,7 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
// Insert VG into the list of CSRs, immediately before LR if saved.
if (requiresSaveVG(MF)) {
std::vector<CalleeSavedInfo> VGSaves;
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
auto VGInfo = CalleeSavedInfo(AArch64::VG);
VGInfo.setRestored(false);
@@ -4909,10 +4909,10 @@ static void emitVGSaveRestore(MachineBasicBlock::iterator II,
MI.getOpcode() != AArch64::VGRestorePseudo)
return;
- SMEAttrs FuncAttrs(MF->getFunction());
+ auto *AFI = MF->getInfo<AArch64FunctionInfo>();
+ SMEAttrs FuncAttrs = AFI->getSMEFnAttrs();
bool LocallyStreaming =
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
- const AArch64FunctionInfo *AFI = MF->getInfo<AArch64FunctionInfo>();
int64_t VGFrameIdx =
LocallyStreaming ? AFI->getStreamingVGIdx() : AFI->getVGIdx();
@@ -5402,8 +5402,8 @@ static inline raw_ostream &operator<<(raw_ostream &OS, const StackAccess &SA) {
void AArch64FrameLowering::emitRemarks(
const MachineFunction &MF, MachineOptimizationRemarkEmitter *ORE) const {
- SMEAttrs Attrs(MF.getFunction());
- if (Attrs.hasNonStreamingInterfaceAndBody())
+ auto *AFI = MF.getInfo<AArch64FunctionInfo>();
+ if (AFI->getSMEFnAttrs().hasNonStreamingInterfaceAndBody())
return;
unsigned StackHazardSize = getStackHazardSize(MF);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ae34e6b7dcc3c..4dd9c513120bb 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7751,7 +7751,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
(void)Res;
}
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = FuncInfo->getSMEFnAttrs();
bool IsLocallyStreaming =
!Attrs.hasStreamingInterface() && Attrs.hasStreamingBody();
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
@@ -8105,7 +8105,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
// Create a 16 Byte TPIDR2 object. The dynamic buffer
// will be expanded and stored in the static object later using a pseudonode.
- if (SMEAttrs(MF.getFunction()).hasZAState()) {
+ if (Attrs.hasZAState()) {
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
@@ -8125,7 +8125,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
- } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
+ } else if (Attrs.hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize =
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
@@ -9610,7 +9610,7 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
// Emit SMSTOP before returning from a locally streaming function
- SMEAttrs FuncAttrs(MF.getFunction());
+ SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
if (FuncAttrs.hasStreamingCompatibleInterface()) {
Register Reg = FuncInfo->getPStateSMReg();
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
index 5bcff61cef4b1..4b04b80121ffa 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
@@ -100,6 +100,9 @@ AArch64FunctionInfo::AArch64FunctionInfo(const Function &F,
BranchTargetEnforcement = F.hasFnAttribute("branch-target-enforcement");
BranchProtectionPAuthLR = F.hasFnAttribute("branch-protection-pauth-lr");
+ // Parse the SME function attributes.
+ SMEFnAttrs = SMEAttrs(F);
+
// The default stack probe size is 4096 if the function has no
// stack-probe-size attribute. This is a safe default because it is the
// smallest possible guard page size.
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index d3026ca45c349..361d5ec3f2b22 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -14,6 +14,7 @@
#define LLVM_LIB_TARGET_AARCH64_AARCH64MACHINEFUNCTIONINFO_H
#include "AArch64Subtarget.h"
+#include "Utils/AArch64SMEAttributes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -245,6 +246,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t VGIdx = std::numeric_limits<int>::max();
int64_t StreamingVGIdx = std::numeric_limits<int>::max();
+ // Holds the SME function attributes (streaming mode, ZA/ZT0 state).
+ SMEAttrs SMEFnAttrs;
+
public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
@@ -449,6 +453,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
StackHazardCSRSlotIndex = Index;
}
+ SMEAttrs getSMEFnAttrs() const { return SMEFnAttrs; }
+
unsigned getSRetReturnReg() const { return SRetReturnReg; }
void setSRetReturnReg(unsigned Reg) { SRetReturnReg = Reg; }
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index 1afe23e637e8d..2310c19356968 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -648,9 +648,8 @@ bool AArch64RegisterInfo::hasBasePointer(const MachineFunction &MF) const {
// Since hasBasePointer() is called before we know if we have hazard padding
// or an emergency spill slot we need to enable the basepointer
// conservatively.
- if (AFI->hasStackHazardSlotIndex() ||
- (ST.getStreamingHazardSize() &&
- !SMEAttrs(MF.getFunction()).hasNonStreamingInterfaceAndBody())) {
+ if (ST.getStreamingHazardSize() &&
+ !AFI->getSMEFnAttrs().hasNonStreamingInterfaceAndBody()) {
return true;
}
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index 0d368b7c280c8..90f6fc2ea664b 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -11,8 +11,8 @@
//===----------------------------------------------------------------------===//
#include "AArch64SelectionDAGInfo.h"
+#include "AArch64MachineFunctionInfo.h"
#include "AArch64TargetMachine.h"
-#include "Utils/AArch64SMEAttributes.h"
#define GET_SDNODE_DESC
#include "AArch64GenSDNodeInfo.inc"
@@ -227,7 +227,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
return EmitMOPS(AArch64::MOPSMemoryCopyPseudo, DAG, DL, Chain, Dst, Src,
Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
- SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
RTLIB::MEMCPY);
@@ -246,7 +247,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
Size, Alignment, isVolatile, DstPtrInfo,
MachinePointerInfo{});
- SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
RTLIB::MEMSET);
@@ -264,7 +266,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
return EmitMOPS(AArch64::MOPSMemoryMovePseudo, DAG, dl, Chain, Dst, Src,
Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
- SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
RTLIB::MEMMOVE);
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 9bef102e8abf1..fd77571fe1c52 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -539,7 +539,7 @@ bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
return true;
}
- SMEAttrs Attrs(F);
+ SMEAttrs Attrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
if (Attrs.hasZAState() || Attrs.hasZT0State() ||
Attrs.hasStreamingInterfaceOrBody() ||
Attrs.hasStreamingCompatibleInterface())
More information about the llvm-commits
mailing list