[llvm] [AArch64][SME] Refactor MachineSMEABI pass state (NFCI) (PR #156674)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 11 05:52:34 PDT 2025


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/156674

>From de1c14cbb4ba789928d55e36a9a2ec40353b39c3 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 15 Jul 2025 11:47:48 +0000
Subject: [PATCH 1/2] [AArch64][SME] Refactor MachineSMEABI pass state (NFCI)

This removes the pass state (aside from target classes) from the
MachineSMEABI class, and instead passes/returns state between functions.

The intention is to make dataflow (and where state is mutated) more
apparent.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |   1 +
 llvm/lib/Target/AArch64/MachineSMEABIPass.cpp | 284 ++++++++++--------
 2 files changed, 159 insertions(+), 126 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5e11145ecd161..88734eedaca8e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -9312,6 +9312,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   std::optional<unsigned> ZAMarkerNode;
   bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
+
   if (UseNewSMEABILowering) {
     if (CallAttrs.requiresLazySave() ||
         CallAttrs.requiresPreservingAllZAState())
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index c39a5cc2fcb16..fd07c939bbb0c 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -110,6 +110,65 @@ struct PhysRegSave {
   Register X0Save = AArch64::NoRegister;
 };
 
+/// Contains the needed ZA state (and live registers) at an instruction.
+struct InstInfo {
+  ZAState NeededState{ZAState::ANY};
+  MachineBasicBlock::iterator InsertPt;
+  LiveRegs PhysLiveRegs = LiveRegs::None;
+};
+
+/// Contains the needed ZA state for each instruction in a block. Instructions
+/// that do not require a ZA state are not recorded.
+struct BlockInfo {
+  ZAState FixedEntryState{ZAState::ANY};
+  SmallVector<InstInfo> Insts;
+  LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
+  LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
+};
+
+/// Contains the needed ZA state information for all blocks within a function.
+struct FunctionInfo {
+  SmallVector<BlockInfo> Blocks;
+  std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
+  LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
+};
+
+/// State/helpers that is only needed when emitting code to handle
+/// saving/restoring ZA.
+struct EmitContext {
+  EmitContext() = default;
+
+  /// Get or create a TPIDR2 block in \p MF.
+  int getTPIDR2Block(MachineFunction &MF) {
+    if (TPIDR2BlockFI)
+      return *TPIDR2BlockFI;
+    MachineFrameInfo &MFI = MF.getFrameInfo();
+    TPIDR2BlockFI = MFI.CreateStackObject(16, Align(16), false);
+    return *TPIDR2BlockFI;
+  }
+
+  /// Get or create agnostic ZA buffer pointer in \p MF.
+  Register getAgnosticZABufferPtr(MachineFunction &MF) {
+    if (AgnosticZABufferPtr != AArch64::NoRegister)
+      return AgnosticZABufferPtr;
+    Register BufferPtr =
+        MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
+    AgnosticZABufferPtr =
+        BufferPtr != AArch64::NoRegister
+            ? BufferPtr
+            : MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+    return AgnosticZABufferPtr;
+  }
+
+  bool needsSaveBuffer() const {
+    return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
+  }
+
+private:
+  std::optional<int> TPIDR2BlockFI;
+  Register AgnosticZABufferPtr = AArch64::NoRegister;
+};
+
 static bool isLegalEdgeBundleZAState(ZAState State) {
   switch (State) {
   case ZAState::ACTIVE:
@@ -119,9 +178,6 @@ static bool isLegalEdgeBundleZAState(ZAState State) {
     return false;
   }
 }
-struct TPIDR2State {
-  int FrameIndex = -1;
-};
 
 StringRef getZAStateString(ZAState State) {
 #define MAKE_CASE(V)                                                           \
@@ -192,25 +248,28 @@ struct MachineSMEABI : public MachineFunctionPass {
 
   /// Collects the needed ZA state (and live registers) before each instruction
   /// within the machine function.
-  void collectNeededZAStates(SMEAttrs);
+  FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs);
 
   /// Assigns each edge bundle a ZA state based on the needed states of blocks
   /// that have incoming or outgoing edges in that bundle.
-  void assignBundleZAStates();
+  SmallVector<ZAState> assignBundleZAStates(EdgeBundles const &Bundles,
+                                            FunctionInfo const &FnInfo);
 
   /// Inserts code to handle changes between ZA states within the function.
   /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
-  void insertStateChanges();
+  void insertStateChanges(EmitContext &, FunctionInfo const &FnInfo,
+                          EdgeBundles const &Bundles,
+                          ArrayRef<ZAState> BundleStates);
 
   // Emission routines for private and shared ZA functions (using lazy saves).
   void emitNewZAPrologue(MachineBasicBlock &MBB,
                          MachineBasicBlock::iterator MBBI);
-  void emitRestoreLazySave(MachineBasicBlock &MBB,
+  void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
                            MachineBasicBlock::iterator MBBI,
                            LiveRegs PhysLiveRegs);
-  void emitSetupLazySave(MachineBasicBlock &MBB,
+  void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB,
                          MachineBasicBlock::iterator MBBI);
-  void emitAllocateLazySaveBuffer(MachineBasicBlock &MBB,
+  void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
                                   MachineBasicBlock::iterator MBBI);
   void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                  bool ClearTPIDR2);
@@ -222,35 +281,38 @@ struct MachineSMEABI : public MachineFunctionPass {
   // Emit a "full" ZA save or restore. It is "full" in the sense that this
   // function will emit a call to __arm_sme_save or __arm_sme_restore, which
   // handles saving and restoring both ZA and ZT0.
-  void emitFullZASaveRestore(MachineBasicBlock &MBB,
+  void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB,
                              MachineBasicBlock::iterator MBBI,
                              LiveRegs PhysLiveRegs, bool IsSave);
-  void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
+  void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB,
                                     MachineBasicBlock::iterator MBBI,
                                     LiveRegs PhysLiveRegs);
 
-  void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
-                       ZAState From, ZAState To, LiveRegs PhysLiveRegs);
+  void emitStateChange(EmitContext &, MachineBasicBlock &MBB,
+                       MachineBasicBlock::iterator MBBI, ZAState From,
+                       ZAState To, LiveRegs PhysLiveRegs);
 
   // Helpers for switching between lazy/full ZA save/restore routines.
-  void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
-                  LiveRegs PhysLiveRegs) {
+  void emitZASave(EmitContext &Context, MachineBasicBlock &MBB,
+                  MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
     if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
-      return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
-    return emitSetupLazySave(MBB, MBBI);
+      return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
+                                   /*IsSave=*/true);
+    return emitSetupLazySave(Context, MBB, MBBI);
   }
-  void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
-                     LiveRegs PhysLiveRegs) {
+  void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB,
+                     MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
     if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
-      return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
-    return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
+      return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
+                                   /*IsSave=*/false);
+    return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs);
   }
