[llvm] [AArch64][SME] Implement the SME ABI (ZA state management) in Machine IR (PR #149062)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 8 05:11:11 PDT 2025


================
@@ -0,0 +1,642 @@
+//===- MachineSMEABIPass.cpp ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements the SME ABI requirements for ZA state. This includes
+// implementing the lazy ZA state save schemes around calls.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AArch64InstrInfo.h"
+#include "AArch64MachineFunctionInfo.h"
+#include "AArch64Subtarget.h"
+#include "MCTargetDesc/AArch64AddressingModes.h"
+#include "llvm/ADT/BitmaskEnum.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/EdgeBundles.h"
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-machine-sme-abi"
+
+namespace {
+
+enum ZAState {
+  ANY = 0,
+  ACTIVE,
+  LOCAL_SAVED,
+  CALLER_DORMANT,
+  OFF,
+  NUM_ZA_STATE
+};
+
+enum LiveRegs : uint8_t {
+  None = 0,
+  NZCV = 1 << 0,
+  W0 = 1 << 1,
+  W0_HI = 1 << 2,
+  X0 = W0 | W0_HI,
+  LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
+};
+
+static bool isLegalEdgeBundleZAState(ZAState State) {
+  switch (State) {
+  case ZAState::ACTIVE:
+  case ZAState::LOCAL_SAVED:
+    return true;
+  default:
+    return false;
+  }
+}
+struct TPIDR2State {
+  int FrameIndex = -1;
+};
+
+StringRef getZAStateString(ZAState State) {
+#define MAKE_CASE(V)                                                           \
+  case V:                                                                      \
+    return #V;
+  switch (State) {
+    MAKE_CASE(ZAState::ANY)
+    MAKE_CASE(ZAState::ACTIVE)
+    MAKE_CASE(ZAState::LOCAL_SAVED)
+    MAKE_CASE(ZAState::CALLER_DORMANT)
+    MAKE_CASE(ZAState::OFF)
+  default:
+    llvm_unreachable("Unexpected ZAState");
+  }
+#undef MAKE_CASE
+}
+
+static bool isZAorZT0RegOp(const TargetRegisterInfo &TRI,
+                           const MachineOperand &MO) {
+  if (!MO.isReg() || !MO.getReg().isPhysical())
+    return false;
+  return any_of(TRI.subregs_inclusive(MO.getReg()), [](const MCPhysReg &SR) {
+    return AArch64::MPR128RegClass.contains(SR) ||
+           AArch64::ZTRRegClass.contains(SR);
+  });
+}
+
+static std::pair<ZAState, MachineBasicBlock::iterator>
+getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
+                     bool ZALiveAtReturn) {
+  MachineBasicBlock::iterator InsertPt(MI);
+
+  if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
+    return {ZAState::ACTIVE, std::prev(InsertPt)};
+
+  if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
+    return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
+
+  if (MI.isReturn())
+    return {ZALiveAtReturn ? ZAState::ACTIVE : ZAState::OFF, InsertPt};
+
+  for (auto &MO : MI.operands()) {
+    if (isZAorZT0RegOp(TRI, MO))
+      return {ZAState::ACTIVE, InsertPt};
+  }
+
+  return {ZAState::ANY, InsertPt};
+}
+
+struct MachineSMEABI : public MachineFunctionPass {
+  inline static char ID = 0;
+
+  MachineSMEABI() : MachineFunctionPass(ID) {}
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  StringRef getPassName() const override { return "Machine SME ABI pass"; }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+    AU.addRequired<EdgeBundlesWrapperLegacy>();
+    AU.addPreservedID(MachineLoopInfoID);
+    AU.addPreservedID(MachineDominatorsID);
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  void collectNeededZAStates(MachineFunction &MF, SMEAttrs);
+  void pickBundleZAStates(MachineFunction &MF);
+  void insertStateChanges(MachineFunction &MF);
+
+  // Emission routines for private and shared ZA functions (using lazy saves).
+  void emitNewZAPrologue(MachineBasicBlock &MBB,
+                         MachineBasicBlock::iterator MBBI);
+  void emitRestoreLazySave(MachineBasicBlock &MBB,
+                           MachineBasicBlock::iterator MBBI,
+                           LiveRegs PhysLiveRegs);
+  void emitSetupLazySave(MachineBasicBlock &MBB,
+                         MachineBasicBlock::iterator MBBI);
+  void emitAllocateLazySaveBuffer(MachineBasicBlock &MBB,
+                                  MachineBasicBlock::iterator MBBI);
+  void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+                 bool ClearTPIDR2);
+
+  void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+                       ZAState From, ZAState To, LiveRegs PhysLiveRegs);
+
+  TPIDR2State getTPIDR2Block(MachineFunction &MF);
+
+private:
+  struct InstInfo {
+    ZAState NeededState{ZAState::ANY};
+    MachineBasicBlock::iterator InsertPt;
+    LiveRegs PhysLiveRegs = LiveRegs::None;
+  };
+
+  struct BlockInfo {
+    ZAState FixedEntryState{ZAState::ANY};
+    SmallVector<InstInfo> Insts;
+    LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
+  };
+
+  // All pass state that must be cleared between functions.
+  struct PassState {
+    SmallVector<BlockInfo> Blocks;
+    SmallVector<ZAState> BundleStates;
+    std::optional<TPIDR2State> TPIDR2Block;
+  } State;
+
+  EdgeBundles *Bundles = nullptr;
+};
+
+void MachineSMEABI::collectNeededZAStates(MachineFunction &MF,
+                                          SMEAttrs SMEFnAttrs) {
+  const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
+  assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
+         "Expected function to have ZA/ZT0 state!");
+
+  State.Blocks.resize(MF.getNumBlockIDs());
+  for (MachineBasicBlock &MBB : MF) {
+    BlockInfo &Block = State.Blocks[MBB.getNumber()];
+    if (&MBB == &MF.front()) {
+      // Entry block:
+      Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface()
+                                  ? ZAState::CALLER_DORMANT
+                                  : ZAState::ACTIVE;
+    } else if (MBB.isEHPad()) {
+      // EH entry block:
+      Block.FixedEntryState = ZAState::LOCAL_SAVED;
+    }
+
+    LiveRegUnits LiveUnits(TRI);
+    LiveUnits.addLiveOuts(MBB);
+
+    auto GetPhysLiveRegs = [&] {
+      LiveRegs PhysLiveRegs = LiveRegs::None;
+      if (!LiveUnits.available(AArch64::NZCV))
+        PhysLiveRegs |= LiveRegs::NZCV;
+      // We have to track W0 and X0 separately as otherwise things can get
+      // confused we attempt to preserve X0 but only W0 was defined.
+      if (!LiveUnits.available(AArch64::W0))
+        PhysLiveRegs |= LiveRegs::W0;
+      if (!LiveUnits.available(AArch64::W0_HI))
+        PhysLiveRegs |= LiveRegs::W0_HI;
+      return PhysLiveRegs;
+    };
+
+    Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
+    auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
+    for (MachineInstr &MI : reverse(MBB)) {
+      MachineBasicBlock::iterator MBBI(MI);
+      LiveUnits.stepBackward(MI);
+      LiveRegs PhysLiveRegs = GetPhysLiveRegs();
+      auto [NeededState, InsertPt] = getInstNeededZAState(
+          TRI, MI, /*ZALiveAtReturn=*/SMEFnAttrs.hasSharedZAInterface());
+      assert((InsertPt == MBBI ||
+              InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) &&
+             "Unexpected state change insertion point!");
+      // TODO: Do something to avoid state changes where NZCV is live.
+      if (MBBI == FirstTerminatorInsertPt)
+        Block.PhysLiveRegsAtExit = PhysLiveRegs;
+      if (NeededState != ZAState::ANY)
+        Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
+    }
+
+    // Reverse vector (as we had to iterate backwards for liveness).
+    std::reverse(Block.Insts.begin(), Block.Insts.end());
+  }
+}
+
+void MachineSMEABI::pickBundleZAStates(MachineFunction &MF) {
+  State.BundleStates.resize(Bundles->getNumBundles());
+  for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) {
+    LLVM_DEBUG(dbgs() << "Picking ZA state for edge bundle: " << I << '\n');
+
+    // Attempt to pick a ZA state for this bundle that minimizes state
+    // transitions. Edges within loops are given a higher weight as we assume
+    // they will be executed more than once.
+    // TODO: We should propagate desired incoming/outgoing states through blocks
+    // that have the "ANY" state first to make better global decisions.
+    int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
+    for (unsigned BlockID : Bundles->getBlocks(I)) {
+      LLVM_DEBUG(dbgs() << "- bb." << BlockID);
+
+      BlockInfo &Block = State.Blocks[BlockID];
+      if (Block.Insts.empty()) {
+        LLVM_DEBUG(dbgs() << " (no state preference)\n");
+        continue;
+      }
+      bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I;
+      bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I;
+
+      ZAState DesiredIncomingState = Block.Insts.front().NeededState;
+      if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
+        EdgeStateCounts[DesiredIncomingState]++;
+        LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
+                          << getZAStateString(DesiredIncomingState));
+      }
+      ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
+      if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
+        EdgeStateCounts[DesiredOutgoingState]++;
+        LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
+                          << getZAStateString(DesiredOutgoingState));
+      }
+      LLVM_DEBUG(dbgs() << '\n');
+    }
+
+    ZAState BundleState =
+        ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
+
+    // Force ZA to be active in bundles that don't have a preferred state.
+    // TODO: Something better here (to avoid extra mode switches).
+    if (BundleState == ZAState::ANY)
+      BundleState = ZAState::ACTIVE;
+
+    LLVM_DEBUG({
+      dbgs() << "Chosen ZA state: " << getZAStateString(BundleState) << '\n'
+             << "Edge counts:";
+      for (auto [State, Count] : enumerate(EdgeStateCounts))
+        dbgs() << " " << getZAStateString(ZAState(State)) << ": " << Count;
+      dbgs() << "\n\n";
+    });
+
+    State.BundleStates[I] = BundleState;
+  }
+}
+
+void MachineSMEABI::insertStateChanges(MachineFunction &MF) {
+  for (MachineBasicBlock &MBB : MF) {
+    BlockInfo &Block = State.Blocks[MBB.getNumber()];
+    ZAState InState =
+        State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/false)];
+    ZAState OutState =
+        State.BundleStates[Bundles->getBundle(MBB.getNumber(), /*Out=*/true)];
+
+    ZAState CurrentState = Block.FixedEntryState;
+    if (CurrentState == ZAState::ANY)
+      CurrentState = InState;
+
+    for (auto &Inst : Block.Insts) {
+      if (CurrentState != Inst.NeededState)
+        emitStateChange(MBB, Inst.InsertPt, CurrentState, Inst.NeededState,
+                        Inst.PhysLiveRegs);
+      CurrentState = Inst.NeededState;
+    }
+
+    if (MBB.succ_empty())
+      continue;
+
+    if (CurrentState != OutState)
+      emitStateChange(MBB, MBB.getFirstTerminator(), CurrentState, OutState,
+                      Block.PhysLiveRegsAtExit);
+  }
+}
+
+TPIDR2State MachineSMEABI::getTPIDR2Block(MachineFunction &MF) {
+  if (State.TPIDR2Block)
+    return *State.TPIDR2Block;
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+  State.TPIDR2Block = TPIDR2State{MFI.CreateStackObject(16, Align(16), false)};
+  return *State.TPIDR2Block;
+}
+
+static DebugLoc getDebugLoc(MachineBasicBlock &MBB,
+                            MachineBasicBlock::iterator MBBI) {
+  if (MBBI != MBB.end())
+    return MBBI->getDebugLoc();
+  return DebugLoc();
+}
+
+void MachineSMEABI::emitSetupLazySave(MachineBasicBlock &MBB,
+                                      MachineBasicBlock::iterator MBBI) {
+  MachineFunction &MF = *MBB.getParent();
+  auto &Subtarget = MF.getSubtarget<AArch64Subtarget>();
+  const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  DebugLoc DL = getDebugLoc(MBB, MBBI);
+
+  // Get pointer to TPIDR2 block.
+  Register TPIDR2 = MRI.createVirtualRegister(&AArch64::GPR64spRegClass);
+  Register TPIDR2Ptr = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
+  BuildMI(MBB, MBBI, DL, TII.get(AArch64::ADDXri), TPIDR2)
+      .addFrameIndex(getTPIDR2Block(MF).FrameIndex)
+      .addImm(0)
+      .addImm(0);
+  BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY), TPIDR2Ptr).addReg(TPIDR2);
+  // Set TPIDR2_EL0 to point to TPIDR2 block.
+  BuildMI(MBB, MBBI, DL, TII.get(AArch64::MSR))
+      .addImm(AArch64SysReg::TPIDR2_EL0)
+      .addReg(TPIDR2Ptr);
+}
+
+// Helper class for saving physical registers around calls.
+struct ScopedPhysRegSave {
+  ScopedPhysRegSave(MachineRegisterInfo &MRI, const TargetInstrInfo &TII,
+                    DebugLoc DL, MachineBasicBlock &MBB,
+                    MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs)
+      : TII(TII), DL(DL), MBB(MBB), MBBI(MBBI), PhysLiveRegs(PhysLiveRegs) {
+    if (PhysLiveRegs & LiveRegs::NZCV) {
+      StatusFlags = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
+      BuildMI(MBB, MBBI, DL, TII.get(AArch64::MRS))
+          .addReg(StatusFlags, RegState::Define)
+          .addImm(AArch64SysReg::NZCV)
+          .addReg(AArch64::NZCV, RegState::Implicit);
+    }
+    // Note: Preserving X0 is "free" as this is before register allocation, so
+    // the register allocator is still able to optimize these copies.
+    if (PhysLiveRegs & LiveRegs::W0) {
+      X0Save = MRI.createVirtualRegister(PhysLiveRegs & LiveRegs::W0_HI
+                                             ? &AArch64::GPR64RegClass
+                                             : &AArch64::GPR32RegClass);
+      BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY), X0Save)
+          .addReg(PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0);
+    }
+  }
+
+  ~ScopedPhysRegSave() {
+    if (StatusFlags != AArch64::NoRegister)
+      BuildMI(MBB, MBBI, DL, TII.get(AArch64::MSR))
+          .addImm(AArch64SysReg::NZCV)
+          .addReg(StatusFlags)
+          .addReg(AArch64::NZCV, RegState::ImplicitDefine);
+
+    if (X0Save != AArch64::NoRegister)
+      BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::COPY),
+              PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0)
+          .addReg(X0Save);
+  }
+
+  const TargetInstrInfo &TII;
+  DebugLoc DL;
+  MachineBasicBlock &MBB;
+  MachineBasicBlock::iterator MBBI;
+  LiveRegs PhysLiveRegs;
+  Register StatusFlags = AArch64::NoRegister;
+  Register X0Save = AArch64::NoRegister;
+};
+
+void MachineSMEABI::emitRestoreLazySave(MachineBasicBlock &MBB,
+                                        MachineBasicBlock::iterator MBBI,
+                                        LiveRegs PhysLiveRegs) {
+  MachineFunction &MF = *MBB.getParent();
+  auto &Subtarget = MF.getSubtarget<AArch64Subtarget>();
+  const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();
+  const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  DebugLoc DL = getDebugLoc(MBB, MBBI);
+  Register TPIDR2EL0 = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
+  Register TPIDR2 = AArch64::X0;
+
+  // TODO: Emit these within the restore MBB to prevent unnecessary saves.
+  ScopedPhysRegSave ScopedPhysRegSave(MRI, TII, DL, MBB, MBBI, PhysLiveRegs);
----------------
MacDue wrote:

Changed this into some helpers for now (this is used more in later PRs). I do prefer having the save/restore local though (as I find having them a far apart can also obfuscate things in larger functions, I guess here it's fine though). I can always do something like:

```c++
  PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
  auto RestoreSave =
      make_scope_exit([&] { restorePhyRegSave(RegSave, MBB, MBBI, DL); });
```

https://github.com/llvm/llvm-project/pull/149062


More information about the llvm-commits mailing list