[llvm] [AArch64][SME] Spill p-regs as z-regs when streaming hazards are possible (PR #123752)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 24 09:54:50 PST 2025


================
@@ -4155,8 +4168,318 @@ int64_t AArch64FrameLowering::assignSVEStackObjectOffsets(
                                         true);
 }
 
+/// Attempts to scavenge a register from \p ScavengeableRegs given the used
+/// registers in \p UsedRegs.
+static Register tryScavengeRegister(LiveRegUnits const &UsedRegs,
+                                    BitVector const &ScavengeableRegs) {
+  for (auto Reg : ScavengeableRegs.set_bits()) {
+    if (UsedRegs.available(Reg))
+      return Reg;
+  }
+  return AArch64::NoRegister;
+}
+
+/// Propagates frame-setup/destroy flags from \p SourceMI to all instructions in
+/// \p MachineInstrs.
+static void propagateFrameFlags(MachineInstr &SourceMI,
+                                ArrayRef<MachineInstr *> MachineInstrs) {
+  for (MachineInstr *MI : MachineInstrs) {
+    if (SourceMI.getFlag(MachineInstr::FrameSetup))
+      MI->setFlag(MachineInstr::FrameSetup);
+    if (SourceMI.getFlag(MachineInstr::FrameDestroy))
+      MI->setFlag(MachineInstr::FrameDestroy);
+  }
+}
+
+/// RAII helper class for scavenging or spilling a register. On construction
+/// attempts to find a free register of class \p RC (given \p UsedRegs and \p
+/// AllocatableRegs), if no register can be found spills \p SpillCandidate to \p
+/// MaybeSpillFI to free a register. The free'd register is returned via the \p
+/// FreeReg output parameter. On destruction, if there is a spill, its previous
+/// value is reloaded. The spilling and scavenging is only valid at the
+/// insertion point \p MBBI, this class should _not_ be used in places that
+/// create or manipulate basic blocks, moving the expected insertion point.
+struct ScopedScavengeOrSpill {
+  ScopedScavengeOrSpill(const ScopedScavengeOrSpill &) = delete;
+  ScopedScavengeOrSpill(ScopedScavengeOrSpill &&) = delete;
+
+  ScopedScavengeOrSpill(MachineFunction &MF, MachineBasicBlock &MBB,
+                        MachineBasicBlock::iterator MBBI, Register &FreeReg,
+                        Register SpillCandidate, const TargetRegisterClass &RC,
+                        LiveRegUnits const &UsedRegs,
+                        BitVector const &AllocatableRegs,
+                        std::optional<int> &MaybeSpillFI)
+      : MBB(MBB), MBBI(MBBI), RC(RC), TII(static_cast<const AArch64InstrInfo &>(
+                                          *MF.getSubtarget().getInstrInfo())),
+        TRI(*MF.getSubtarget().getRegisterInfo()) {
+    FreeReg = tryScavengeRegister(UsedRegs, AllocatableRegs);
+    if (FreeReg != AArch64::NoRegister)
+      return;
+    if (!MaybeSpillFI) {
+      MachineFrameInfo &MFI = MF.getFrameInfo();
+      MaybeSpillFI = MFI.CreateSpillStackObject(TRI.getSpillSize(RC),
+                                                TRI.getSpillAlign(RC));
+    }
+    FreeReg = SpilledReg = SpillCandidate;
+    SpillFI = *MaybeSpillFI;
+    TII.storeRegToStackSlot(MBB, MBBI, SpilledReg, false, SpillFI, &RC, &TRI,
+                            Register());
+  }
+
+  bool hasSpilled() const { return SpilledReg != AArch64::NoRegister; }
+
+  ~ScopedScavengeOrSpill() {
+    if (hasSpilled())
+      TII.loadRegFromStackSlot(MBB, MBBI, SpilledReg, SpillFI, &RC, &TRI,
+                               Register());
+  }
+
+private:
+  MachineBasicBlock &MBB;
+  MachineBasicBlock::iterator MBBI;
+  const TargetRegisterClass &RC;
+  const AArch64InstrInfo &TII;
+  const TargetRegisterInfo &TRI;
+  Register SpilledReg = AArch64::NoRegister;
+  int SpillFI = -1;
+};
+
+/// Emergency stack slots for expanding SPILL_PPR_TO_ZPR_SLOT_PSEUDO and
+/// FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
+struct EmergencyStackSlots {
+  std::optional<int> ZPRSpillFI;
+  std::optional<int> PPRSpillFI;
+  std::optional<int> GPRSpillFI;
+};
+
+/// Expands:
+/// ```
+/// SPILL_PPR_TO_ZPR_SLOT_PSEUDO $p0, %stack.0, 0
+/// ```
+/// To:
+/// ```
+/// $z0 = CPY_ZPzI_B $p0, 1, 0
+/// STR_ZXI $z0, $stack.0, 0
+/// ```
+/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
+/// spilling if necessary).
+static void expandSpillPPRToZPRSlotPseudo(MachineBasicBlock &MBB,
+                                          MachineInstr &MI,
+                                          const TargetRegisterInfo &TRI,
+                                          LiveRegUnits const &UsedRegs,
+                                          BitVector const &ZPRRegs,
+                                          EmergencyStackSlots &SpillSlots) {
+  MachineFunction &MF = *MBB.getParent();
+  auto *TII =
+      static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
+
+  Register ZPredReg = AArch64::NoRegister;
+  ScopedScavengeOrSpill FindZPRReg(MF, MBB, MachineBasicBlock::iterator(MI),
+                                   ZPredReg, AArch64::Z0, AArch64::ZPRRegClass,
+                                   UsedRegs, ZPRRegs, SpillSlots.ZPRSpillFI);
+
+#ifndef NDEBUG
+  bool InPrologueOrEpilogue = MI.getFlag(MachineInstr::FrameSetup) ||
+                              MI.getFlag(MachineInstr::FrameDestroy);
+  assert((!FindZPRReg.hasSpilled() || !InPrologueOrEpilogue) &&
+         "SPILL_PPR_TO_ZPR_SLOT_PSEUDO expansion should not spill in prologue "
+         "or epilogue");
+#endif
+
+  SmallVector<MachineInstr *, 2> MachineInstrs;
+  const DebugLoc &DL = MI.getDebugLoc();
+  MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::CPY_ZPzI_B))
+                              .addReg(ZPredReg, RegState::Define)
+                              .add(MI.getOperand(0))
+                              .addImm(1)
+                              .addImm(0)
+                              .getInstr());
+  MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::STR_ZXI))
+                              .addReg(ZPredReg)
+                              .add(MI.getOperand(1))
+                              .addImm(MI.getOperand(2).getImm())
+                              .setMemRefs(MI.memoperands())
+                              .getInstr());
+  propagateFrameFlags(MI, MachineInstrs);
+}
+
+/// Expands:
+/// ```
+/// $p0 = FILL_PPR_FROM_ZPR_SLOT_PSEUDO %stack.0, 0
+/// ```
+/// To:
+/// ```
+/// $z0 = LDR_ZXI %stack.0, 0
+/// $p0 = PTRUE_B 31, implicit $vg
+/// $p0 = CMPNE_PPzZI_B $p0, $z0, 0, implicit-def $nzcv, implicit-def $nzcv
+/// ```
+/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
+/// spilling if necessary). If the status flags are in use at the point of
+/// expansion they are preserved (by moving them to/from a GPR). This may cause
+/// an additional spill if no GPR is free at the expansion point.
+static bool expandFillPPRFromZPRSlotPseudo(
+    MachineBasicBlock &MBB, MachineInstr &MI, const TargetRegisterInfo &TRI,
+    LiveRegUnits const &UsedRegs, BitVector const &ZPRRegs,
+    BitVector const &PPR3bRegs, BitVector const &GPRRegs,
+    EmergencyStackSlots &SpillSlots) {
+  MachineFunction &MF = *MBB.getParent();
+  auto *TII =
+      static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
+
+  Register ZPredReg = AArch64::NoRegister;
+  ScopedScavengeOrSpill FindZPRReg(MF, MBB, MachineBasicBlock::iterator(MI),
+                                   ZPredReg, AArch64::Z0, AArch64::ZPRRegClass,
+                                   UsedRegs, ZPRRegs, SpillSlots.ZPRSpillFI);
+
+  Register PredReg = AArch64::NoRegister;
+  std::optional<ScopedScavengeOrSpill> FindPPR3bReg;
+  if (AArch64::PPR_3bRegClass.contains(MI.getOperand(0).getReg()))
+    PredReg = MI.getOperand(0).getReg();
+  else
+    FindPPR3bReg.emplace(MF, MBB, MachineBasicBlock::iterator(MI), PredReg,
+                         AArch64::P0, AArch64::PPR_3bRegClass, UsedRegs,
+                         PPR3bRegs, SpillSlots.PPRSpillFI);
+
+  // Elide NZCV spills if we know it is not used.
+  Register NZCVSaveReg = AArch64::NoRegister;
+  bool IsNZCVUsed = !UsedRegs.available(AArch64::NZCV);
+  std::optional<ScopedScavengeOrSpill> FindGPRReg;
+  if (IsNZCVUsed)
+    FindGPRReg.emplace(MF, MBB, MachineBasicBlock::iterator(MI), NZCVSaveReg,
+                       AArch64::X0, AArch64::GPR64RegClass, UsedRegs, GPRRegs,
+                       SpillSlots.GPRSpillFI);
+
+#ifndef NDEBUG
+  bool Spilled = FindZPRReg.hasSpilled() ||
+                 (FindPPR3bReg && FindPPR3bReg->hasSpilled()) ||
+                 (FindGPRReg && FindGPRReg->hasSpilled());
+  bool InPrologueOrEpilogue = MI.getFlag(MachineInstr::FrameSetup) ||
+                              MI.getFlag(MachineInstr::FrameDestroy);
+  assert((!Spilled || !InPrologueOrEpilogue) &&
+         "FILL_PPR_FROM_ZPR_SLOT_PSEUDO expansion should not spill in prologue "
+         "or epilogue");
+#endif
+
+  SmallVector<MachineInstr *, 4> MachineInstrs;
+  const DebugLoc &DL = MI.getDebugLoc();
+  MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::LDR_ZXI))
+                              .addReg(ZPredReg, RegState::Define)
+                              .add(MI.getOperand(1))
+                              .addImm(MI.getOperand(2).getImm())
+                              .setMemRefs(MI.memoperands())
+                              .getInstr());
+  if (IsNZCVUsed)
+    MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MRS))
+                                .addReg(NZCVSaveReg, RegState::Define)
+                                .addImm(AArch64SysReg::NZCV)
+                                .addReg(AArch64::NZCV, RegState::Implicit)
+                                .getInstr());
+  MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::PTRUE_B))
+                              .addReg(PredReg, RegState::Define)
+                              .addImm(31));
+  MachineInstrs.push_back(
+      BuildMI(MBB, MI, DL, TII->get(AArch64::CMPNE_PPzZI_B))
+          .addReg(MI.getOperand(0).getReg(), RegState::Define)
+          .addReg(PredReg)
+          .addReg(ZPredReg)
+          .addImm(0)
+          .addReg(AArch64::NZCV, RegState::ImplicitDefine)
+          .getInstr());
+  if (IsNZCVUsed)
+    MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MSR))
+                                .addImm(AArch64SysReg::NZCV)
+                                .addReg(NZCVSaveReg)
+                                .addReg(AArch64::NZCV, RegState::ImplicitDefine)
+                                .getInstr());
+
+  propagateFrameFlags(MI, MachineInstrs);
+  return FindPPR3bReg && FindPPR3bReg->hasSpilled();
+}
+
+/// Expands all FILL_PPR_FROM_ZPR_SLOT_PSEUDO and SPILL_PPR_TO_ZPR_SLOT_PSEUDO
+/// operations within the MachineBasicBlock \p MBB.
+static bool expandSMEPPRToZPRSpillPseudos(MachineBasicBlock &MBB,
+                                          const TargetRegisterInfo &TRI,
+                                          BitVector const &ZPRRegs,
+                                          BitVector const &PPR3bRegs,
+                                          BitVector const &GPRRegs,
+                                          EmergencyStackSlots &SpillSlots) {
+  LiveRegUnits UsedRegs(TRI);
+  UsedRegs.addLiveOuts(MBB);
+  bool HasPPRSpills = false;
+  for (MachineInstr &MI : make_early_inc_range(reverse(MBB))) {
+    UsedRegs.stepBackward(MI);
+    switch (MI.getOpcode()) {
+    case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
+      HasPPRSpills |= expandFillPPRFromZPRSlotPseudo(
+          MBB, MI, TRI, UsedRegs, ZPRRegs, PPR3bRegs, GPRRegs, SpillSlots);
+      MI.eraseFromParent();
+      break;
+    case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
+      expandSpillPPRToZPRSlotPseudo(MBB, MI, TRI, UsedRegs, ZPRRegs,
+                                    SpillSlots);
+      MI.eraseFromParent();
+      break;
+    default:
+      break;
+    }
+  }
+
+  return HasPPRSpills;
+}
+
 void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
     MachineFunction &MF, RegScavenger *RS) const {
+
+  AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+  const TargetSubtargetInfo &TSI = MF.getSubtarget();
+  const TargetRegisterInfo &TRI = *TSI.getRegisterInfo();
+  if (AFI->hasStackFrame() && TRI.getSpillSize(AArch64::PPRRegClass) == 16) {
+    const uint32_t *CSRMask =
+        TRI.getCallPreservedMask(MF, MF.getFunction().getCallingConv());
+    const MachineFrameInfo &MFI = MF.getFrameInfo();
+    assert(MFI.isCalleeSavedInfoValid());
+
+    auto ComputeScavengeableRegisters = [&](unsigned RegClassID) {
+      BitVector ScavengeableRegs =
+          TRI.getAllocatableSet(MF, TRI.getRegClass(RegClassID));
+      if (CSRMask)
+        ScavengeableRegs.clearBitsInMask(CSRMask);
+      // TODO: Allow reusing callee-saved registers that have been saved.
----------------
MacDue wrote:

Done this, still need a little special case for `p4`, but it's simpler now :+1: 

https://github.com/llvm/llvm-project/pull/123752


More information about the llvm-commits mailing list