-  void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
+  void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB,
                                 MachineBasicBlock::iterator MBBI,
                                 LiveRegs PhysLiveRegs) {
     if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
-      return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
-    return emitAllocateLazySaveBuffer(MBB, MBBI);
+      return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs);
+    return emitAllocateLazySaveBuffer(Context, MBB, MBBI);
   }
 
   /// Save live physical registers to virtual registers.
@@ -260,40 +322,8 @@ struct MachineSMEABI : public MachineFunctionPass {
   void restorePhyRegSave(PhysRegSave const &RegSave, MachineBasicBlock &MBB,
                          MachineBasicBlock::iterator MBBI, DebugLoc DL);
 
-  /// Get or create a TPIDR2 block in this function.
-  TPIDR2State getTPIDR2Block();
-
-  Register getAgnosticZABufferPtr();
-
 private:
-  /// Contains the needed ZA state (and live registers) at an instruction.
-  struct InstInfo {
-    ZAState NeededState{ZAState::ANY};
-    MachineBasicBlock::iterator InsertPt;
-    LiveRegs PhysLiveRegs = LiveRegs::None;
-  };
-
-  /// Contains the needed ZA state for each instruction in a block.
-  /// Instructions that do not require a ZA state are not recorded.
-  struct BlockInfo {
-    ZAState FixedEntryState{ZAState::ANY};
-    SmallVector<InstInfo> Insts;
-    LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
-    LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
-  };
-
-  // All pass state that must be cleared between functions.
-  struct PassState {
-    SmallVector<BlockInfo> Blocks;
-    SmallVector<ZAState> BundleStates;
-    std::optional<TPIDR2State> TPIDR2Block;
-    std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
-    Register AgnosticZABufferPtr = AArch64::NoRegister;
-    LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
-  } State;
-
   MachineFunction *MF = nullptr;
-  EdgeBundles *Bundles = nullptr;
   const AArch64Subtarget *Subtarget = nullptr;
   const AArch64RegisterInfo *TRI = nullptr;
   const AArch64FunctionInfo *AFI = nullptr;
@@ -301,14 +331,18 @@ struct MachineSMEABI : public MachineFunctionPass {
   MachineRegisterInfo *MRI = nullptr;
 };
 
-void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
+FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
   assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
           SMEFnAttrs.hasZAState()) &&
          "Expected function to have ZA/ZT0 state!");
 
