[llvm] [CodeGen][Spill2Reg] Initial patch (PR #118832)

Wei Xiao via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 31 19:59:06 PST 2024


================
@@ -0,0 +1,534 @@
+//===- Spill2Reg.cpp - Spill To Register Optimization ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+/// \file This file implements Spill2Reg, an optimization which selectively
+/// replaces spills/reloads to/from the stack with register copies to/from the
+/// vector register file. This works even on targets where load/stores have
+/// similar latency to register copies because it can free up memory units which
+/// helps avoid back-end stalls.
+///
+//===----------------------------------------------------------------------===//
+
+#include "AllocationOrder.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/CodeGen/LiveRegUnits.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "Spill2Reg"
+STATISTIC(NumSpill2RegInstrs, "Number of spills/reloads replaced by spill2reg");
+
+namespace {
+
+class Spill2Reg : public MachineFunctionPass {
+public:
+  static char ID;
+  Spill2Reg() : MachineFunctionPass(ID) {
+    initializeSpill2RegPass(*PassRegistry::getPassRegistry());
+  }
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  void releaseMemory() override;
+  bool runOnMachineFunction(MachineFunction &) override;
+
+private:
+  /// Holds data for spills and reloads.
+  struct StackSlotDataEntry {
+    /// This is set to true to disable code generation for the spills/reloads
+    /// that we collected in this entry.
+    bool Disable = false;
+    /// Indentation for the dump() methods.
+    static constexpr const int DumpInd = 2;
+
+    /// The data held for each spill/reload.
+    struct MIData {
+      MIData(MachineInstr *MI, const MachineOperand *MO, unsigned SpillBits)
+          : MI(MI), MO(MO), SpillBits(SpillBits) {}
+      /// The Spill/Reload instruction.
+      MachineInstr *MI = nullptr;
+      /// The operand being spilled/reloaded.
+      const MachineOperand *MO = nullptr;
+      /// The size of the data spilled/reloaded in bits. This occasionally
+      /// differs across accesses to the same stack slot.
+      unsigned SpillBits = 0;
+#ifndef NDEBUG
+      LLVM_DUMP_METHOD virtual void dump() const;
+      virtual ~MIData() {}
+#endif
+    };
+
+    struct MIDataWithLiveIn : public MIData {
+      MIDataWithLiveIn(MachineInstr *MI, const MachineOperand *MO,
+                       unsigned SpillBits)
+          : MIData(MI, MO, SpillBits) {}
+      /// We set this to false to mark the vector register associated to this
+      /// reload as definitely not live-in. This is useful in blocks with both
+      /// spill and reload of the same stack slot, like in the example:
+      /// \verbatim
+      ///  bb:
+      ///    spill %stack.0
+      ///    reload %stack.0
+      /// \endverbatim
+      /// This information is used during `updateLiveIns()`. We are collecting
+      /// this information during `collectSpillsAndReloads()` because we are
+      /// already walking through the code there. Otherwise we would need to
+      /// walk throught the code again in `updateLiveIns()` just to check for
+      /// other spills in the block, which would waste compilation time.
+      bool IsLiveIn = true;
+#ifndef NDEBUG
+      LLVM_DUMP_METHOD virtual void dump() const override;
+#endif
+    };
+    SmallVector<MIData, 1> Spills;
+    SmallVector<MIDataWithLiveIn, 1> Reloads;
+
+    /// \Returns the physical register being spilled.
+    Register getSpilledReg() const { return Spills.front().MO->getReg(); }
+#ifndef NDEBUG
+    LLVM_DUMP_METHOD void dump() const;
+#endif
+  };
+  /// Look for candidates for spill2reg. These candidates are in places with
+  /// high memory unit contention. Fills in StackSlotData.
+  void collectSpillsAndReloads();
+  /// \Returns if \p MI is profitable to apply spill-to-reg by checking whether
+  /// this would remove pipeline bubbles.
+  bool isProfitable(const MachineInstr *MI) const;
+  /// \Returns true if any stack-based spill/reload in \p Entry is profitable
+  /// to replace with a reg-based spill/reload.
+  bool allAccessesProfitable(const StackSlotDataEntry &Entry) const;
+  /// Look for a free physical register in \p LRU of reg class \p RegClass.
+  std::optional<MCRegister>
+  tryGetFreePhysicalReg(const TargetRegisterClass *RegClass,
+                        const LiveRegUnits &LRU);
+  /// Helper for generateCode(). It eplaces stack spills or reloads with movs
+  /// to \p LI.reg().
+  void replaceStackWithReg(StackSlotDataEntry &Entry, Register VectorReg);
+  /// Updates the live-ins of MBBs after we emit the new spill2reg instructions
+  /// and the vector registers become live from register spills to reloads.
+  void updateLiveIns(StackSlotDataEntry &Entry, MCRegister VectorReg);
+  /// Updates \p LRU with the liveness of physical registers around the spills
+  /// and reloads in \p Entry.
+  void calculateLiveRegs(StackSlotDataEntry &Entry, LiveRegUnits &LRU);
+  /// Replace spills to stack with spills to registers (same for reloads).
+  void generateCode();
+  /// Cleanup data structures once the pass is finished.
+  void cleanup();
+  /// The main entry point for this pass.
+  bool run();
+
+  /// Map from a stack slot to the corresponding spills and reloads.
+  DenseMap<int, StackSlotDataEntry> StackSlotData;
+  /// The registers used by each block (from LiveRegUnits). This is needed for
+  /// finding free physical registers in the generateCode().
+  DenseMap<const MachineBasicBlock *, LiveRegUnits> LRUs;
+
+  MachineFunction *MF = nullptr;
+  MachineRegisterInfo *MRI = nullptr;
+  MachineFrameInfo *MFI = nullptr;
+  const TargetInstrInfo *TII = nullptr;
+  const TargetRegisterInfo *TRI = nullptr;
+  RegisterClassInfo RegClassInfo;
+};
+
+} // namespace
+
+void Spill2Reg::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.setPreservesCFG();
+  MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+void Spill2Reg::releaseMemory() {}
+
+bool Spill2Reg::runOnMachineFunction(MachineFunction &MFn) {
+  // Disable if NoImplicitFloat to avoid emitting instrs that use vectors.
+  if (MFn.getFunction().hasFnAttribute(Attribute::NoImplicitFloat))
+    return false;
+
+  MF = &MFn;
+  MRI = &MF->getRegInfo();
+  MFI = &MF->getFrameInfo();
+  TII = MF->getSubtarget().getInstrInfo();
+  TRI = MF->getSubtarget().getRegisterInfo();
+  // Enable only if the target supports the appropriate vector instruction set.
+  if (!TII->targetSupportsSpill2Reg(&MF->getSubtarget()))
+    return false;
+
+  RegClassInfo.runOnMachineFunction(MFn);
+
+  return run();
+}
+
+char Spill2Reg::ID = 0;
+
+char &llvm::Spill2RegID = Spill2Reg::ID;
+
+void Spill2Reg::collectSpillsAndReloads() {
+  /// The checks for collecting spills and reloads are identical, so we keep
+  /// them here in one place. Return true if we should not collect this.
+  auto SkipEntry = [this](int StackSlot, Register Reg) -> bool {
+    // If not a spill/reload stack slot.
+    if (!MFI->isSpillSlotObjectIndex(StackSlot))
+      return true;
+    // Check size in bits.
+    if (!TII->isLegalToSpill2Reg(Reg, TRI, MRI))
+      return true;
+    return false;
+  };
+
+  // Collect spills and reloads and associate them to stack slots.
+  // If any spill/reload for a stack slot is found not to be eligible for
+  // spill-to-reg, then that stack slot is disabled.
+  for (MachineBasicBlock &MBB : *MF) {
+    // Initialize AccumMBBLRU for keeping track of physical registers used
+    // across the whole MBB.
+    LiveRegUnits AccumMBBLRU(*TRI);
+    AccumMBBLRU.addLiveOuts(MBB);
+
+    // Collect spills/reloads
+    for (MachineInstr &MI : llvm::reverse(MBB)) {
+      // Update the LRU state as we move upwards.
+      AccumMBBLRU.accumulate(MI);
+
+      int StackSlot;
+      if (const MachineOperand *MO = TII->isStoreToStackSlotMO(MI, StackSlot)) {
+        MachineInstr *Spill = &MI;
+        auto &Entry = StackSlotData[StackSlot];
+        if (SkipEntry(StackSlot, MO->getReg())) {
+          Entry.Disable = true;
+          continue;
+        }
+        unsigned SpillBits = TRI->getRegSizeInBits(MO->getReg(), *MRI);
+        Entry.Spills.emplace_back(Spill, MO, SpillBits);
+
+        // If any of the reloads collected so far is in the same MBB then mark
+        // it as non live-in. This is used in `updateLiveIns()` where we update
+        // the liveins of MBBs to include the new vector register. Doing this
+        // now avoids an MBB walk in `updateLiveIns()` which should save
+        // compilation time.
+        for (auto &MID : Entry.Reloads)
+          if (MID.MI->getParent() == &MBB)
+            MID.IsLiveIn = false;
+      } else if (const MachineOperand *MO =
+                     TII->isLoadFromStackSlotMO(MI, StackSlot)) {
+        MachineInstr *Reload = &MI;
+        auto &Entry = StackSlotData[StackSlot];
+        if (SkipEntry(StackSlot, MO->getReg())) {
+          Entry.Disable = true;
+          continue;
+        }
+        assert(Reload->getRestoreSize(TII) && "Expected reload");
+        unsigned SpillBits = TRI->getRegSizeInBits(MO->getReg(), *MRI);
+        Entry.Reloads.emplace_back(Reload, MO, SpillBits);
+      } else {
+        // This should capture uses of the stack in instructions that access
+        // memory (e.g., folded spills/reloads) and non-memory instructions,
+        // like x86 LEA.
+        for (const MachineOperand &MO : MI.operands())
+          if (MO.isFI()) {
+            int StackSlot = MO.getIndex();
+            auto &Entry = StackSlotData[StackSlot];
+            Entry.Disable = true;
+          }
+      }
+    }
+
+    LRUs.insert(std::make_pair(&MBB, AccumMBBLRU));
+  }
+}
+
+bool Spill2Reg::isProfitable(const MachineInstr *MI) const {
+  return TII->isSpill2RegProfitable(MI, TRI, MRI);
+}
+
+bool Spill2Reg::allAccessesProfitable(const StackSlotDataEntry &Entry) const {
+  auto IsProfitable = [this](const auto &MID) { return isProfitable(MID.MI); };
+  return llvm::all_of(Entry.Spills, IsProfitable) &&
+         llvm::all_of(Entry.Reloads, IsProfitable);
+}
+
+std::optional<MCRegister>
+Spill2Reg::tryGetFreePhysicalReg(const TargetRegisterClass *RegClass,
+                                 const LiveRegUnits &LRU) {
+  auto Order = RegClassInfo.getOrder(RegClass);
+  for (auto I = Order.begin(), E = Order.end(); I != E; ++I) {
+    MCRegister PhysVectorReg = *I;
+    if (LRU.available(PhysVectorReg))
+      return PhysVectorReg;
+  }
+  return std::nullopt;
+}
+
+/// Perform a bottom-up depth-first traversal from \p MBB at \p MI towards its
+/// predecessors blocks. Visited marks the visited blocks. \p Fn is the
+/// callback function called in pre-order. If \p Fn returns true we stop the
+/// traversal.
+// TODO: Use df_iterator
+static void DFS(MachineBasicBlock *MBB, DenseSet<MachineBasicBlock *> &Visited,
+                std::function<bool(MachineBasicBlock *)> Fn) {
+  // Skip visited to avoid infinite loops.
+  if (Visited.count(MBB))
+    return;
+  Visited.insert(MBB);
+
+  // Preorder.
+  if (Fn(MBB))
+    return;
+
+  // Depth-first across predecessors.
+  for (MachineBasicBlock *PredMBB : MBB->predecessors())
+    DFS(PredMBB, Visited, Fn);
+}
+
+void Spill2Reg::updateLiveIns(StackSlotDataEntry &Entry, MCRegister VectorReg) {
+  // Collect the parent MBBs of Spills for fast lookup.
+  DenseSet<MachineBasicBlock *> SpillMBBs(Entry.Spills.size());
+  DenseSet<MachineInstr *> Spills(Entry.Spills.size());
+  for (const auto &Data : Entry.Spills) {
+    SpillMBBs.insert(Data.MI->getParent());
+    Spills.insert(Data.MI);
+  }
+
+  auto AddLiveInIfRequired = [VectorReg, &SpillMBBs](MachineBasicBlock *MBB) {
+    // If there is a spill in this MBB then we don't need to add a live-in.
+    // This works even if there is a reload above the spill, like this:
+    //   reload stack.0
+    //   spill  stack.0
+    // because the live-in due to the reload is handled at a separate walk.
+    if (SpillMBBs.count(MBB))
+      // Return true to stop the recursion.
+      return true;
+    // If there are no spills in this block then the register is live-in.
+    if (!MBB->isLiveIn(VectorReg))
+      MBB->addLiveIn(VectorReg);
+    // Return false to continue the recursion.
+    return false;
+  };
+
+  // Update the MBB live-ins. These are used for the live regs calculation.
+  DenseSet<MachineBasicBlock *> Visited;
+  for (const auto &ReloadData : Entry.Reloads) {
+    MachineInstr *Reload = ReloadData.MI;
+    MachineBasicBlock *MBB = Reload->getParent();
+    // From a previous walk in MBB we know whether the reload is live-in, or
+    // whether the value comes from an earlier spill in the same MBB.
+    if (ReloadData.IsLiveIn) {
+      if (!MBB->isLiveIn(VectorReg))
+        MBB->addLiveIn(VectorReg);
+    }
+    for (MachineBasicBlock *PredMBB : Reload->getParent()->predecessors())
+      DFS(PredMBB, Visited, AddLiveInIfRequired);
+  }
+}
+
+// Replace stack-based spills/reloads with register-based ones.
+void Spill2Reg::replaceStackWithReg(StackSlotDataEntry &Entry,
+                                    Register VectorReg) {
+  for (StackSlotDataEntry::MIData &SpillData : Entry.Spills) {
+    MachineInstr *StackSpill = SpillData.MI;
+    assert(SpillData.MO->isReg() && "Expected register MO");
+    Register OldReg = SpillData.MO->getReg();
+
+    TII->spill2RegInsertToVectorReg(
+        VectorReg, OldReg, SpillData.SpillBits, StackSpill->getParent(),
+        /*InsertBeforeIt=*/StackSpill->getIterator(), TRI, &MF->getSubtarget());
+
+    // Mark VectorReg as live in the instr's BB.
+    LRUs[StackSpill->getParent()].addReg(VectorReg);
+
+    // Spill to stack is no longer needed.
+    StackSpill->eraseFromParent();
+    assert(OldReg.isPhysical() && "Otherwise we need to removeInterval()");
+  }
+
+  for (StackSlotDataEntry::MIData &ReloadData : Entry.Reloads) {
+    MachineInstr *StackReload = ReloadData.MI;
+    assert(ReloadData.MO->isReg() && "Expected Reg MO");
+    Register OldReg = ReloadData.MO->getReg();
+
+    TII->spill2RegExtractFromVectorReg(
+        OldReg, VectorReg, ReloadData.SpillBits, StackReload->getParent(),
+        /*InsertBeforeIt=*/StackReload->getIterator(), TRI,
+        &MF->getSubtarget());
+
+    // Mark VectorReg as live in the instr's BB.
+    LRUs[StackReload->getParent()].addReg(VectorReg);
+
+    // Reload from stack is no longer needed.
+    StackReload->eraseFromParent();
+    assert(OldReg.isPhysical() && "Otherwise we need to removeInterval()");
+  }
+}
+
+void Spill2Reg::calculateLiveRegs(StackSlotDataEntry &Entry,
+                                  LiveRegUnits &LRU) {
+  // Collect the parent MBBs of Spills for fast lookup.
+  DenseSet<MachineBasicBlock *> SpillMBBs(Entry.Spills.size());
+  DenseSet<MachineInstr *> Spills(Entry.Spills.size());
+  for (const auto &Data : Entry.Spills) {
+    SpillMBBs.insert(Data.MI->getParent());
+    Spills.insert(Data.MI);
+  }
+
+  /// Walks up the instructions in \p Reload's block, stopping at a spill if
+  /// found. \Returns true if a spill was found, false otherwise.
+  auto AccumulateLRUUntilSpillFn = [&Spills, &SpillMBBs](MachineInstr *Reload,
+                                                         LiveRegUnits &LRU) {
+    MachineBasicBlock *MBB = Reload->getParent();
+    bool IsSpillBlock = SpillMBBs.count(MBB);
+    // Add all MBB's live-outs.
+    LRU.addLiveOuts(*MBB);
----------------
williamweixiao wrote:

I agree that LRU.accumulate() is correct but conservative.
consider below example:
```
spill rax
...
reload rax
xmm1 = ...
xmm2 = ...
```

we can actually reuse xmm1/xmm2 for the spill.

https://github.com/llvm/llvm-project/pull/118832


More information about the llvm-commits mailing list