[llvm] [AArch64][SME] Implement the SME ABI (ZA state management) in Machine IR (PR #149062)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 8 05:02:20 PDT 2025


================
@@ -0,0 +1,642 @@
+//===- MachineSMEABIPass.cpp ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements the SME ABI requirements for ZA state. This includes
+// implementing the lazy ZA state save schemes around calls.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AArch64InstrInfo.h"
+#include "AArch64MachineFunctionInfo.h"
+#include "AArch64Subtarget.h"
+#include "MCTargetDesc/AArch64AddressingModes.h"
+#include "llvm/ADT/BitmaskEnum.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/EdgeBundles.h"
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-machine-sme-abi"
+
+namespace {
+
+enum ZAState {
+  ANY = 0,
+  ACTIVE,
+  LOCAL_SAVED,
+  CALLER_DORMANT,
+  OFF,
+  NUM_ZA_STATE
+};
+
+enum LiveRegs : uint8_t {
+  None = 0,
+  NZCV = 1 << 0,
+  W0 = 1 << 1,
+  W0_HI = 1 << 2,
+  X0 = W0 | W0_HI,
+  LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
+};
+
+static bool isLegalEdgeBundleZAState(ZAState State) {
+  switch (State) {
+  case ZAState::ACTIVE:
+  case ZAState::LOCAL_SAVED:
+    return true;
+  default:
+    return false;
+  }
+}
+struct TPIDR2State {
+  int FrameIndex = -1;
+};
+
+StringRef getZAStateString(ZAState State) {
+#define MAKE_CASE(V)                                                           \
+  case V:                                                                      \
+    return #V;
+  switch (State) {
+    MAKE_CASE(ZAState::ANY)
+    MAKE_CASE(ZAState::ACTIVE)
+    MAKE_CASE(ZAState::LOCAL_SAVED)
+    MAKE_CASE(ZAState::CALLER_DORMANT)
+    MAKE_CASE(ZAState::OFF)
+  default:
+    llvm_unreachable("Unexpected ZAState");
+  }
+#undef MAKE_CASE
+}
+
+static bool isZAorZT0RegOp(const TargetRegisterInfo &TRI,
+                           const MachineOperand &MO) {
+  if (!MO.isReg() || !MO.getReg().isPhysical())
+    return false;
+  return any_of(TRI.subregs_inclusive(MO.getReg()), [](const MCPhysReg &SR) {
+    return AArch64::MPR128RegClass.contains(SR) ||
+           AArch64::ZTRRegClass.contains(SR);
+  });
+}
+
+static std::pair<ZAState, MachineBasicBlock::iterator>
+getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
+                     bool ZALiveAtReturn) {
+  MachineBasicBlock::iterator InsertPt(MI);
+
+  if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
+    return {ZAState::ACTIVE, std::prev(InsertPt)};
+
+  if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
+    return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
+
+  if (MI.isReturn())
+    return {ZALiveAtReturn ? ZAState::ACTIVE : ZAState::OFF, InsertPt};
+
+  for (auto &MO : MI.operands()) {
+    if (isZAorZT0RegOp(TRI, MO))
+      return {ZAState::ACTIVE, InsertPt};
+  }
+
+  return {ZAState::ANY, InsertPt};
+}
+
+struct MachineSMEABI : public MachineFunctionPass {
+  inline static char ID = 0;
+
+  MachineSMEABI() : MachineFunctionPass(ID) {}
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  StringRef getPassName() const override { return "Machine SME ABI pass"; }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+    AU.addRequired<EdgeBundlesWrapperLegacy>();
+    AU.addPreservedID(MachineLoopInfoID);
+    AU.addPreservedID(MachineDominatorsID);
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  void collectNeededZAStates(MachineFunction &MF, SMEAttrs);
+  void pickBundleZAStates(MachineFunction &MF);
+  void insertStateChanges(MachineFunction &MF);
+
+  // Emission routines for private and shared ZA functions (using lazy saves).
+  void emitNewZAPrologue(MachineBasicBlock &MBB,
+                         MachineBasicBlock::iterator MBBI);
+  void emitRestoreLazySave(MachineBasicBlock &MBB,
+                           MachineBasicBlock::iterator MBBI,
+                           LiveRegs PhysLiveRegs);
+  void emitSetupLazySave(MachineBasicBlock &MBB,
+                         MachineBasicBlock::iterator MBBI);
+  void emitAllocateLazySaveBuffer(MachineBasicBlock &MBB,
+                                  MachineBasicBlock::iterator MBBI);
+  void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+                 bool ClearTPIDR2);
+
+  void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+                       ZAState From, ZAState To, LiveRegs PhysLiveRegs);
+
+  TPIDR2State getTPIDR2Block(MachineFunction &MF);
+
+private:
+  struct InstInfo {
+    ZAState NeededState{ZAState::ANY};
----------------
MacDue wrote:

Added a comment. Please let me know if you had any specific issue here :)

https://github.com/llvm/llvm-project/pull/149062


More information about the llvm-commits mailing list