[llvm-branch-commits] [llvm] [CodeGen] Move rollback capabilities outside of the rematerializer (PR #184341)
Lucas Ramirez via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Mar 3 05:49:31 PST 2026
https://github.com/lucas-rami created https://github.com/llvm/llvm-project/pull/184341
The rematerializer implements support for rolling back rematerializations by modifying MIs that should normally be deleted in an attempt to make them "transparent" to other analyses. This involves:
1. setting their opcode to DBG_VALUE and
2. setting their read register operands to the sentinel register.
This approach has several drawbacks.
1. It forces the rematerializer to support tracking these "dead MIs" (even if support is optional, these data-structures have to exist).
2. It is not actually clear whether this mechanism will interact well with all other analyses. This is an issue since the intent of the rematerializer is to be usable in as many contexts as possible.
3. In practice, it has shown itself to be relatively error-prone.
This commit removes rollback support from the rematerializer and moves those capabilities to a rematerializer listener than can be instantiated on-demand and implements the same functionality on top of standard rematerializer operations. The rematerializer now actually deletes MIs that are no longer useful after rematerializations, and has support for re-creating them on-demand without requiring additional tracking on its part.
>From 7c2c5050bd9c0f1e646d15385b442bcd3b6f0619 Mon Sep 17 00:00:00 2001
From: Lucas Ramirez <lucas.rami at proton.me>
Date: Tue, 3 Mar 2026 13:00:27 +0000
Subject: [PATCH] [CodeGen] Move rollback capabilities outside of the
rematerializer
The rematerializer implements support for rolling back
rematerializations by modifying MIs that should normally be deleted in
an attempt to make them "transparent" to other analyses. This involves:
1. setting their opcode to DBG_VALUE and
2. setting their read register operands to the sentinel register.
This approach has several drawbacks.
1. It forces the rematerializer to support tracking these "dead MIs".
2. It is not actually clear whether this mechanism will interact well
with all other analyses. This is an issue since the intent of the
rematerializer is to be usable in as many contexts as possible.
3. In practice, it has shown itself to be relatively error-prone.
This commit removes rollback support from the rematerializer and moves
those capabilties to a rematerializer listener than can be instantiated
on-demand and implements the same functionnality on top of standard
rematerializer operations. The rematerializer now actually deletes MIs
that are no longer useful after rematerializations, and has support for
re-creating them on-demand without requiring additional tracking on its
part.
---
llvm/include/llvm/CodeGen/Rematerializer.h | 178 ++++++-----
llvm/lib/CodeGen/Rematerializer.cpp | 283 ++++++++++--------
llvm/unittests/CodeGen/RematerializerTest.cpp | 131 ++++++--
3 files changed, 359 insertions(+), 233 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/Rematerializer.h b/llvm/include/llvm/CodeGen/Rematerializer.h
index 11ac95fc9368d..3a42a301264b2 100644
--- a/llvm/include/llvm/CodeGen/Rematerializer.h
+++ b/llvm/include/llvm/CodeGen/Rematerializer.h
@@ -14,6 +14,7 @@
#ifndef LLVM_CODEGEN_REMATERIALIZER_H
#define LLVM_CODEGEN_REMATERIALIZER_H
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
@@ -76,18 +77,7 @@ namespace llvm {
/// The rematerializer supports rematerializing arbitrary complex DAGs of
/// registers to regions where these registers are used, with the option of
/// re-using non-root registers or their previous rematerializations instead of
-/// rematerializing them again. It also optionally supports rolling back
-/// previous rematerializations (set during analysis phase, see \ref
-/// Rematerializer::analyze) to restore the MIR state to what it was
-/// pre-rematerialization. When enabled, machine instructions defining
-/// rematerializable registers that no longer have any uses following previous
-/// rematerializations will not be deleted from the MIR; their opcode will
-/// instead be set to a DEBUG_VALUE and their read register operands set to the
-/// null register. This maintains their position in the MIR and keeps the
-/// original register alive for potential rollback while allowing other
-/// passes/analyzes (e.g., machine scheduler, live-interval analysis) to ignore
-/// them. \ref Rematerializer::commitRematerializations actually deletes those
-/// instructions when rollback is deemed unnecessary.
+/// rematerializing them again.
///
/// Throughout its lifetime, the rematerializer tracks new registers it creates
/// (which are rematerializable by construction) and their relations to other
@@ -121,10 +111,7 @@ class Rematerializer {
/// arbitrary number of regions, potentially including its own defining
/// region. When rematerializations lead to operand changes in users, a
/// register may find itself without any user left, at which point the
- /// rematerializer marks it for deletion. Its defining instruction either
- /// becomes nullptr (without rollback support) or its opcode is set to
- /// TargetOpcode::DBG_VALUE (with rollback support) until \ref
- /// Rematerializer::commitRematerializations is called.
+ /// rematerializer deletes it (setting its defining MI to nullptr).
struct Reg {
/// Single MI defining the rematerializable register.
MachineInstr *DefMI;
@@ -174,9 +161,7 @@ class Rematerializer {
std::pair<MachineInstr *, MachineInstr *>
getRegionUseBounds(unsigned UseRegion, const LiveIntervals &LIS) const;
- bool isAlive() const {
- return DefMI && DefMI->getOpcode() != TargetOpcode::DBG_VALUE;
- }
+ bool isAlive() const { return DefMI; }
private:
void addUser(MachineInstr *MI, unsigned Region);
@@ -225,6 +210,8 @@ class Rematerializer {
using RegionBoundaries =
std::pair<MachineBasicBlock::iterator, MachineBasicBlock::iterator>;
+ using RematsOf = SmallDenseSet<RegisterIdx, 4>;
+
/// Simply initializes some internal state, does not identify
/// rematerialization candidates.
Rematerializer(MachineFunction &MF,
@@ -232,11 +219,8 @@ class Rematerializer {
LiveIntervals &LIS);
/// Goes through the whole MF and identifies all rematerializable registers.
- /// When \p SupportRollback is set, rematerializations of original registers
- /// can be rolled back and original registers are maintained in the IR even
- /// when they longer have any users. Returns whether there is any
- /// rematerializable register in regions.
- bool analyze(bool SupportRollback);
+ /// Returns whether there is any rematerializable register in regions.
+ bool analyze();
/// Adds a listener to the rematerializer.
void addListener(Listener *Listen) { Listeners.push_back(Listen); }
@@ -250,12 +234,16 @@ class Rematerializer {
inline ArrayRef<Reg> getRegs() const { return Regs; };
inline unsigned getNumRegs() const { return Regs.size(); };
- inline const RegionBoundaries &getRegion(RegisterIdx RegionIdx) {
+ inline const RegionBoundaries &getRegion(RegisterIdx RegionIdx) const {
assert(RegionIdx < Regions.size() && "out of bounds");
return Regions[RegionIdx];
}
inline unsigned getNumRegions() const { return Regions.size(); }
+ /// Whether register \p RegIdx is an original register.
+ inline bool isOriginalRegister(RegisterIdx RegIdx) const {
+ return !isRematerializedRegister(RegIdx);
+ }
/// Whether register \p RegIdx is a rematerialization of some original
/// register.
inline bool isRematerializedRegister(RegisterIdx RegIdx) const {
@@ -276,10 +264,16 @@ class Rematerializer {
}
/// Returns operand indices corresponding to unrematerializable operands for
/// any register \p RegIdx.
- inline ArrayRef<unsigned> getUnrematableOprds(unsigned RegIdx) const {
+ inline ArrayRef<unsigned> getUnrematableOprds(RegisterIdx RegIdx) const {
return UnrematableOprds[getOriginOrSelf(RegIdx)];
}
+ /// If \p MI's first operand defines a register and that register is a
+ /// rematerializable register tracked by the rematerializer, returns its
+ /// index in the \ref Regs vector. Otherwise returns \ref
+ /// Rematerializer::NoReg.
+ RegisterIdx getDefRegIdx(const MachineInstr &MI) const;
+
/// When rematerializating a register (called the "root" register in this
/// context) to a given position, we must decide what to do with all its
/// rematerializable dependencies (for unrematerializable dependencies, we
@@ -356,27 +350,26 @@ class Rematerializer {
MachineBasicBlock::iterator InsertPos,
DependencyReuseInfo &DRI);
- /// Rolls back all rematerializations of original register \p RootIdx,
- /// transfering all their users back to it and permanently deleting them from
- /// the MIR. The root register is revived if it was fully rematerialized (this
- /// requires that rollback support was set at that time). Transitive
- /// dependencies of the root register that were fully rematerialized are
- /// re-vived at their original positions; this requires that rollback support
- /// was set when they were rematerialized.
- void rollbackRematsOf(RegisterIdx RootIdx);
-
- /// Rolls back register \p RematIdx (which must be a rematerialization)
- /// transfering all its users back to its origin. The latter is revived if it
- /// was fully rematerialized (this requires that rollback support was set at
- /// that time).
- void rollback(RegisterIdx RematIdx);
-
- /// Revives original register \p RootIdx at its original position in the MIR
- /// if it was fully rematerialized with rollback support set. Transitive
- /// dependencies of the root register that were fully rematerialized are
- /// revived at their original positions; this requires that rollback support
- /// was set when they were themselves rematerialized.
- void reviveRegIfDead(RegisterIdx RootIdx);
+ /// Rematerializes register \p RegIdx at \p InsertPos in \p UseRegion, adding
+ /// the new rematerializable register to the backing vector \ref Regs and
+ /// returning its index inside the vector. Sets the new register's
+ /// rematerializable dependencies to \p Dependencies (these are assumed to
+ /// already exist in the MIR) and its unrematerializable dependencies to the
+ /// same as \p RegIdx. The new register initially has no user. Since the
+ /// method appends to \ref Regs, references to elements within it should be
+ /// considered invalidated across calls to this method unless the vector can
+ /// be guaranteed to have enough space for an extra element.
+ RegisterIdx rematerializeReg(RegisterIdx RegIdx, unsigned UseRegion,
+ MachineBasicBlock::iterator InsertPos,
+ SmallVectorImpl<Reg::Dependency> &&Dependencies);
+
+ /// Re-creates a previously deleted register \p RegIdx at \p InsertPos in \p
+ /// DefRegion. \p DefReg must be the original virtual register that \p RegIdx
+ /// used to define. Sets the new register's rematerializable dependencies to
+ /// \p Dependencies (these are assumed to already exist in the MIR).
+ void recreateReg(RegisterIdx RegIdx, unsigned DefRegion,
+ MachineBasicBlock::iterator InsertPos, Register DefReg,
+ SmallVectorImpl<Reg::Dependency> &&Dependencies);
/// Transfers all users of register \p FromRegIdx in region \p UseRegion to \p
/// ToRegIdx, the latter of which must be a rematerialization of the former or
@@ -392,14 +385,16 @@ class Rematerializer {
void transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
unsigned UserRegion, MachineInstr &UserMI);
+ /// Transfers all users of register \p FromRegIdx to register \p ToRegIdx, the
+ /// latter of which must be a rematerialization of the former or have the same
+ /// origin register. Users of \p FromRegIdx must be reachable from \p
+ /// ToRegIdx.
+ void transferAllUsers(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx);
+
/// Recomputes all live intervals that have changed as a result of previous
- /// rematerializations/rollbacks.
+ /// rematerializations.
void updateLiveIntervals();
- /// Deletes unused rematerialized registers that were left in the MIR to
- /// support rollback.
- void commitRematerializations();
-
/// Determines whether (sub-)register operand \p MO has the same value at
/// all \p Uses as at \p MO. This implies that it is also available at all \p
/// Uses according to its current live interval.
@@ -444,9 +439,8 @@ class Rematerializer {
/// Indicates the original register index of each rematerialization, in the
/// order in which they are created. The size of the vector indicates the
/// total number of rematerializations ever created, including those that were
- /// deleted or rolled back.
+ /// deleted.
SmallVector<RegisterIdx> Origins;
- using RematsOf = SmallDenseSet<RegisterIdx, 4>;
/// Maps original register indices to their currently alive
/// rematerializations. In practice most registers don't have
/// rematerializations so this is represented as a map to lower memory cost.
@@ -459,15 +453,13 @@ class Rematerializer {
/// Parent block of each region, in order.
SmallVector<MachineBasicBlock *> RegionMBB;
/// Set of registers whose live-range may have changed during past
- /// rematerializations/rollbacks.
+ /// rematerializations.
DenseSet<RegisterIdx> LISUpdates;
- /// Keys are fully rematerialized registers whose rematerializations are
- /// currently rollback-able. Values map register machine operand indices to
- /// their original register.
- DenseMap<RegisterIdx, DenseMap<unsigned, Register>> Revivable;
- /// Whether all rematerializations of registers identified during the last
- /// analysis phase will be rollback-able.
- bool SupportRollback = false;
+
+ /// Common post-processing step after creating a new register \p RematRegIdx
+ /// at \p InsertPos based on register \p ModelRegIdx.
+ void postRematerialization(RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx,
+ MachineBasicBlock::iterator InsertPos);
/// During the analysis phase, creates a \ref Rematerializer::Reg object for
/// virtual register \p VirtRegIdx if it is rematerializable. \p MIRegion maps
@@ -484,19 +476,6 @@ class Rematerializer {
/// defined once.
bool isMIRematerializable(const MachineInstr &MI) const;
- /// Rematerializes register \p RegIdx at \p InsertPos in \p UseRegion, adding
- /// the new rematerializable register to the backing vector \ref Regs and
- /// returning its index inside the vector. Sets the new registers'
- /// rematerializable dependencies to \p Dependencies (these are assumed to
- /// already exist in the MIR) and its unrematerializable dependencies to the
- /// same as \p RegIdx. The new register initially has no user. Since the
- /// method appends to \ref Regs, references to elements within it should be
- /// considered invalidated across calls to this method unless the vector can
- /// be guaranteed to have enough space for an extra element.
- RegisterIdx rematerializeReg(RegisterIdx RegIdx, unsigned UseRegion,
- MachineBasicBlock::iterator InsertPos,
- SmallVectorImpl<Reg::Dependency> &&Dependencies);
-
/// Implementation of \ref Rematerializer::transferUser that doesn't update
/// register users.
void transferUserImpl(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
@@ -510,12 +489,51 @@ class Rematerializer {
/// Deletes rematerializable register \p RegIdx from the DAG and relevant
/// internal state.
void deleteReg(RegisterIdx RegIdx);
+};
- /// If \p MI's first operand defines a register and that register is a
- /// rematerializable register tracked by the rematerializer, returns its
- /// index in the \ref Regs vector. Otherwise returns \ref
- /// Rematerializer::NoReg.
- RegisterIdx getDefRegIdx(const MachineInstr &MI) const;
+/// Rematerializer listener with the ability to re-create deleted registers and
+/// rollback rematerializations. Starts recording register deletions and
+/// rematerializations as soon as it is attached to the rematerializer.
+class Rollbacker : public Rematerializer::Listener {
+public:
+ Rollbacker() = default;
+
+ /// Re-creates all deleted registers and rolls back all rematerializations
+ /// that were recorded.
+ void rollback(Rematerializer &Remater);
+
+private:
+ struct RollbackInfo {
+ /// Original register.
+ Register DefReg;
+ /// Original defining region.
+ unsigned DefRegion;
+ /// Original dependencies.
+ SmallVector<Rematerializer::Reg::Dependency, 2> Dependencies;
+ /// Position to re-insert the defining MI before in case of rollback.
+ MachineBasicBlock::iterator InsertPos;
+ /// If \ref InsertPos points to an MI defining a rematerializable register,
+ /// stores its index. Otherwise \ref Rematerializer::NoReg.
+ RegisterIdx NextRegIdx;
+
+ RollbackInfo(const Rematerializer &Remater, RegisterIdx RegIdx);
+ };
+
+ /// Original registers that have been deleted, in order of deletion.
+ MapVector<RegisterIdx, RollbackInfo> DeadRegs;
+ /// Registers which have been rematerialized (from original index to
+ /// rematerialized index).
+ DenseMap<RegisterIdx, Rematerializer::RematsOf> Rematerializations;
+ /// Used to block further recording of events whenver we are actively rolling
+ /// back.
+ bool RollingBack = false;
+
+ void beforeRegRematerialized(const Rematerializer &Remater,
+ RegisterIdx RegIdx,
+ RegisterIdx RematRegIdx) override;
+
+ void beforeRegDeleted(const Rematerializer &Remater,
+ RegisterIdx RegIdx) override;
};
} // namespace llvm
diff --git a/llvm/lib/CodeGen/Rematerializer.cpp b/llvm/lib/CodeGen/Rematerializer.cpp
index e07e7d0745d91..63a3fbffe2eac 100644
--- a/llvm/lib/CodeGen/Rematerializer.cpp
+++ b/llvm/lib/CodeGen/Rematerializer.cpp
@@ -127,86 +127,6 @@ Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion,
return LastNewIdx;
}
-void Rematerializer::rollbackRematsOf(RegisterIdx RootIdx) {
- auto Remats = Rematerializations.find(RootIdx);
- if (Remats == Rematerializations.end())
- return;
-
- LLVM_DEBUG(dbgs() << "Rolling back rematerializations of " << printID(RootIdx)
- << '\n');
-
- reviveRegIfDead(RootIdx);
- // All of the rematerialization's users must use the revived register.
- for (RegisterIdx RematRegIdx : Remats->getSecond()) {
- for (const auto &[UseRegion, RegionUsers] : Regs[RematRegIdx].Uses)
- transferRegionUsers(RematRegIdx, RootIdx, UseRegion);
- }
- Rematerializations.erase(RootIdx);
-
- LLVM_DEBUG(dbgs() << "** Rolled back rematerializations of "
- << printID(RootIdx) << '\n');
-}
-
-void Rematerializer::rollback(RegisterIdx RematIdx) {
- assert(getReg(RematIdx).DefMI && !Revivable.contains(RematIdx) &&
- "cannot rollback dead register");
- const RegisterIdx OriginRegIdx = getOriginOf(RematIdx);
- reviveRegIfDead(OriginRegIdx);
- for (const auto &[UseRegion, RegionUsers] : Regs[RematIdx].Uses)
- transferRegionUsers(RematIdx, OriginRegIdx, UseRegion);
-}
-
-void Rematerializer::reviveRegIfDead(RegisterIdx RootIdx) {
- if (getReg(RootIdx).isAlive())
- return;
- assert(Revivable.contains(RootIdx) && "not revivable");
-
- // Traverse the root's dependency DAG depth-first to find the set of
- // registers we must revive and a legal order to revive them in.
- SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
- SmallSetVector<RegisterIdx, 8> ReviveOrder;
- ReviveOrder.insert(RootIdx);
- do {
- // All dependencies of a revived register need to be alive too.
- const Reg &ReviveReg = getReg(DepDAG.pop_back_val());
- for (const Reg::Dependency &Dep : ReviveReg.Dependencies) {
- // We may have already seen the dependency in the dependency DAG.
- if (ReviveOrder.contains(Dep.RegIdx))
- continue;
-
- // Dead dependencies need to be revived.
- Reg &DepReg = Regs[Dep.RegIdx];
- if (!DepReg.isAlive()) {
- assert(Revivable.contains(Dep.RegIdx) && "not revivable");
- ReviveOrder.insert(Dep.RegIdx);
- DepDAG.push_back(Dep.RegIdx);
- }
-
- // All dependencies get a new user (the revived register).
- DepReg.addUser(ReviveReg.DefMI, ReviveReg.DefRegion);
- LISUpdates.insert(Dep.RegIdx);
- }
- } while (!DepDAG.empty());
-
- for (RegisterIdx RegIdx : reverse(ReviveOrder)) {
- // Pick any rematerialization to retrieve the original opcode from.
- Reg &ReviveReg = Regs[RegIdx];
- assert(Rematerializations.contains(RegIdx) && "no remats");
- RegisterIdx RematIdx = *Rematerializations.at(RegIdx).begin();
- ReviveReg.DefMI->setDesc(getReg(RematIdx).DefMI->getDesc());
- for (const auto &[MOIdx, Reg] : Revivable.at(RegIdx))
- ReviveReg.DefMI->getOperand(MOIdx).setReg(Reg);
- Revivable.erase(RegIdx);
- LISUpdates.insert(RegIdx);
-
- LLVM_DEBUG({
- dbgs() << "** Revived " << printID(RegIdx) << " @ ";
- LIS.getInstructionIndex(*ReviveReg.DefMI).print(dbgs());
- dbgs() << '\n';
- });
- }
-}
-
void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
unsigned UserRegion, MachineInstr &UserMI) {
assert(getReg(FromRegIdx).Uses.contains(UserRegion) && "no user in region");
@@ -235,6 +155,18 @@ void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx,
deleteRegIfUnused(FromRegIdx);
}
+void Rematerializer::transferAllUsers(RegisterIdx FromRegIdx,
+ RegisterIdx ToRegIdx) {
+ Reg &FromReg = Regs[FromRegIdx], &ToReg = Regs[ToRegIdx];
+ for (const auto &[UseRegion, RegionUsers] : FromReg.Uses) {
+ for (MachineInstr *UserMI : RegionUsers)
+ transferUserImpl(FromRegIdx, ToRegIdx, *UserMI);
+ ToReg.addUsers(RegionUsers, UseRegion);
+ }
+ FromReg.Uses.clear();
+ deleteRegIfUnused(FromRegIdx);
+}
+
void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx,
RegisterIdx ToRegIdx,
MachineInstr &UserMI) {
@@ -268,7 +200,7 @@ void Rematerializer::updateLiveIntervals() {
DenseSet<Register> SeenUnrematRegs;
for (RegisterIdx RegIdx : LISUpdates) {
const Reg &UpdateReg = getReg(RegIdx);
- assert((UpdateReg.DefMI || Revivable.contains(RegIdx)) && "dead reg");
+ assert(UpdateReg.isAlive() && "dead register");
Register DefReg = UpdateReg.getDefReg();
if (LIS.hasInterval(DefReg))
@@ -299,12 +231,6 @@ void Rematerializer::updateLiveIntervals() {
LISUpdates.clear();
}
-void Rematerializer::commitRematerializations() {
- for (auto &[RegIdx, _] : Revivable)
- deleteReg(RegIdx);
- Revivable.clear();
-}
-
bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO,
ArrayRef<SlotIndex> Uses) const {
if (Uses.empty())
@@ -361,7 +287,7 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
// A deleted register's dependencies may be deletable too.
const Reg &DeleteReg = getReg(DepDAG.pop_back_val());
for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
- // All dependencies loose a user (the delete register).
+ // All dependencies loose a user (the deleted register).
Reg &DepReg = Regs[Dep.RegIdx];
DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion);
if (DepReg.Uses.empty()) {
@@ -373,27 +299,16 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
for (RegisterIdx RegIdx : reverse(DeleteOrder)) {
Reg &DeleteReg = Regs[RegIdx];
- LIS.removeInterval(DeleteReg.getDefReg());
+
+ // It is possible that the defined register we are deleting doesn't have an
+ // interval yet if the LIS hasn't been updated since it was created.
+ Register DefReg = DeleteReg.getDefReg();
+ if (LIS.hasInterval(DefReg))
+ LIS.removeInterval(DefReg);
LISUpdates.erase(RegIdx);
- const bool IsRematerializedReg = isRematerializedRegister(RegIdx);
- if (SupportRollback && !IsRematerializedReg) {
- // Replace all read registers with the null one to prevent them from
- // showing up in use-lists, which is disallowed for debug instructions in
- // live interval calculations. Store mappings between operand indices and
- // original registers for potential rollback.
- DenseMap<unsigned, Register> &RegMap =
- Revivable.try_emplace(RegIdx).first->getSecond();
- for (auto [Idx, MO] : enumerate(DeleteReg.DefMI->operands())) {
- if (MO.isReg() && MO.readsReg()) {
- RegMap.insert({Idx, MO.getReg()});
- MO.setReg(Register());
- }
- }
- DeleteReg.DefMI->setDesc(TII.get(TargetOpcode::DBG_VALUE));
- } else {
- deleteReg(RegIdx);
- }
- if (IsRematerializedReg) {
+
+ deleteReg(RegIdx);
+ if (isRematerializedRegister(RegIdx)) {
// Delete rematerialized register from its origin's rematerializations.
RematsOf &OriginRemats = Rematerializations.at(getOriginOf(RegIdx));
assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link");
@@ -444,7 +359,7 @@ Rematerializer::Rematerializer(MachineFunction &MF,
#endif
}
-bool Rematerializer::analyze(bool SupportRollback) {
+bool Rematerializer::analyze() {
Regs.clear();
UnrematableOprds.clear();
Origins.clear();
@@ -452,8 +367,6 @@ bool Rematerializer::analyze(bool SupportRollback) {
RegionMBB.clear();
RegToIdx.clear();
LISUpdates.clear();
- Revivable.clear();
- this->SupportRollback = SupportRollback;
if (Regions.empty())
return false;
@@ -613,17 +526,63 @@ RegisterIdx Rematerializer::rematerializeReg(
*FromReg.DefMI);
NewReg.DefMI = &*std::prev(InsertPos);
RegToIdx.insert({NewDefReg, NewRegIdx});
+ postRematerialization(RegIdx, NewRegIdx, InsertPos);
+ LLVM_DEBUG(dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
+ << printRematReg(NewRegIdx) << '\n');
+ return NewRegIdx;
+}
+
+void Rematerializer::recreateReg(
+ RegisterIdx RegIdx, unsigned DefRegion,
+ MachineBasicBlock::iterator InsertPos, Register DefReg,
+ SmallVectorImpl<Reg::Dependency> &&Dependencies) {
+ assert(RegToIdx.contains(DefReg) && "unknown defined register");
+ assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register");
+ assert(!getReg(RegIdx).DefMI && "register is still alive");
+
+ Reg &OriginReg = Regs[RegIdx];
+ OriginReg.DefRegion = DefRegion;
+ OriginReg.Dependencies = std::move(Dependencies);
+
+ // Re-establish the link between origin and rematerialization if necessary.
+ const bool RecreateOriginalReg = isOriginalRegister(RegIdx);
+ if (!RecreateOriginalReg)
+ Rematerializations[getOriginOf(RegIdx)].insert(RegIdx);
+
+ // Rematerialize from one of the existing rematerializations or from the
+ // origin. We expect at least one to exist, otherwise it would mean the value
+ // held by the original register is no longer available anywhere in the MF.
+ RegisterIdx ModelRegIdx;
+ if (RecreateOriginalReg) {
+ assert(Rematerializations.contains(RegIdx) && "expected remats");
+ ModelRegIdx = *Rematerializations.at(RegIdx).begin();
+ } else {
+ assert(getReg(getOriginOf(RegIdx)).DefMI && "expected alive origin");
+ ModelRegIdx = getOriginOf(RegIdx);
+ }
+ const MachineInstr &ModelDefMI = *getReg(ModelRegIdx).DefMI;
+
+ TII.reMaterialize(*RegionMBB[DefRegion], InsertPos, DefReg, 0, ModelDefMI);
+ OriginReg.DefMI = &*std::prev(InsertPos);
+ postRematerialization(ModelRegIdx, RegIdx, InsertPos);
+ LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as "
+ << printRematReg(RegIdx) << '\n');
+}
- // Update the DAG.
- RegionBoundaries &Bounds = Regions[UseRegion];
- if (Bounds.first == std::next(MachineBasicBlock::iterator(NewReg.DefMI)))
- Bounds.first = NewReg.DefMI;
- LIS.InsertMachineInstrInMaps(*NewReg.DefMI);
- LISUpdates.insert(NewRegIdx);
+void Rematerializer::postRematerialization(
+ RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx,
+ MachineBasicBlock::iterator InsertPos) {
+
+ // The start of the new register's region may have changed.
+ Reg &ModelReg = Regs[ModelRegIdx], &RematReg = Regs[RematRegIdx];
+ LIS.InsertMachineInstrInMaps(*RematReg.DefMI);
+ MachineBasicBlock::iterator &RegionBegin = Regions[RematReg.DefRegion].first;
+ if (RegionBegin == std::next(MachineBasicBlock::iterator(RematReg.DefMI)))
+ RegionBegin = RematReg.DefMI;
// Replace dependencies as needed in the rematerialized MI. All dependencies
// of the latter gain a new user.
- auto ZipedDeps = zip_equal(FromReg.Dependencies, NewReg.Dependencies);
+ auto ZipedDeps = zip_equal(ModelReg.Dependencies, RematReg.Dependencies);
for (const auto &[OldDep, NewDep] : ZipedDeps) {
assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch");
LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": "
@@ -632,22 +591,15 @@ RegisterIdx Rematerializer::rematerializeReg(
Reg &NewDepReg = Regs[NewDep.RegIdx];
if (OldDep.RegIdx != NewDep.RegIdx) {
- Register OldDefReg = FromReg.DefMI->getOperand(OldDep.MOIdx).getReg();
- NewReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0,
- TRI);
+ Register OldDefReg = ModelReg.DefMI->getOperand(OldDep.MOIdx).getReg();
+ RematReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0,
+ TRI);
LISUpdates.insert(OldDep.RegIdx);
}
- NewDepReg.addUser(NewReg.DefMI, UseRegion);
+ NewDepReg.addUser(RematReg.DefMI, RematReg.DefRegion);
LISUpdates.insert(NewDep.RegIdx);
}
-
- notifyListeners(&Listener::newRegCreated, NewRegIdx);
-
- LLVM_DEBUG({
- dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
- << printRematReg(NewRegIdx) << '\n';
- });
- return NewRegIdx;
+ notifyListeners(&Listener::newRegCreated, RematRegIdx);
}
std::pair<MachineInstr *, MachineInstr *>
@@ -803,3 +755,74 @@ Printable Rematerializer::printUser(const MachineInstr *MI) const {
LIS.getInstructionIndex(*MI).print(dbgs());
});
}
+
+Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater,
+ RegisterIdx RegIdx) {
+ const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
+ DefReg = Reg.getDefReg();
+ DefRegion = Reg.DefRegion;
+ Dependencies = Reg.Dependencies;
+
+ InsertPos = std::next(Reg.DefMI->getIterator());
+ if (InsertPos != Reg.DefMI->getParent()->end())
+ NextRegIdx = Remater.getDefRegIdx(*InsertPos);
+ else
+ NextRegIdx = Rematerializer::NoReg;
+}
+
+void Rollbacker::beforeRegRematerialized(const Rematerializer &Remater,
+ RegisterIdx RegIdx,
+ RegisterIdx RematRegIdx) {
+ if (RollingBack)
+ return;
+ Rematerializations[Remater.getOriginOrSelf(RegIdx)].insert(RematRegIdx);
+}
+
+void Rollbacker::beforeRegDeleted(const Rematerializer &Remater,
+ RegisterIdx RegIdx) {
+ if (RollingBack || Remater.isRematerializedRegister(RegIdx))
+ return;
+ DeadRegs.try_emplace(RegIdx, Remater, RegIdx);
+}
+
+void Rollbacker::rollback(Rematerializer &Remater) {
+ RollingBack = true;
+
+ // Re-create deleted registers.
+ for (auto &[RegIdx, Info] : DeadRegs) {
+ assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead");
+
+ // The MI that was originally just after the MI defining the register we
+ // are trying to re-create may have been moved or deleted. In such cases,
+ // we can re-create at that MI's own insert position (and apply the same
+ // logic recursively).
+ MachineBasicBlock::iterator InsertPos = Info.InsertPos;
+ RegisterIdx NextRegIdx = Info.NextRegIdx;
+ while (NextRegIdx != Rematerializer::NoReg) {
+ const auto *NextRegRollback = DeadRegs.find(NextRegIdx);
+ if (NextRegRollback == DeadRegs.end())
+ break;
+ InsertPos = NextRegRollback->second.InsertPos;
+ NextRegIdx = NextRegRollback->second.NextRegIdx;
+ }
+ Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg,
+ std::move(Info.Dependencies));
+ }
+
+ // Rollback rematerializations.
+ for (const auto &[RegIdx, RematsOf] : Rematerializations) {
+ for (RegisterIdx RematRegIdx : RematsOf) {
+ // It is possible that rematerializations were deleted. Their users would
+ // have been transfered to some other rematerialization so we can safely
+ // ignore them. Original registers that were deleted were just re-created
+ // so we do not need to check for that.
+ if (Remater.getReg(RematRegIdx).isAlive())
+ Remater.transferAllUsers(RematRegIdx, RegIdx);
+ }
+ }
+
+ Remater.updateLiveIntervals();
+ DeadRegs.clear();
+ Rematerializations.clear();
+ RollingBack = false;
+}
diff --git a/llvm/unittests/CodeGen/RematerializerTest.cpp b/llvm/unittests/CodeGen/RematerializerTest.cpp
index ca2bc3b86d47c..49bbfb9ebfe34 100644
--- a/llvm/unittests/CodeGen/RematerializerTest.cpp
+++ b/llvm/unittests/CodeGen/RematerializerTest.cpp
@@ -80,8 +80,7 @@ class RematerializerTest : public testing::Test {
MAM.registerPass([&] { return MachineModuleAnalysis(*MMI); });
}
- bool parseMIRAndInit(StringRef MIRCode, StringRef FunName,
- bool SupportRollback) {
+ bool parseMIRAndInit(StringRef MIRCode, StringRef FunName) {
SMDiagnostic Diagnostic;
std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
MIR = createMIRParser(std::move(MBuffer), Context);
@@ -122,7 +121,7 @@ class RematerializerTest : public testing::Test {
}
Remater = std::make_unique<Rematerializer>(*MF, *Regions, LIS);
- Remater->analyze(SupportRollback);
+ Remater->analyze();
return true;
}
@@ -197,10 +196,11 @@ body: |
S_ENDPGM 0
...
)";
- ASSERT_TRUE(
- parseMIRAndInit(MIR, "TreeRematRollback", /*SupportRollback=*/true));
+ ASSERT_TRUE(parseMIRAndInit(MIR, "TreeRematRollback"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
+ Rollbacker Rollbacker;
+ Remater.addListener(&Rollbacker);
// MBB/Region indices.
const unsigned MBB0 = 0, MBB1 = 1;
@@ -218,15 +218,16 @@ body: |
Remater.rematerializeToRegion(/*RootIdx=*/Add23, /*UseRegion=*/MBB1, DRI);
Remater.updateLiveIntervals();
- // None of the original registers have any users, but they still are in the
- // MIR because we enabled rollback support.
+ // None of the original registers have any users left.
EXPECT_NO_USERS(Cst0);
EXPECT_NO_USERS(Cst1);
EXPECT_NO_USERS(Add01);
EXPECT_NO_USERS(Cst3);
EXPECT_NO_USERS(Add23);
- // Copies of all MIs were inserted into the second MBB.
+ // Copies of all MIs were inserted into the second MBB. Original registers
+ // were deleted.
+ RegionSizes[MBB0] -= 5;
RegionSizes[MBB1] += 5;
ASSERT_REGION_SIZES(RegionSizes);
NumRegs += 5;
@@ -234,7 +235,8 @@ body: |
}
// After rollback all rematerializations are removed from the MIR.
- Remater.rollbackRematsOf(Add23);
+ Rollbacker.rollback(Remater);
+ RegionSizes[MBB0] += 5;
RegionSizes[MBB1] -= 5;
ASSERT_REGION_SIZES(RegionSizes);
@@ -253,6 +255,7 @@ body: |
EXPECT_NO_USERS(Add23);
// Only immediate dependencies are copied to the second MBB.
+ RegionSizes[MBB0] -= 3;
RegionSizes[MBB1] += 3;
ASSERT_REGION_SIZES(RegionSizes);
NumRegs += 3;
@@ -260,7 +263,8 @@ body: |
}
// After rollback all rematerializations are removed from the MIR.
- Remater.rollbackRematsOf(Add23);
+ Rollbacker.rollback(Remater);
+ RegionSizes[MBB0] += 3;
RegionSizes[MBB1] -= 3;
ASSERT_REGION_SIZES(RegionSizes);
@@ -302,21 +306,15 @@ body: |
EXPECT_NO_USERS(Add23);
EXPECT_NUM_USERS(RematAdd23, 1);
+ RegionSizes[MBB0] -= 3;
RegionSizes[MBB1] += 3;
ASSERT_REGION_SIZES(RegionSizes);
NumRegs += 3;
ASSERT_EQ(Remater.getNumRegs(), NumRegs);
}
- // This time don't rollback; commit the rematerializations. This finally
- // deletes unused registers in the first block. However the number of
- // registers tracked by the rematerializer doesn't change.
+ // This time don't rollback.
Remater.updateLiveIntervals();
- Remater.commitRematerializations();
- RegionSizes[MBB0] -= 3;
- ASSERT_REGION_SIZES(RegionSizes);
- ASSERT_EQ(Remater.getNumRegs(), NumRegs);
-
EXPECT_TRUE(getMF().verify());
}
@@ -345,8 +343,7 @@ body: |
S_ENDPGM 0
...
)";
- ASSERT_TRUE(
- parseMIRAndInit(MIR, "MultiRegionsRemat", /*SupportRollback=*/false));
+ ASSERT_TRUE(parseMIRAndInit(MIR, "MultiRegionsRemat"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
@@ -416,7 +413,7 @@ body: |
S_ENDPGM 0
...
)";
- ASSERT_TRUE(parseMIRAndInit(MIR, "MultiStep", /*SupportRollback=*/false));
+ ASSERT_TRUE(parseMIRAndInit(MIR, "MultiStep"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
@@ -497,7 +494,7 @@ body: |
S_ENDPGM 0
...
)";
- ASSERT_TRUE(parseMIRAndInit(MIR, "EmptyRegion", /*SupportRollback=*/false));
+ ASSERT_TRUE(parseMIRAndInit(MIR, "EmptyRegion"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
@@ -566,7 +563,7 @@ body: |
S_ENDPGM 0
...
)";
- ASSERT_TRUE(parseMIRAndInit(MIR, "SubReg", /*SupportRollback=*/false));
+ ASSERT_TRUE(parseMIRAndInit(MIR, "SubReg"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
@@ -592,3 +589,91 @@ body: |
Remater.updateLiveIntervals();
EXPECT_TRUE(getMF().verify());
}
+
+/// Checks that rollback works as expected when the rollback listener is added
+/// mid-rematerializations.
+TEST_F(RematerializerTest, Rollback) {
+ StringRef MIR = R"(
+name: Rollback
+tracksRegLiveness: true
+machineFunctionInfo:
+ isEntryFunction: true
+body: |
+ bb.0:
+ %0:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 0, implicit $exec, implicit $mode
+ %1:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 1, implicit $exec, implicit $mode
+
+ bb.1:
+ S_NOP 0, implicit %0, implicit %1
+
+ bb.2:
+ S_NOP 0, implicit %0, implicit %1
+ S_ENDPGM 0
+)";
+ ASSERT_TRUE(parseMIRAndInit(MIR, "Rollback"));
+ Rematerializer &Remater = getRematerializer();
+ Rematerializer::DependencyReuseInfo DRI;
+
+ // MBB/Region indices.
+ const unsigned MBB0 = 0, MBB1 = 1, MBB2 = 2;
+ SmallVector<unsigned, 4> RegionSizes{2, 1, 1};
+ ASSERT_REGION_SIZES(RegionSizes);
+
+ // Indices of rematerializable registers.
+ unsigned NumRegs = 0;
+ const RegisterIdx Cst0 = NumRegs++, Cst1 = NumRegs++;
+ ASSERT_EQ(Remater.getNumRegs(), NumRegs);
+
+ // Rematerialize %0 to MBB1, taking one user from the original register.
+ RegisterIdx RematCst0MBB1 = Remater.rematerializeToRegion(Cst0, MBB1, DRI);
+ RegionSizes[MBB1] += 1;
+ ASSERT_REGION_SIZES(RegionSizes);
+ NumRegs += 1;
+ ASSERT_EQ(Remater.getNumRegs(), NumRegs);
+
+ Rollbacker Rollback;
+ Remater.addListener(&Rollback);
+
+ // Rematerialize %0 to MBB2 amd %1 to MBB1/MBB2; each rematerialization ends
+ // up with a single user and both original registers are deleted.
+ RegisterIdx RematCst0MBB2 =
+ Remater.rematerializeToRegion(Cst0, MBB2, DRI.clear());
+ RegisterIdx RematCst1MBB1 =
+ Remater.rematerializeToRegion(Cst1, MBB1, DRI.clear());
+ RegisterIdx RematCst1MBB2 =
+ Remater.rematerializeToRegion(Cst1, MBB2, DRI.clear());
+
+ RegionSizes[MBB0] -= 2;
+ RegionSizes[MBB1] += 1;
+ RegionSizes[MBB2] += 2;
+ ASSERT_REGION_SIZES(RegionSizes);
+ NumRegs += 3;
+ ASSERT_EQ(Remater.getNumRegs(), NumRegs);
+
+ EXPECT_NO_USERS(Cst0);
+ EXPECT_NO_USERS(Cst1);
+ EXPECT_NUM_USERS(RematCst0MBB1, 1);
+ EXPECT_NUM_USERS(RematCst0MBB2, 1);
+ EXPECT_NUM_USERS(RematCst1MBB1, 1);
+ EXPECT_NUM_USERS(RematCst1MBB2, 1);
+
+ // Rollback all changes since the rollbacker was added. The first
+ // rematerialization of %0 to MBB1 happened before so it is not rolled back.
+ // However %0 is re-created because it was deleted after.
+ Rollback.rollback(Remater);
+
+ RegionSizes[MBB0] += 2;
+ RegionSizes[MBB1] -= 1;
+ RegionSizes[MBB2] -= 2;
+ ASSERT_REGION_SIZES(RegionSizes);
+ ASSERT_EQ(Remater.getNumRegs(), NumRegs);
+
+ EXPECT_NUM_USERS(Cst0, 1);
+ EXPECT_NUM_USERS(Cst1, 2);
+ EXPECT_NUM_USERS(RematCst0MBB1, 1);
+ EXPECT_NO_USERS(RematCst0MBB2);
+ EXPECT_NO_USERS(RematCst1MBB1);
+ EXPECT_NO_USERS(RematCst1MBB2);
+
+ EXPECT_TRUE(getMF().verify());
+}
More information about the llvm-branch-commits
mailing list