-  State.Blocks.resize(MF->getNumBlockIDs());
+  SmallVector<BlockInfo> Blocks(MF->getNumBlockIDs());
+  LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
+  std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
+
   for (MachineBasicBlock &MBB : *MF) {
-    BlockInfo &Block = State.Blocks[MBB.getNumber()];
+    BlockInfo &Block = Blocks[MBB.getNumber()];
+
     if (MBB.isEntryBlock()) {
       // Entry block:
       Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface()
@@ -347,8 +381,8 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
       // allocation -- which is a safe point for this pass to insert any TPIDR2
       // block setup.
       if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
-        State.AfterSMEProloguePt = MBBI;
-        State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
+        AfterSMEProloguePt = MBBI;
+        PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
       }
       // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
       auto [NeededState, InsertPt] = getZAStateBeforeInst(
@@ -368,11 +402,18 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
     // Reverse vector (as we had to iterate backwards for liveness).
     std::reverse(Block.Insts.begin(), Block.Insts.end());
   }
+
+  return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
+                      PhysLiveRegsAfterSMEPrologue};
 }
 
-void MachineSMEABI::assignBundleZAStates() {
-  State.BundleStates.resize(Bundles->getNumBundles());
-  for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) {
+/// Assigns each edge bundle a ZA state based on the needed states of blocks
+/// that have incoming or outgoing edges in that bundle.
+SmallVector<ZAState>
+MachineSMEABI::assignBundleZAStates(EdgeBundles const &Bundles,
+                                    FunctionInfo const &FnInfo) {
+  SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
+  for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
     LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n');
 
     // Attempt to assign a ZA state for this bundle that minimizes state
@@ -381,16 +422,16 @@ void MachineSMEABI::assignBundleZAStates() {
     // TODO: We should propagate desired incoming/outgoing states through blocks
     // that have the "ANY" state first to make better global decisions.
     int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
-    for (unsigned BlockID : Bundles->getBlocks(I)) {
+    for (unsigned BlockID : Bundles.getBlocks(I)) {
       LLVM_DEBUG(dbgs() << "- bb." << BlockID);
 
-      const BlockInfo &Block = State.Blocks[BlockID];
+      const BlockInfo &Block = FnInfo.Blocks[BlockID];
       if (Block.Insts.empty()) {
         LLVM_DEBUG(dbgs() << " (no state preference)\n");
         continue;
       }
-      bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I;
-      bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I;
+      bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
+      bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;
 
       ZAState DesiredIncomingState = Block.Insts.front().NeededState;
       if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
@@ -423,15 +464,20 @@ void MachineSMEABI::assignBundleZAStates() {
       dbgs() << "\n\n";
     });
 
-    State.BundleStates[I] = BundleState;
+    BundleStates[I] = BundleState;
   }
+
+  return BundleStates;
 }
 
-void MachineSMEABI::insertStateChanges() {
+void MachineSMEABI::insertStateChanges(EmitContext &Context,
+                                       FunctionInfo const &FnInfo,
+                                       EdgeBundles const &Bundles,
+                                       ArrayRef<ZAState> BundleStates) {
   for (MachineBasicBlock &MBB : *MF) {
-    const BlockInfo &Block = State.Blocks[MBB.getNumber()];
-    ZAState InState = State.BundleStates[Bundles->getBundle(MBB.getNumber(),
-                                                            /*Out=*/false)];
+    const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
+    ZAState InState = BundleStates[Bundles.getBundle(MBB.getNumber(),
+                                                     /*Out=*/false)];
 
     ZAState CurrentState = Block.FixedEntryState;
     if (CurrentState == ZAState::ANY)
@@ -439,8 +485,8 @@ void MachineSMEABI::insertStateChanges() {
 
     for (auto &Inst : Block.Insts) {
       if (CurrentState != Inst.NeededState)
-        emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
-                        Inst.PhysLiveRegs);
+        emitStateChange(Context, MBB, Inst.InsertPt, CurrentState,
+                        Inst.NeededState, Inst.PhysLiveRegs);
       CurrentState = Inst.NeededState;
     }
 
@@ -448,21 +494,13 @@ void MachineSMEABI::insertStateChanges() {
       continue;
 
     ZAState OutState =
-        State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)];
+        BundleStates[Bundles.getBundle(MBB.getNumber(), /*Out=*/true)];
     if (CurrentState != OutState)
-      emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
-                      Block.PhysLiveRegsAtExit);
+      emitStateChange(Context, MBB, MBB.getFirstTerminator(), CurrentState,
+                      OutState, Block.PhysLiveRegsAtExit);
   }
 }
 
-TPIDR2State MachineSMEABI::getTPIDR2Block() {
-  if (State.TPIDR2Block)
-    return *State.TPIDR2Block;
-  MachineFrameInfo &MFI = MF->getFrameInfo();
-  State.TPIDR2Block = TPIDR2State{MFI.CreateStackObject(16, Align(16), false)};
-  return *State.TPIDR2Block;
-}
-
 static DebugLoc getDebugLoc(MachineBasicBlock &MBB,
                             MachineBasicBlock::iterator MBBI) {
   if (MBBI != MBB.end())
@@ -470,7 +508,8 @@ static DebugLoc getDebugLoc(MachineBasicBlock &MBB,
   return DebugLoc();
 }
 
-void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB,
+void MachineSMEABI::emitSetupLazySave(EmitContext &Context,
+                                      MachineBasicBlock &MBB,
                                       MachineBasicBlock::iterator MBBI) {
   DebugLoc DL = getDebugLoc(MBB, MBBI);
 
@@ -478,7 +517,7 @@ void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB,
   Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
   Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
   BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
-      .addFrameIndex(getTPIDR2Block().FrameIndex)
+      .addFrameIndex(Context.getTPIDR2Block(*MF))
       .addImm(0)
       .addImm(0);
   BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr)
@@ -528,7 +567,8 @@ void MachineSMEABI::restorePhyRegSave(PhysRegSave const &RegSave,
         .addReg(RegSave.X0Save);
 }
 
-void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB,
+void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
+                                        MachineBasicBlock &MBB,
                                         MachineBasicBlock::iterator MBBI,
                                         LiveRegs PhysLiveRegs) {
   auto *TLI = Subtarget->getTargetLowering();
@@ -548,7 +588,7 @@ void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB,
       .addImm(AArch64SysReg::TPIDR2_EL0);
   // Get pointer to TPIDR2 block.
   BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
-      .addFrameIndex(getTPIDR2Block().FrameIndex)
+      .addFrameIndex(Context.getTPIDR2Block(*MF))
       .addImm(0)
       .addImm(0);
   // (Conditionally) restore ZA state.
@@ -582,7 +622,8 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
 }
 
 void MachineSMEABI::emitAllocateLazySaveBuffer(
-    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
+    EmitContext &Context, MachineBasicBlock &MBB,
+    MachineBasicBlock::iterator MBBI) {
   MachineFrameInfo &MFI = MF->getFrameInfo();
   DebugLoc DL = getDebugLoc(MBB, MBBI);
   Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
@@ -630,7 +671,7 @@ void MachineSMEABI::emitAllocateLazySaveBuffer(
     BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
         .addReg(Buffer)
         .addReg(SVL)
-        .addFrameIndex(getTPIDR2Block().FrameIndex)
+        .addFrameIndex(Context.getTPIDR2Block(*MF))
         .addImm(0);
   }
 }
@@ -662,18 +703,8 @@ void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
       .addImm(1);
 }
 
