[llvm] [AArch64][SME] Refactor MachineSMEABI pass state (NFCI) (PR #156674)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 11 05:53:25 PDT 2025
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/156674
>From 8320082a32e713cfa500584227838556361568c9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 15 Jul 2025 11:47:48 +0000
Subject: [PATCH 1/2] [AArch64][SME] Refactor MachineSMEABI pass state (NFCI)
This removes the pass state (aside from target classes) from the
MachineSMEABI class, and instead passes/returns state between functions.
The intention is to make dataflow (and where state is mutated) more
apparent.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 1 +
llvm/lib/Target/AArch64/MachineSMEABIPass.cpp | 284 ++++++++++--------
2 files changed, 159 insertions(+), 126 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5ffaf2c49b4c0..b90b7f50856b0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -9312,6 +9312,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
std::optional<unsigned> ZAMarkerNode;
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
+
if (UseNewSMEABILowering) {
if (CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState())
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index c39a5cc2fcb16..fd07c939bbb0c 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -110,6 +110,65 @@ struct PhysRegSave {
Register X0Save = AArch64::NoRegister;
};
+/// Contains the needed ZA state (and live registers) at an instruction.
+struct InstInfo {
+ ZAState NeededState{ZAState::ANY};
+ MachineBasicBlock::iterator InsertPt;
+ LiveRegs PhysLiveRegs = LiveRegs::None;
+};
+
+/// Contains the needed ZA state for each instruction in a block. Instructions
+/// that do not require a ZA state are not recorded.
+struct BlockInfo {
+ ZAState FixedEntryState{ZAState::ANY};
+ SmallVector<InstInfo> Insts;
+ LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
+ LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
+};
+
+/// Contains the needed ZA state information for all blocks within a function.
+struct FunctionInfo {
+ SmallVector<BlockInfo> Blocks;
+ std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
+};
+
+/// State/helpers that is only needed when emitting code to handle
+/// saving/restoring ZA.
+struct EmitContext {
+ EmitContext() = default;
+
+ /// Get or create a TPIDR2 block in \p MF.
+ int getTPIDR2Block(MachineFunction &MF) {
+ if (TPIDR2BlockFI)
+ return *TPIDR2BlockFI;
+ MachineFrameInfo &MFI = MF.getFrameInfo();
+ TPIDR2BlockFI = MFI.CreateStackObject(16, Align(16), false);
+ return *TPIDR2BlockFI;
+ }
+
+ /// Get or create agnostic ZA buffer pointer in \p MF.
+ Register getAgnosticZABufferPtr(MachineFunction &MF) {
+ if (AgnosticZABufferPtr != AArch64::NoRegister)
+ return AgnosticZABufferPtr;
+ Register BufferPtr =
+ MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
+ AgnosticZABufferPtr =
+ BufferPtr != AArch64::NoRegister
+ ? BufferPtr
+ : MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+ return AgnosticZABufferPtr;
+ }
+
+ bool needsSaveBuffer() const {
+ return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
+ }
+
+private:
+ std::optional<int> TPIDR2BlockFI;
+ Register AgnosticZABufferPtr = AArch64::NoRegister;
+};
+
static bool isLegalEdgeBundleZAState(ZAState State) {
switch (State) {
case ZAState::ACTIVE:
@@ -119,9 +178,6 @@ static bool isLegalEdgeBundleZAState(ZAState State) {
return false;
}
}
-struct TPIDR2State {
- int FrameIndex = -1;
-};
StringRef getZAStateString(ZAState State) {
#define MAKE_CASE(V) \
@@ -192,25 +248,28 @@ struct MachineSMEABI : public MachineFunctionPass {
/// Collects the needed ZA state (and live registers) before each instruction
/// within the machine function.
- void collectNeededZAStates(SMEAttrs);
+ FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs);
/// Assigns each edge bundle a ZA state based on the needed states of blocks
/// that have incoming or outgoing edges in that bundle.
- void assignBundleZAStates();
+ SmallVector<ZAState> assignBundleZAStates(EdgeBundles const &Bundles,
+ FunctionInfo const &FnInfo);
/// Inserts code to handle changes between ZA states within the function.
/// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
- void insertStateChanges();
+ void insertStateChanges(EmitContext &, FunctionInfo const &FnInfo,
+ EdgeBundles const &Bundles,
+ ArrayRef<ZAState> BundleStates);
// Emission routines for private and shared ZA functions (using lazy saves).
void emitNewZAPrologue(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
- void emitRestoreLazySave(MachineBasicBlock &MBB,
+ void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs);
- void emitSetupLazySave(MachineBasicBlock &MBB,
+ void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
- void emitAllocateLazySaveBuffer(MachineBasicBlock &MBB,
+ void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2);
@@ -222,35 +281,38 @@ struct MachineSMEABI : public MachineFunctionPass {
// Emit a "full" ZA save or restore. It is "full" in the sense that this
// function will emit a call to __arm_sme_save or __arm_sme_restore, which
// handles saving and restoring both ZA and ZT0.
- void emitFullZASaveRestore(MachineBasicBlock &MBB,
+ void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsSave);
- void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
+ void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs);
- void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
- ZAState From, ZAState To, LiveRegs PhysLiveRegs);
+ void emitStateChange(EmitContext &, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI, ZAState From,
+ ZAState To, LiveRegs PhysLiveRegs);
// Helpers for switching between lazy/full ZA save/restore routines.
- void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
- LiveRegs PhysLiveRegs) {
+ void emitZASave(EmitContext &Context, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
- return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
- return emitSetupLazySave(MBB, MBBI);
+ return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
+ /*IsSave=*/true);
+ return emitSetupLazySave(Context, MBB, MBBI);
}
- void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
- LiveRegs PhysLiveRegs) {
+ void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
- return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
- return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
+ return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
+ /*IsSave=*/false);
+ return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs);
}
- void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
+ void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs) {
if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
- return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
- return emitAllocateLazySaveBuffer(MBB, MBBI);
+ return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs);
+ return emitAllocateLazySaveBuffer(Context, MBB, MBBI);
}
/// Save live physical registers to virtual registers.
@@ -260,40 +322,8 @@ struct MachineSMEABI : public MachineFunctionPass {
void restorePhyRegSave(PhysRegSave const &RegSave, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, DebugLoc DL);
- /// Get or create a TPIDR2 block in this function.
- TPIDR2State getTPIDR2Block();
-
- Register getAgnosticZABufferPtr();
-
private:
- /// Contains the needed ZA state (and live registers) at an instruction.
- struct InstInfo {
- ZAState NeededState{ZAState::ANY};
- MachineBasicBlock::iterator InsertPt;
- LiveRegs PhysLiveRegs = LiveRegs::None;
- };
-
- /// Contains the needed ZA state for each instruction in a block.
- /// Instructions that do not require a ZA state are not recorded.
- struct BlockInfo {
- ZAState FixedEntryState{ZAState::ANY};
- SmallVector<InstInfo> Insts;
- LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
- LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
- };
-
- // All pass state that must be cleared between functions.
- struct PassState {
- SmallVector<BlockInfo> Blocks;
- SmallVector<ZAState> BundleStates;
- std::optional<TPIDR2State> TPIDR2Block;
- std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
- Register AgnosticZABufferPtr = AArch64::NoRegister;
- LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
- } State;
-
MachineFunction *MF = nullptr;
- EdgeBundles *Bundles = nullptr;
const AArch64Subtarget *Subtarget = nullptr;
const AArch64RegisterInfo *TRI = nullptr;
const AArch64FunctionInfo *AFI = nullptr;
@@ -301,14 +331,18 @@ struct MachineSMEABI : public MachineFunctionPass {
MachineRegisterInfo *MRI = nullptr;
};
-void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
+FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
SMEFnAttrs.hasZAState()) &&
"Expected function to have ZA/ZT0 state!");
- State.Blocks.resize(MF->getNumBlockIDs());
+ SmallVector<BlockInfo> Blocks(MF->getNumBlockIDs());
+ LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
+ std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
+
for (MachineBasicBlock &MBB : *MF) {
- BlockInfo &Block = State.Blocks[MBB.getNumber()];
+ BlockInfo &Block = Blocks[MBB.getNumber()];
+
if (MBB.isEntryBlock()) {
// Entry block:
Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface()
@@ -347,8 +381,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
// allocation -- which is a safe point for this pass to insert any TPIDR2
// block setup.
if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
- State.AfterSMEProloguePt = MBBI;
- State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
+ AfterSMEProloguePt = MBBI;
+ PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
}
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
auto [NeededState, InsertPt] = getZAStateBeforeInst(
@@ -368,11 +402,18 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
// Reverse vector (as we had to iterate backwards for liveness).
std::reverse(Block.Insts.begin(), Block.Insts.end());
}
+
+ return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
+ PhysLiveRegsAfterSMEPrologue};
}
-void MachineSMEABI::assignBundleZAStates() {
- State.BundleStates.resize(Bundles->getNumBundles());
- for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) {
+/// Assigns each edge bundle a ZA state based on the needed states of blocks
+/// that have incoming or outgoing edges in that bundle.
+SmallVector<ZAState>
+MachineSMEABI::assignBundleZAStates(EdgeBundles const &Bundles,
+ FunctionInfo const &FnInfo) {
+ SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
+ for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n');
// Attempt to assign a ZA state for this bundle that minimizes state
@@ -381,16 +422,16 @@ void MachineSMEABI::assignBundleZAStates() {
// TODO: We should propagate desired incoming/outgoing states through blocks
// that have the "ANY" state first to make better global decisions.
int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
- for (unsigned BlockID : Bundles->getBlocks(I)) {
+ for (unsigned BlockID : Bundles.getBlocks(I)) {
LLVM_DEBUG(dbgs() << "- bb." << BlockID);
- const BlockInfo &Block = State.Blocks[BlockID];
+ const BlockInfo &Block = FnInfo.Blocks[BlockID];
if (Block.Insts.empty()) {
LLVM_DEBUG(dbgs() << " (no state preference)\n");
continue;
}
- bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I;
- bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I;
+ bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
+ bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;
ZAState DesiredIncomingState = Block.Insts.front().NeededState;
if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
@@ -423,15 +464,20 @@ void MachineSMEABI::assignBundleZAStates() {
dbgs() << "\n\n";
});
- State.BundleStates[I] = BundleState;
+ BundleStates[I] = BundleState;
}
+
+ return BundleStates;
}
-void MachineSMEABI::insertStateChanges() {
+void MachineSMEABI::insertStateChanges(EmitContext &Context,
+ FunctionInfo const &FnInfo,
+ EdgeBundles const &Bundles,
+ ArrayRef<ZAState> BundleStates) {
for (MachineBasicBlock &MBB : *MF) {
- const BlockInfo &Block = State.Blocks[MBB.getNumber()];
- ZAState InState = State.BundleStates[Bundles->getBundle(MBB.getNumber(),
- /*Out=*/false)];
+ const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
+ ZAState InState = BundleStates[Bundles.getBundle(MBB.getNumber(),
+ /*Out=*/false)];
ZAState CurrentState = Block.FixedEntryState;
if (CurrentState == ZAState::ANY)
@@ -439,8 +485,8 @@ void MachineSMEABI::insertStateChanges() {
for (auto &Inst : Block.Insts) {
if (CurrentState != Inst.NeededState)
- emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
- Inst.PhysLiveRegs);
+ emitStateChange(Context, MBB, Inst.InsertPt, CurrentState,
+ Inst.NeededState, Inst.PhysLiveRegs);
CurrentState = Inst.NeededState;
}
@@ -448,21 +494,13 @@ void MachineSMEABI::insertStateChanges() {
continue;
ZAState OutState =
- State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)];
+ BundleStates[Bundles.getBundle(MBB.getNumber(), /*Out=*/true)];
if (CurrentState != OutState)
- emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
- Block.PhysLiveRegsAtExit);
+ emitStateChange(Context, MBB, MBB.getFirstTerminator(), CurrentState,
+ OutState, Block.PhysLiveRegsAtExit);
}
}
-TPIDR2State MachineSMEABI::getTPIDR2Block() {
- if (State.TPIDR2Block)
- return *State.TPIDR2Block;
- MachineFrameInfo &MFI = MF->getFrameInfo();
- State.TPIDR2Block = TPIDR2State{MFI.CreateStackObject(16, Align(16), false)};
- return *State.TPIDR2Block;
-}
-
static DebugLoc getDebugLoc(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI) {
if (MBBI != MBB.end())
@@ -470,7 +508,8 @@ static DebugLoc getDebugLoc(MachineBasicBlock &MBB,
return DebugLoc();
}
-void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB,
+void MachineSMEABI::emitSetupLazySave(EmitContext &Context,
+ MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI) {
DebugLoc DL = getDebugLoc(MBB, MBBI);
@@ -478,7 +517,7 @@ void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB,
Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
- .addFrameIndex(getTPIDR2Block().FrameIndex)
+ .addFrameIndex(Context.getTPIDR2Block(*MF))
.addImm(0)
.addImm(0);
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr)
@@ -528,7 +567,8 @@ void MachineSMEABI::restorePhyRegSave(PhysRegSave const &RegSave,
.addReg(RegSave.X0Save);
}
-void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB,
+void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
+ MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs) {
auto *TLI = Subtarget->getTargetLowering();
@@ -548,7 +588,7 @@ void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB,
.addImm(AArch64SysReg::TPIDR2_EL0);
// Get pointer to TPIDR2 block.
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
- .addFrameIndex(getTPIDR2Block().FrameIndex)
+ .addFrameIndex(Context.getTPIDR2Block(*MF))
.addImm(0)
.addImm(0);
// (Conditionally) restore ZA state.
@@ -582,7 +622,8 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
}
void MachineSMEABI::emitAllocateLazySaveBuffer(
- MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
+ EmitContext &Context, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI) {
MachineFrameInfo &MFI = MF->getFrameInfo();
DebugLoc DL = getDebugLoc(MBB, MBBI);
Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
@@ -630,7 +671,7 @@ void MachineSMEABI::emitAllocateLazySaveBuffer(
BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
.addReg(Buffer)
.addReg(SVL)
- .addFrameIndex(getTPIDR2Block().FrameIndex)
+ .addFrameIndex(Context.getTPIDR2Block(*MF))
.addImm(0);
}
}
@@ -662,18 +703,8 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
.addImm(1);
}
-Register MachineSMEABI::getAgnosticZABufferPtr() {
- if (State.AgnosticZABufferPtr != AArch64::NoRegister)
- return State.AgnosticZABufferPtr;
- Register BufferPtr = AFI->getEarlyAllocSMESaveBuffer();
- State.AgnosticZABufferPtr =
- BufferPtr != AArch64::NoRegister
- ? BufferPtr
- : MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
- return State.AgnosticZABufferPtr;
-}
-
-void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
+void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
+ MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
LiveRegs PhysLiveRegs, bool IsSave) {
auto *TLI = Subtarget->getTargetLowering();
@@ -684,7 +715,7 @@ void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
// Copy the buffer pointer into X0.
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
- .addReg(getAgnosticZABufferPtr());
+ .addReg(Context.getAgnosticZABufferPtr(*MF));
// Call __arm_sme_save/__arm_sme_restore.
BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
@@ -699,14 +730,14 @@ void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
}
void MachineSMEABI::emitAllocateFullZASaveBuffer(
- MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
- LiveRegs PhysLiveRegs) {
+ EmitContext &Context, MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
// Buffer already allocated in SelectionDAG.
if (AFI->getEarlyAllocSMESaveBuffer())
return;
DebugLoc DL = getDebugLoc(MBB, MBBI);
- Register BufferPtr = getAgnosticZABufferPtr();
+ Register BufferPtr = Context.getAgnosticZABufferPtr(*MF);
Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
@@ -742,11 +773,11 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer(
restorePhyRegSave(RegSave, MBB, MBBI, DL);
}
-void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
+void MachineSMEABI::emitStateChange(EmitContext &Context,
+ MachineBasicBlock &MBB,
MachineBasicBlock::iterator InsertPt,
ZAState From, ZAState To,
LiveRegs PhysLiveRegs) {
-
// ZA not used.
if (From == ZAState::ANY || To == ZAState::ANY)
return;
@@ -774,9 +805,9 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
}
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
- emitZASave(MBB, InsertPt, PhysLiveRegs);
+ emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
- emitZARestore(MBB, InsertPt, PhysLiveRegs);
+ emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
else if (To == ZAState::OFF) {
assert(From != ZAState::CALLER_DORMANT &&
"CALLER_DORMANT to OFF should have already been handled");
@@ -807,32 +838,33 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
- // Reset pass state.
- State = PassState{};
this->MF = &MF;
- Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
Subtarget = &MF.getSubtarget<AArch64Subtarget>();
TII = Subtarget->getInstrInfo();
TRI = Subtarget->getRegisterInfo();
MRI = &MF.getRegInfo();
- collectNeededZAStates(SMEFnAttrs);
- assignBundleZAStates();
- insertStateChanges();
+ EdgeBundles const &Bundles =
+ getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
+
+ FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
+ SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
+
+ EmitContext Context;
+ insertStateChanges(Context, FnInfo, Bundles, BundleStates);
- // Allocate save buffer (if needed).
- if (State.AgnosticZABufferPtr != AArch64::NoRegister || State.TPIDR2Block) {
- if (State.AfterSMEProloguePt) {
+ if (Context.needsSaveBuffer()) {
+ if (FnInfo.AfterSMEProloguePt) {
// Note: With inline stack probes the AfterSMEProloguePt may not be in the
// entry block (due to the probing loop).
- emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
- *State.AfterSMEProloguePt,
- State.PhysLiveRegsAfterSMEPrologue);
+ MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt;
+ emitAllocateZASaveBuffer(Context, *MBBI->getParent(), MBBI,
+ FnInfo.PhysLiveRegsAfterSMEPrologue);
} else {
MachineBasicBlock &EntryBlock = MF.front();
emitAllocateZASaveBuffer(
- EntryBlock, EntryBlock.getFirstNonPHI(),
- State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
+ Context, EntryBlock, EntryBlock.getFirstNonPHI(),
+ FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
}
}
>From 093f7123744ebe22c9e1107c6f64e36aca47af11 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 11 Sep 2025 12:51:14 +0000
Subject: [PATCH 2/2] Fixups
Change-Id: I05e79e16397a3e035f5d35df30c4999987a029d4
---
llvm/lib/Target/AArch64/MachineSMEABIPass.cpp | 32 +++++++++++--------
1 file changed, 19 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index fd07c939bbb0c..cced0faa28889 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -110,7 +110,8 @@ struct PhysRegSave {
Register X0Save = AArch64::NoRegister;
};
-/// Contains the needed ZA state (and live registers) at an instruction.
+/// Contains the needed ZA state (and live registers) at an instruction. That is
+/// the state ZA must be in _before_ "InsertPt".
struct InstInfo {
ZAState NeededState{ZAState::ANY};
MachineBasicBlock::iterator InsertPt;
@@ -135,7 +136,8 @@ struct FunctionInfo {
/// State/helpers that is only needed when emitting code to handle
/// saving/restoring ZA.
-struct EmitContext {
+class EmitContext {
+public:
EmitContext() = default;
/// Get or create a TPIDR2 block in \p MF.
@@ -160,7 +162,11 @@ struct EmitContext {
return AgnosticZABufferPtr;
}
+ /// Returns true if the function must allocate a ZA save buffer on entry. This
+ /// will be the case if, at any point in the function, a ZA save was emitted.
bool needsSaveBuffer() const {
+ assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) &&
+ "Cannot have both a TPIDR2 block and agnostic ZA buffer");
return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
}
@@ -252,13 +258,13 @@ struct MachineSMEABI : public MachineFunctionPass {
/// Assigns each edge bundle a ZA state based on the needed states of blocks
/// that have incoming or outgoing edges in that bundle.
- SmallVector<ZAState> assignBundleZAStates(EdgeBundles const &Bundles,
- FunctionInfo const &FnInfo);
+ SmallVector<ZAState> assignBundleZAStates(const EdgeBundles &Bundles,
+ const FunctionInfo &FnInfo);
/// Inserts code to handle changes between ZA states within the function.
/// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
- void insertStateChanges(EmitContext &, FunctionInfo const &FnInfo,
- EdgeBundles const &Bundles,
+ void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo,
+ const EdgeBundles &Bundles,
ArrayRef<ZAState> BundleStates);
// Emission routines for private and shared ZA functions (using lazy saves).
@@ -319,7 +325,7 @@ struct MachineSMEABI : public MachineFunctionPass {
PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, DebugLoc DL);
/// Restore physical registers from a save of their previous values.
- void restorePhyRegSave(PhysRegSave const &RegSave, MachineBasicBlock &MBB,
+ void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, DebugLoc DL);
private:
@@ -410,8 +416,8 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
/// Assigns each edge bundle a ZA state based on the needed states of blocks
/// that have incoming or outgoing edges in that bundle.
SmallVector<ZAState>
-MachineSMEABI::assignBundleZAStates(EdgeBundles const &Bundles,
- FunctionInfo const &FnInfo) {
+MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
+ const FunctionInfo &FnInfo) {
SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n');
@@ -471,8 +477,8 @@ MachineSMEABI::assignBundleZAStates(EdgeBundles const &Bundles,
}
void MachineSMEABI::insertStateChanges(EmitContext &Context,
- FunctionInfo const &FnInfo,
- EdgeBundles const &Bundles,
+ const FunctionInfo &FnInfo,
+ const EdgeBundles &Bundles,
ArrayRef<ZAState> BundleStates) {
for (MachineBasicBlock &MBB : *MF) {
const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
@@ -551,7 +557,7 @@ PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
return RegSave;
}
-void MachineSMEABI::restorePhyRegSave(PhysRegSave const &RegSave,
+void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave,
MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
DebugLoc DL) {
@@ -844,7 +850,7 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
TRI = Subtarget->getRegisterInfo();
MRI = &MF.getRegInfo();
- EdgeBundles const &Bundles =
+ const EdgeBundles &Bundles =
getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
More information about the llvm-commits
mailing list