-Register MachineSMEABI::getAgnosticZABufferPtr() {
-  if (State.AgnosticZABufferPtr != AArch64::NoRegister)
-    return State.AgnosticZABufferPtr;
-  Register BufferPtr = AFI->getEarlyAllocSMESaveBuffer();
-  State.AgnosticZABufferPtr =
-      BufferPtr != AArch64::NoRegister
-          ? BufferPtr
-          : MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
-  return State.AgnosticZABufferPtr;
-}
-
-void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
+void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
+                                          MachineBasicBlock &MBB,
                                           MachineBasicBlock::iterator MBBI,
                                           LiveRegs PhysLiveRegs, bool IsSave) {
   auto *TLI = Subtarget->getTargetLowering();
@@ -684,7 +715,7 @@ void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
 
   // Copy the buffer pointer into X0.
   BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
-      .addReg(getAgnosticZABufferPtr());
+      .addReg(Context.getAgnosticZABufferPtr(*MF));
 
   // Call __arm_sme_save/__arm_sme_restore.
   BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
@@ -699,14 +730,14 @@ void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
 }
 
 void MachineSMEABI::emitAllocateFullZASaveBuffer(
-    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
-    LiveRegs PhysLiveRegs) {
+    EmitContext &Context, MachineBasicBlock &MBB,
+    MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
   // Buffer already allocated in SelectionDAG.
   if (AFI->getEarlyAllocSMESaveBuffer())
     return;
 
   DebugLoc DL = getDebugLoc(MBB, MBBI);
-  Register BufferPtr = getAgnosticZABufferPtr();
+  Register BufferPtr = Context.getAgnosticZABufferPtr(*MF);
   Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
 
   PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
@@ -742,11 +773,11 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer(
   restorePhyRegSave(RegSave, MBB, MBBI, DL);
 }
 
-void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
+void MachineSMEABI::emitStateChange(EmitContext &Context,
+                                    MachineBasicBlock &MBB,
                                     MachineBasicBlock::iterator InsertPt,
                                     ZAState From, ZAState To,
                                     LiveRegs PhysLiveRegs) {
-
   // ZA not used.
   if (From == ZAState::ANY || To == ZAState::ANY)
     return;
@@ -774,9 +805,9 @@ void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
   }
 
   if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
-    emitZASave(MBB, InsertPt, PhysLiveRegs);
+    emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
   else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
-    emitZARestore(MBB, InsertPt, PhysLiveRegs);
+    emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
   else if (To == ZAState::OFF) {
     assert(From != ZAState::CALLER_DORMANT &&
            "CALLER_DORMANT to OFF should have already been handled");
@@ -807,32 +838,33 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
 
   assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
 
-  // Reset pass state.
-  State = PassState{};
   this->MF = &MF;
-  Bundles = &getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
   Subtarget = &MF.getSubtarget<AArch64Subtarget>();
   TII = Subtarget->getInstrInfo();
   TRI = Subtarget->getRegisterInfo();
   MRI = &MF.getRegInfo();
 
-  collectNeededZAStates(SMEFnAttrs);
-  assignBundleZAStates();
-  insertStateChanges();
+  EdgeBundles const &Bundles =
+      getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
+
+  FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
+  SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
+
+  EmitContext Context;
+  insertStateChanges(Context, FnInfo, Bundles, BundleStates);
 
-  // Allocate save buffer (if needed).
-  if (State.AgnosticZABufferPtr != AArch64::NoRegister || State.TPIDR2Block) {
-    if (State.AfterSMEProloguePt) {
+  if (Context.needsSaveBuffer()) {
+    if (FnInfo.AfterSMEProloguePt) {
       // Note: With inline stack probes the AfterSMEProloguePt may not be in the
       // entry block (due to the probing loop).
-      emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
-                               *State.AfterSMEProloguePt,
-                               State.PhysLiveRegsAfterSMEPrologue);
+      MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt;
+      emitAllocateZASaveBuffer(Context, *MBBI->getParent(), MBBI,
+                               FnInfo.PhysLiveRegsAfterSMEPrologue);
     } else {
       MachineBasicBlock &EntryBlock = MF.front();
       emitAllocateZASaveBuffer(
-          EntryBlock, EntryBlock.getFirstNonPHI(),
-          State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
+          Context, EntryBlock, EntryBlock.getFirstNonPHI(),
+          FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
     }
   }
 

>From 272c4de52716a0b61b47b9dec37258c533d55796 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 11 Sep 2025 12:51:14 +0000
Subject: [PATCH 2/2] Fixups

Change-Id: I05e79e16397a3e035f5d35df30c4999987a029d4
---
 llvm/lib/Target/AArch64/MachineSMEABIPass.cpp | 32 +++++++++++--------
 1 file changed, 19 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index fd07c939bbb0c..cced0faa28889 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -110,7 +110,8 @@ struct PhysRegSave {
   Register X0Save = AArch64::NoRegister;
 };
 
-/// Contains the needed ZA state (and live registers) at an instruction.
+/// Contains the needed ZA state (and live registers) at an instruction. That is
+/// the state ZA must be in _before_ "InsertPt".
 struct InstInfo {
   ZAState NeededState{ZAState::ANY};
   MachineBasicBlock::iterator InsertPt;
@@ -135,7 +136,8 @@ struct FunctionInfo {
 
 /// State/helpers that is only needed when emitting code to handle
 /// saving/restoring ZA.
-struct EmitContext {
+class EmitContext {
+public:
   EmitContext() = default;
 
   /// Get or create a TPIDR2 block in \p MF.
@@ -160,7 +162,11 @@ struct EmitContext {
     return AgnosticZABufferPtr;
   }
 
+  /// Returns true if the function must allocate a ZA save buffer on entry. This
+  /// will be the case if, at any point in the function, a ZA save was emitted.
   bool needsSaveBuffer() const {
+    assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) &&
+           "Cannot have both a TPIDR2 block and agnostic ZA buffer");
     return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
   }
 
@@ -252,13 +258,13 @@ struct MachineSMEABI : public MachineFunctionPass {
 
   /// Assigns each edge bundle a ZA state based on the needed states of blocks
   /// that have incoming or outgoing edges in that bundle.
-  SmallVector<ZAState> assignBundleZAStates(EdgeBundles const &Bundles,
-                                            FunctionInfo const &FnInfo);
+  SmallVector<ZAState> assignBundleZAStates(const EdgeBundles &Bundles,
+                                            const FunctionInfo &FnInfo);
 
   /// Inserts code to handle changes between ZA states within the function.
   /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
-  void insertStateChanges(EmitContext &, FunctionInfo const &FnInfo,
-                          EdgeBundles const &Bundles,
+  void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo,
+                          const EdgeBundles &Bundles,
                           ArrayRef<ZAState> BundleStates);
 
   // Emission routines for private and shared ZA functions (using lazy saves).
@@ -319,7 +325,7 @@ struct MachineSMEABI : public MachineFunctionPass {
   PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
                                 MachineBasicBlock::iterator MBBI, DebugLoc DL);
   /// Restore physical registers from a save of their previous values.
-  void restorePhyRegSave(PhysRegSave const &RegSave, MachineBasicBlock &MBB,
+  void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB,
                          MachineBasicBlock::iterator MBBI, DebugLoc DL);
 
 private:
@@ -410,8 +416,8 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
 /// Assigns each edge bundle a ZA state based on the needed states of blocks
 /// that have incoming or outgoing edges in that bundle.
 SmallVector<ZAState>
-MachineSMEABI::assignBundleZAStates(EdgeBundles const &Bundles,
-                                    FunctionInfo const &FnInfo) {
+MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
+                                    const FunctionInfo &FnInfo) {
   SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
   for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
     LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n');
@@ -471,8 +477,8 @@ MachineSMEABI::assignBundleZAStates(EdgeBundles const &Bundles,
 }
 
 void MachineSMEABI::insertStateChanges(EmitContext &Context,
-                                       FunctionInfo const &FnInfo,
-                                       EdgeBundles const &Bundles,
+                                       const FunctionInfo &FnInfo,
+                                       const EdgeBundles &Bundles,
                                        ArrayRef<ZAState> BundleStates) {
   for (MachineBasicBlock &MBB : *MF) {
     const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
@@ -551,7 +557,7 @@ PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
   return RegSave;
 }
 
-void MachineSMEABI::restorePhyRegSave(PhysRegSave const &RegSave,
+void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave,
                                       MachineBasicBlock &MBB,
                                       MachineBasicBlock::iterator MBBI,
                                       DebugLoc DL) {
@@ -844,7 +850,7 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
   TRI = Subtarget->getRegisterInfo();
   MRI = &MF.getRegInfo();
 
-  EdgeBundles const &Bundles =
+  const EdgeBundles &Bundles =
       getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
 
   FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);



More information about the llvm-commits mailing list