[llvm-branch-commits] [llvm] [CodeGen] Fix incorrect rematerialization rollback order (PR #197576)
Lucas Ramirez via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jun 9 07:16:03 PDT 2026
https://github.com/lucas-rami updated https://github.com/llvm/llvm-project/pull/197576
>From bdd060475f53929dc007965eece438111550c2fc Mon Sep 17 00:00:00 2001
From: Lucas Ramirez <lucas.rami at proton.me>
Date: Tue, 21 Apr 2026 22:50:29 +0000
Subject: [PATCH 1/2] [CodeGen] Fix incorrect rematerialization rollback order
This fixes an issue in the rematerializer's rollbacker wherein adjacent
MIs that were deleted through rematerializations would
sometimes---depending on the exact order in which they were
deleted---not be re-created in their original
pre-rematerialization order. While this does not impact correctness
(i.e., use-def relations are always honored), this goes against the
rollbacker's intent to re-create the MIR exactly as it was
pre-rematerializations (up to slot index changes).
---
llvm/include/llvm/CodeGen/Rematerializer.h | 8 +-
llvm/lib/CodeGen/Rematerializer.cpp | 61 ++++++++++----
llvm/unittests/CodeGen/RematerializerTest.cpp | 82 +++++++++++++------
3 files changed, 107 insertions(+), 44 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/Rematerializer.h b/llvm/include/llvm/CodeGen/Rematerializer.h
index 96c00c59f3186..bbebc6e0ab8b4 100644
--- a/llvm/include/llvm/CodeGen/Rematerializer.h
+++ b/llvm/include/llvm/CodeGen/Rematerializer.h
@@ -538,11 +538,17 @@ class Rollbacker : public Rematerializer::Listener {
/// stores its index. Otherwise equals \ref Rematerializer::NoReg.
RegisterIdx NextRegIdx;
- RollbackInfo(const Rematerializer &Remater, RegisterIdx RegIdx);
+ RollbackInfo(const Rematerializer::Reg &Reg, RegisterIdx NextRegIdx);
};
/// Original registers that have been deleted, in order of deletion.
MapVector<RegisterIdx, RollbackInfo> DeadRegs;
+ /// When there are two ajacent rematerializable MIs in the original
+ /// instruction order and the later one is deleted, stores a mapping from the
+ /// earlier one's register index to the later one's register index. If the
+ /// earlier one is then deleted, this makes it possible to rematerialize it at
+ /// the correct position after the later one is re-created.
+ DenseMap<RegisterIdx, RegisterIdx> AdjacentDeletedMIs;
/// Registers which have been rematerialized (from original index to
/// rematerialized index).
DenseMap<RegisterIdx, Rematerializer::RematsOf> Rematerializations;
diff --git a/llvm/lib/CodeGen/Rematerializer.cpp b/llvm/lib/CodeGen/Rematerializer.cpp
index c4edface1e27f..2377cbfdf7a07 100644
--- a/llvm/lib/CodeGen/Rematerializer.cpp
+++ b/llvm/lib/CodeGen/Rematerializer.cpp
@@ -762,19 +762,11 @@ Printable Rematerializer::printUser(const MachineInstr *MI,
});
}
-Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater,
- RegisterIdx RegIdx) {
- const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
- DefReg = Reg.getDefReg();
- DefRegion = Reg.DefRegion;
- Dependencies = Reg.Dependencies;
-
- InsertPos = std::next(Reg.DefMI->getIterator());
- if (InsertPos != Reg.DefMI->getParent()->end())
- NextRegIdx = Remater.getDefRegIdx(*InsertPos);
- else
- NextRegIdx = Rematerializer::NoReg;
-}
+Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer::Reg &Reg,
+ RegisterIdx NextRegIdx)
+ : DefReg(Reg.getDefReg()), DefRegion(Reg.DefRegion),
+ Dependencies(Reg.Dependencies),
+ InsertPos(std::next(Reg.DefMI->getIterator())), NextRegIdx(NextRegIdx) {}
void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater,
RegisterIdx RegIdx) {
@@ -787,7 +779,34 @@ void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater,
RegisterIdx RegIdx) {
if (RollingBack || Remater.isRematerializedRegister(RegIdx))
return;
- DeadRegs.try_emplace(RegIdx, Remater, RegIdx);
+ const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
+
+ // If the MI that was originally after the one defining the register is
+ // rematerializable, derive its corresponding register index.
+ RegisterIdx NextRegIdx = Rematerializer::NoReg;
+ auto KnownNextRegIdx = AdjacentDeletedMIs.find(RegIdx);
+ if (KnownNextRegIdx != AdjacentDeletedMIs.end()) {
+ NextRegIdx = KnownNextRegIdx->second;
+ } else {
+ MachineBasicBlock::iterator InsertPos = std::next(Reg.DefMI->getIterator());
+ if (InsertPos != Reg.DefMI->getParent()->end())
+ NextRegIdx = Remater.getDefRegIdx(*InsertPos);
+ }
+ DeadRegs.try_emplace(RegIdx, Reg, NextRegIdx);
+
+ // Keep track of adjacent rematerializable MIs to allow re-creation in the
+ // same exact order regardless of rematerialization order.
+ MachineBasicBlock::iterator DefMI = Reg.DefMI->getIterator();
+ if (Reg.DefMI->getParent()->begin() != DefMI) {
+ RegisterIdx PrevRegIdx = Remater.getDefRegIdx(*std::prev(DefMI));
+ if (PrevRegIdx != Rematerializer::NoReg) {
+ // The key might already be in the map, which indicates that there were
+ // originally instructions in between the MIs which have since then been
+ // deleted. In these cases we leave the value untouched, as we care about
+ // MIs that were adjacent in the original instruction order.
+ AdjacentDeletedMIs.try_emplace(PrevRegIdx, RegIdx);
+ }
+ }
}
void Rollbacker::rollback(Rematerializer &Remater) {
@@ -804,11 +823,16 @@ void Rollbacker::rollback(Rematerializer &Remater) {
MachineBasicBlock::iterator InsertPos = Info.InsertPos;
RegisterIdx NextRegIdx = Info.NextRegIdx;
while (NextRegIdx != Rematerializer::NoReg) {
- const auto *NextRegRollback = DeadRegs.find(NextRegIdx);
- if (NextRegRollback == DeadRegs.end())
+ const Rematerializer::Reg &MaybeAliveReg = Remater.getReg(NextRegIdx);
+ // When the next MI is alive (including when it was dead and already
+ // re-created), we must use it as the position to insert before.
+ if (MaybeAliveReg.isAlive()) {
+ InsertPos = MaybeAliveReg.DefMI->getIterator();
break;
- InsertPos = NextRegRollback->second.InsertPos;
- NextRegIdx = NextRegRollback->second.NextRegIdx;
+ }
+ RollbackInfo NextInfo = DeadRegs.at(NextRegIdx);
+ InsertPos = NextInfo.InsertPos;
+ NextRegIdx = NextInfo.NextRegIdx;
}
Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg,
std::move(Info.Dependencies));
@@ -829,5 +853,6 @@ void Rollbacker::rollback(Rematerializer &Remater) {
Remater.updateLiveIntervals();
DeadRegs.clear();
Rematerializations.clear();
+ AdjacentDeletedMIs.clear();
RollingBack = false;
}
diff --git a/llvm/unittests/CodeGen/RematerializerTest.cpp b/llvm/unittests/CodeGen/RematerializerTest.cpp
index 2211470297b65..1a2e0373b8cb4 100644
--- a/llvm/unittests/CodeGen/RematerializerTest.cpp
+++ b/llvm/unittests/CodeGen/RematerializerTest.cpp
@@ -635,33 +635,65 @@ TEST_F(RematerializerTest, RollbackInvalidInsertPos) {
const unsigned MBB0 = 0, MBB1 = 1;
const RegisterIdx Cst0 = 0, Cst1 = 1, Cst2 = 2, Cst3 = 3;
- // Rematerialize %0 to MBB1, deleting the original register.
- RW->rematerializeToRegion(Cst0, MBB1, DRI);
- RW.moveMIs(MBB0, MBB1, 1);
- ASSERT_REGION_SIZES();
-
- // Rematerialize %1 to MBB1, deleting the original register.
- RW->rematerializeToRegion(Cst1, MBB1, DRI.clear());
- RW.moveMIs(MBB0, MBB1, 1);
- ASSERT_REGION_SIZES();
+ auto RematToMBB1 = [&](RegisterIdx RegIdx) -> void {
+ // Rematerialize %RegIdx to MBB1, deleting the original register.
+ RW->rematerializeToRegion(RegIdx, MBB1, DRI.clear());
+ RW.moveMIs(MBB0, MBB1, 1);
+ ASSERT_REGION_SIZES();
+ };
- // Rematerialize %2 to MBB1, deleting the original register.
- RW->rematerializeToRegion(Cst2, MBB1, DRI.clear());
- RW.moveMIs(MBB0, MBB1, 1);
- ASSERT_REGION_SIZES();
+ auto GetNextMI = [&](MachineInstr *MI) -> MachineInstr * {
+ return &*std::next(MI->getIterator());
+ };
- // Now rollback and check for correct instruction order in the original
- // defining region.
- Rollback.rollback(*RW);
- RW.moveMIs(MBB1, MBB0, 3);
- ASSERT_REGION_SIZES();
+ auto RollbackAndCheckOriginalOrder = [&]() -> void {
+ // Rollback and check for correct instruction order in the original
+ // defining region. The asserts on region sizes ensure that all original
+ // registers were indeed deleted and will be re-created in the original
+ // region.
+ Rollback.rollback(*RW);
+ RW.moveMIs(MBB1, MBB0, 3);
+ ASSERT_REGION_SIZES();
+
+ MachineInstr *DefCst0 = RW->getReg(Cst0).DefMI;
+ MachineInstr *DefCst1 = RW->getReg(Cst1).DefMI;
+ MachineInstr *DefCst2 = RW->getReg(Cst2).DefMI;
+ MachineInstr *DefCst3 = RW->getReg(Cst3).DefMI;
+ EXPECT_EQ(GetNextMI(DefCst0), DefCst1);
+ EXPECT_EQ(GetNextMI(DefCst1), DefCst2);
+ EXPECT_EQ(GetNextMI(DefCst2), DefCst3);
+ };
- MachineInstr &DefCst0 = *RW->getReg(Cst0).DefMI;
- MachineInstr &DefCst1 = *RW->getReg(Cst1).DefMI;
- MachineInstr &DefCst2 = *RW->getReg(Cst2).DefMI;
- MachineInstr &DefCst3 = *RW->getReg(Cst3).DefMI;
- EXPECT_EQ(std::next(DefCst0.getIterator()), DefCst1.getIterator());
- EXPECT_EQ(std::next(DefCst1.getIterator()), DefCst2.getIterator());
- EXPECT_EQ(std::next(DefCst2.getIterator()), DefCst3.getIterator());
+ // Test every possible rematerialization order.
+
+ RematToMBB1(Cst0);
+ RematToMBB1(Cst1);
+ RematToMBB1(Cst2);
+ RollbackAndCheckOriginalOrder();
+
+ RematToMBB1(Cst0);
+ RematToMBB1(Cst2);
+ RematToMBB1(Cst1);
+ RollbackAndCheckOriginalOrder();
+
+ RematToMBB1(Cst1);
+ RematToMBB1(Cst0);
+ RematToMBB1(Cst2);
+ RollbackAndCheckOriginalOrder();
+
+ RematToMBB1(Cst1);
+ RematToMBB1(Cst2);
+ RematToMBB1(Cst0);
+ RollbackAndCheckOriginalOrder();
+
+ RematToMBB1(Cst2);
+ RematToMBB1(Cst0);
+ RematToMBB1(Cst1);
+ RollbackAndCheckOriginalOrder();
+
+ RematToMBB1(Cst2);
+ RematToMBB1(Cst1);
+ RematToMBB1(Cst0);
+ RollbackAndCheckOriginalOrder();
});
}
>From 69e176310e4f746089e3dee91b8698555af5a3e5 Mon Sep 17 00:00:00 2001
From: Lucas Ramirez <lucas.rami at proton.me>
Date: Mon, 8 Jun 2026 11:06:39 +0000
Subject: [PATCH 2/2] Change rollback method to reduce tracking need
---
llvm/include/llvm/CodeGen/Rematerializer.h | 111 +++++++-----
llvm/lib/CodeGen/Rematerializer.cpp | 170 ++++++++++--------
llvm/unittests/CodeGen/RematerializerTest.cpp | 72 ++++++++
3 files changed, 235 insertions(+), 118 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/Rematerializer.h b/llvm/include/llvm/CodeGen/Rematerializer.h
index bbebc6e0ab8b4..f441433ac6ceb 100644
--- a/llvm/include/llvm/CodeGen/Rematerializer.h
+++ b/llvm/include/llvm/CodeGen/Rematerializer.h
@@ -14,15 +14,11 @@
#ifndef LLVM_CODEGEN_REMATERIALIZER_H
#define LLVM_CODEGEN_REMATERIALIZER_H
-#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
-#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
-#include <iterator>
namespace llvm {
@@ -185,10 +181,11 @@ class Rematerializer {
virtual void rematerializerNoteRegCreated(const Rematerializer &Remater,
RegisterIdx NewRegIdx) {}
- /// Called juste before register \p RegIdx is deleted from the MIR. At this
+ /// Called just before register \p RegIdx is deleted from the MIR. At this
/// point the register still exists in the MIR but no longer has any user.
- virtual void rematerializerNoteRegDeleted(const Rematerializer &Remater,
- RegisterIdx RegIdx) {}
+ virtual void
+ rematerializerNoteRegWillBeDeleted(const Rematerializer &Remater,
+ RegisterIdx RegIdx) {}
virtual ~Listener() = default;
@@ -368,14 +365,12 @@ class Rematerializer {
MachineBasicBlock::iterator InsertPos,
SmallVectorImpl<Reg::Dependency> &&Dependencies);
- /// Re-creates a previously deleted register \p RegIdx before \p InsertPos in
- /// \p DefRegion. \p DefReg must be the original virtual register that \p
- /// RegIdx used to define. Sets the new register's rematerializable
- /// dependencies to \p Dependencies (these are assumed to already exist in the
- /// MIR).
- void recreateReg(RegisterIdx RegIdx, unsigned DefRegion,
- MachineBasicBlock::iterator InsertPos, Register DefReg,
- SmallVectorImpl<Reg::Dependency> &&Dependencies);
+ /// Re-creates a previously deleted register \p RegIdx before \p InsertPos,
+ /// which must be in the register's original defining region. \p DefReg must
+ /// be the original virtual register that \p RegIdx used to define.
+ /// Dependencies are assumed to already exist in the MIR.
+ void recreateReg(RegisterIdx RegIdx, MachineBasicBlock::iterator InsertPos,
+ Register DefReg);
/// Transfers all users of register \p FromRegIdx in region \p UseRegion to \p
/// ToRegIdx, the latter of which must be a rematerialization of the former or
@@ -432,9 +427,9 @@ class Rematerializer {
Listen->rematerializerNoteRegCreated(*this, RegIdx);
}
- void noteRegDeleted(RegisterIdx RegIdx) const {
+ void noteRegWillBeDeleted(RegisterIdx RegIdx) const {
for (Listener *Listen : Listeners)
- Listen->rematerializerNoteRegDeleted(*this, RegIdx);
+ Listen->rematerializerNoteRegWillBeDeleted(*this, RegIdx);
}
/// Rematerializable registers identified since the rematerializer's creation,
@@ -494,7 +489,9 @@ class Rematerializer {
/// Deletes register \p RootIdx if it no longer has any user. If the register
/// is deleted, recursively deletes any of its transitive rematerializable
- /// dependencies that no longer have users as a result.
+ /// dependencies that no longer have users as a result. In case of recursive
+ /// deletion, all of a register's users are always deleted before the register
+ /// itself.
void deleteRegIfUnused(RegisterIdx RootIdx);
/// Deletes rematerializable register \p RegIdx from the DAG and relevant
@@ -516,45 +513,71 @@ class Rollbacker : public Rematerializer::Listener {
void rematerializerNoteRegCreated(const Rematerializer &Remater,
RegisterIdx RegIdx) override;
- void rematerializerNoteRegDeleted(const Rematerializer &Remater,
- RegisterIdx RegIdx) override;
+ void rematerializerNoteRegWillBeDeleted(const Rematerializer &Remater,
+ RegisterIdx RegIdx) override;
private:
- struct RollbackInfo {
+ struct DeadReg {
+ /// Register index.
+ RegisterIdx Idx;
/// Original register.
Register DefReg;
- /// Original defining region.
- unsigned DefRegion;
- /// Original dependencies.
- SmallVector<Rematerializer::Reg::Dependency, 2> Dependencies;
- /// Position to re-create the register before in case of rollback. This
- /// becomes invalid if it originally points to an MI that is deleted later
- /// as a consequence of other rematerializations. In such cases \ref
- /// NextRegIdx is guaranteed to be an actual register index from which the
- /// rollback logic will determine a valid insert position before which to
- /// re-create this register.
- MachineBasicBlock::iterator InsertPos;
- /// If \ref InsertPos points to an MI defining a rematerializable register,
- /// stores its index. Otherwise equals \ref Rematerializer::NoReg.
- RegisterIdx NextRegIdx;
-
- RollbackInfo(const Rematerializer::Reg &Reg, RegisterIdx NextRegIdx);
+ /// Original definition of the register. The underlying MI no longer exist
+ /// at rollback time, but may be referenced as re-creation position for
+ /// previously deleted registers.
+ MachineInstr *DefMI;
+
+ DeadReg(RegisterIdx Idx, const Rematerializer &Remater)
+ : Idx(Idx), DefReg(Remater.getReg(Idx).getDefReg()),
+ DefMI(Remater.getReg(Idx).DefMI) {}
};
+ /// An insertion position in the MIR. The pointer should be interpreted as:
+ /// - a MachineInstr* if the int is 0/false (insert before the MI).
+ /// - a MachineBasicBlock* if the int is 1/true (insert at the MBB's end).
+ using InsertBeforePos = PointerIntPair<void *, 1, bool>;
+
/// Original registers that have been deleted, in order of deletion.
- MapVector<RegisterIdx, RollbackInfo> DeadRegs;
- /// When there are two ajacent rematerializable MIs in the original
- /// instruction order and the later one is deleted, stores a mapping from the
- /// earlier one's register index to the later one's register index. If the
- /// earlier one is then deleted, this makes it possible to rematerialize it at
- /// the correct position after the later one is re-created.
- DenseMap<RegisterIdx, RegisterIdx> AdjacentDeletedMIs;
+ SmallVector<DeadReg> DeadRegs;
+ /// Re-creation positions for all original registers that have been deleted,
+ /// in register deletion order. A position is either a MachineInstr* that
+ /// existed in the MIR at the time the rollbacker was attached to the
+ /// rematerializer, or a MachineBasicBlock*.
+ SmallVector<InsertBeforePos> Positions;
+ /// Maps all re-creation positions that exist in \ref Positions to the indices
+ /// of elements holding that position in the vector.
+ DenseMap<InsertBeforePos, SmallDenseSet<unsigned, 1>> PosToIdx;
/// Registers which have been rematerialized (from original index to
/// rematerialized index).
DenseMap<RegisterIdx, Rematerializer::RematsOf> Rematerializations;
/// Used to block further recording of events whenver we are actively rolling
/// back.
bool RollingBack = false;
+
+ InsertBeforePos makePos(MachineInstr *MI) const {
+ return InsertBeforePos(MI, false);
+ }
+ InsertBeforePos makePos(MachineBasicBlock *MBB) const {
+ return InsertBeforePos(MBB, true);
+ }
+ InsertBeforePos makePos(MachineBasicBlock::iterator It,
+ MachineBasicBlock *MBB) const {
+ if (It == MBB->end())
+ return makePos(MBB);
+ return makePos(&*It);
+ }
+
+ /// Whether \p MI would be deleted if we were to rollback later. These are MIs
+ /// defining rematerializable registers whose creation has been recorded by
+ /// the rollbacker.
+ bool isRollbackableMI(const MachineInstr &MI,
+ const Rematerializer &Remater) const;
+
+ /// Switches all positions that point to \p MI to \p It in the \ref Positions
+ /// vector, and updates \ref PosToIdx accordingly. This is used when it
+ /// becomes known that \p MI is about to be permanently deleted from the MIR
+ /// and thus becomes an invalid re-creation position.
+ void invalidatePosition(MachineInstr *MI, MachineBasicBlock::iterator It);
};
} // namespace llvm
diff --git a/llvm/lib/CodeGen/Rematerializer.cpp b/llvm/lib/CodeGen/Rematerializer.cpp
index 2377cbfdf7a07..26e468cd00475 100644
--- a/llvm/lib/CodeGen/Rematerializer.cpp
+++ b/llvm/lib/CodeGen/Rematerializer.cpp
@@ -13,7 +13,6 @@
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/Rematerializer.h"
-#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/CodeGen/LiveIntervals.h"
@@ -281,23 +280,22 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
// Traverse the root's dependency DAG depth-first to find the set of registers
// we can delete and a legal order to delete them in.
SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
- SmallSetVector<RegisterIdx, 8> DeleteOrder;
- DeleteOrder.insert(RootIdx);
+ SmallVector<RegisterIdx, 8> DeleteOrder{RootIdx};
do {
// A deleted register's dependencies may be deletable too.
const Reg &DeleteReg = getReg(DepDAG.pop_back_val());
for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
- // All dependencies loose a user (the deleted register).
+ // All dependencies lose a user (the deleted register).
Reg &DepReg = Regs[Dep.RegIdx];
DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion);
if (DepReg.Uses.empty()) {
- DeleteOrder.insert(Dep.RegIdx);
+ DeleteOrder.push_back(Dep.RegIdx);
DepDAG.push_back(Dep.RegIdx);
}
}
} while (!DepDAG.empty());
- for (RegisterIdx RegIdx : reverse(DeleteOrder)) {
+ for (RegisterIdx RegIdx : DeleteOrder) {
Reg &DeleteReg = Regs[RegIdx];
// It is possible that the defined register we are deleting doesn't have an
@@ -322,7 +320,7 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
}
void Rematerializer::deleteReg(RegisterIdx RegIdx) {
- noteRegDeleted(RegIdx);
+ noteRegWillBeDeleted(RegIdx);
Reg &DeleteReg = Regs[RegIdx];
assert(DeleteReg.DefMI && "register was already deleted");
@@ -533,17 +531,14 @@ RegisterIdx Rematerializer::rematerializeReg(
return NewRegIdx;
}
-void Rematerializer::recreateReg(
- RegisterIdx RegIdx, unsigned DefRegion,
- MachineBasicBlock::iterator InsertPos, Register DefReg,
- SmallVectorImpl<Reg::Dependency> &&Dependencies) {
+void Rematerializer::recreateReg(RegisterIdx RegIdx,
+ MachineBasicBlock::iterator InsertPos,
+ Register DefReg) {
assert(RegToIdx.contains(DefReg) && "unknown defined register");
assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register");
assert(!getReg(RegIdx).DefMI && "register is still alive");
Reg &OriginReg = Regs[RegIdx];
- OriginReg.DefRegion = DefRegion;
- OriginReg.Dependencies = std::move(Dependencies);
// Re-establish the link between origin and rematerialization if necessary.
const bool RecreateOriginalReg = isOriginalRegister(RegIdx);
@@ -563,7 +558,8 @@ void Rematerializer::recreateReg(
}
const MachineInstr &ModelDefMI = *getReg(ModelRegIdx).DefMI;
- TII.reMaterialize(*RegionMBB[DefRegion], InsertPos, DefReg, 0, ModelDefMI);
+ TII.reMaterialize(*RegionMBB[OriginReg.DefRegion], InsertPos, DefReg, 0,
+ ModelDefMI);
OriginReg.DefMI = &*std::prev(InsertPos);
postRematerialization(ModelRegIdx, RegIdx, InsertPos);
LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as "
@@ -762,80 +758,75 @@ Printable Rematerializer::printUser(const MachineInstr *MI,
});
}
-Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer::Reg &Reg,
- RegisterIdx NextRegIdx)
- : DefReg(Reg.getDefReg()), DefRegion(Reg.DefRegion),
- Dependencies(Reg.Dependencies),
- InsertPos(std::next(Reg.DefMI->getIterator())), NextRegIdx(NextRegIdx) {}
-
void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater,
RegisterIdx RegIdx) {
if (RollingBack)
return;
+ assert(Remater.isRematerializedRegister(RegIdx) && "only remats are created");
Rematerializations[Remater.getOriginOf(RegIdx)].insert(RegIdx);
}
-void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater,
- RegisterIdx RegIdx) {
- if (RollingBack || Remater.isRematerializedRegister(RegIdx))
+void Rollbacker::rematerializerNoteRegWillBeDeleted(
+ const Rematerializer &Remater, RegisterIdx RegIdx) {
+ if (RollingBack)
+ return;
+
+ // Find a valid re-creation position after the register's definition.
+ MachineInstr *DefMI = Remater.getReg(RegIdx).DefMI;
+ MachineBasicBlock *ParentMBB = DefMI->getParent();
+ MachineBasicBlock::iterator ValidPos = std::next(DefMI->getIterator());
+ while (ValidPos != ParentMBB->end() && isRollbackableMI(*ValidPos, Remater))
+ ValidPos = std::next(ValidPos);
+
+ if (Remater.isRematerializedRegister(RegIdx)) {
+ // Rematerializations will not be re-created. Previously deleted registers
+ // that reference this register's defining instruction as their re-creation
+ // position should instead be re-created at a valid position after the
+ // deleted MI.
+ invalidatePosition(DefMI, ValidPos);
return;
- const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
-
- // If the MI that was originally after the one defining the register is
- // rematerializable, derive its corresponding register index.
- RegisterIdx NextRegIdx = Rematerializer::NoReg;
- auto KnownNextRegIdx = AdjacentDeletedMIs.find(RegIdx);
- if (KnownNextRegIdx != AdjacentDeletedMIs.end()) {
- NextRegIdx = KnownNextRegIdx->second;
- } else {
- MachineBasicBlock::iterator InsertPos = std::next(Reg.DefMI->getIterator());
- if (InsertPos != Reg.DefMI->getParent()->end())
- NextRegIdx = Remater.getDefRegIdx(*InsertPos);
- }
- DeadRegs.try_emplace(RegIdx, Reg, NextRegIdx);
-
- // Keep track of adjacent rematerializable MIs to allow re-creation in the
- // same exact order regardless of rematerialization order.
- MachineBasicBlock::iterator DefMI = Reg.DefMI->getIterator();
- if (Reg.DefMI->getParent()->begin() != DefMI) {
- RegisterIdx PrevRegIdx = Remater.getDefRegIdx(*std::prev(DefMI));
- if (PrevRegIdx != Rematerializer::NoReg) {
- // The key might already be in the map, which indicates that there were
- // originally instructions in between the MIs which have since then been
- // deleted. In these cases we leave the value untouched, as we care about
- // MIs that were adjacent in the original instruction order.
- AdjacentDeletedMIs.try_emplace(PrevRegIdx, RegIdx);
- }
}
+
+ // Original registers can be re-created. Add a re-creation position for the
+ // definition of the rematerializable register.
+ DeadRegs.push_back(DeadReg(RegIdx, Remater));
+ const InsertBeforePos InsertPos = makePos(ValidPos, ParentMBB);
+ PosToIdx[InsertPos].insert(Positions.size());
+ Positions.push_back(InsertPos);
}
void Rollbacker::rollback(Rematerializer &Remater) {
RollingBack = true;
- // Re-create deleted registers.
- for (auto &[RegIdx, Info] : DeadRegs) {
- assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead");
-
- // The MI that was originally just after the MI defining the register we
- // are trying to re-create may have been deleted. In such cases, we can
- // re-create at that MI's own insert position (and apply the same logic
- // recursively).
- MachineBasicBlock::iterator InsertPos = Info.InsertPos;
- RegisterIdx NextRegIdx = Info.NextRegIdx;
- while (NextRegIdx != Rematerializer::NoReg) {
- const Rematerializer::Reg &MaybeAliveReg = Remater.getReg(NextRegIdx);
- // When the next MI is alive (including when it was dead and already
- // re-created), we must use it as the position to insert before.
- if (MaybeAliveReg.isAlive()) {
- InsertPos = MaybeAliveReg.DefMI->getIterator();
- break;
- }
- RollbackInfo NextInfo = DeadRegs.at(NextRegIdx);
- InsertPos = NextInfo.InsertPos;
- NextRegIdx = NextInfo.NextRegIdx;
+ // As we re-create registers, map deleted definitions to re-created ones. This
+ // allows to replace invalid re-creation positions that reference deleted
+ // definitions to valid new positions while restoring original MI order.
+ DenseMap<MachineInstr *, MachineInstr *> Replacements;
+ unsigned PositionIndex = Positions.size();
+
+ // Re-create deleted registers in reverse order of deletion. Related registers
+ // are deleted in reverse def-use order so this ensures we re-create registers
+ // in def-use order. This also ensures that re-creation positions that became
+ // invalid due to later MI deletions can be corrected as we go.
+ for (const DeadReg &Reg : reverse(DeadRegs)) {
+ assert(!Remater.getReg(Reg.Idx).isAlive() && "register should be dead");
+
+ // Determine re-creation position for the register's definition.
+ MachineBasicBlock::iterator InsertPosition;
+ const auto [Ptr, IsMBB] = Positions[--PositionIndex];
+ if (IsMBB) {
+ InsertPosition = static_cast<MachineBasicBlock *>(Ptr)->end();
+ } else {
+ MachineInstr *InsertBeforeMI = static_cast<MachineInstr *>(Ptr);
+ InsertBeforeMI = Replacements.lookup_or(InsertBeforeMI, InsertBeforeMI);
+ InsertPosition = InsertBeforeMI->getIterator();
}
- Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg,
- std::move(Info.Dependencies));
+
+ Remater.recreateReg(Reg.Idx, InsertPosition, Reg.DefReg);
+
+ const Rematerializer::Reg &RecreateReg = Remater.getReg(Reg.Idx);
+ assert(!Replacements.contains(Reg.DefMI) && "duplicate deleted MI");
+ Replacements[Reg.DefMI] = RecreateReg.DefMI;
}
// Rollback rematerializations.
@@ -852,7 +843,38 @@ void Rollbacker::rollback(Rematerializer &Remater) {
Remater.updateLiveIntervals();
DeadRegs.clear();
+ Positions.clear();
+ PosToIdx.clear();
Rematerializations.clear();
- AdjacentDeletedMIs.clear();
RollingBack = false;
}
+
+bool Rollbacker::isRollbackableMI(const MachineInstr &MI,
+ const Rematerializer &Remater) const {
+ RegisterIdx RegIdx = Remater.getDefRegIdx(MI);
+ if (RegIdx == Rematerializer::NoReg ||
+ !Remater.isRematerializedRegister(RegIdx))
+ return false;
+ // It is possible that the MI defines a rematerializable register that was not
+ // recorded if the rollbacker was attached to the rematerializer after the
+ // rematerialization happened. In such cases the MI won't be rolled back.
+ auto RematsOf = Rematerializations.find(Remater.getOriginOf(RegIdx));
+ if (RematsOf == Rematerializations.end())
+ return false;
+ return RematsOf->getSecond().contains(RegIdx);
+}
+
+void Rollbacker::invalidatePosition(MachineInstr *MI,
+ MachineBasicBlock::iterator It) {
+ const InsertBeforePos MIPos = makePos(MI),
+ NewPos = makePos(It, MI->getParent());
+ auto MIIndices = PosToIdx.find(MIPos);
+ if (MIIndices == PosToIdx.end())
+ return;
+ assert(!MIIndices->getSecond().empty() && "no index hold position");
+ for (unsigned I : MIIndices->getSecond())
+ Positions[I] = NewPos;
+ PosToIdx.try_emplace(NewPos).first->getSecond().insert_range(
+ MIIndices->getSecond());
+ PosToIdx.erase(MIPos);
+}
diff --git a/llvm/unittests/CodeGen/RematerializerTest.cpp b/llvm/unittests/CodeGen/RematerializerTest.cpp
index 1a2e0373b8cb4..2fdb80b9be5e6 100644
--- a/llvm/unittests/CodeGen/RematerializerTest.cpp
+++ b/llvm/unittests/CodeGen/RematerializerTest.cpp
@@ -697,3 +697,75 @@ TEST_F(RematerializerTest, RollbackInvalidInsertPos) {
RollbackAndCheckOriginalOrder();
});
}
+
+/// Checks that rollback re-creates MIs in the correct order when the next MI
+/// after a deleted one is a rematerialization of another MI.
+TEST_F(RematerializerTest, RollbackNextPosIsRemat) {
+ StringRef MIRBody = R"MIR(
+ bb.0:
+ %0:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 0, implicit $exec, implicit $mode
+ %1:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 1, implicit $exec, implicit $mode
+
+ bb.1:
+ %2:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 2, implicit $exec, implicit $mode
+ S_NOP 0, implicit %0
+
+ bb.2:
+ %3:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 3, implicit $exec, implicit $mode
+ S_NOP 0, implicit %1
+
+ bb.3:
+ S_NOP 0, implicit %2, implicit %3
+ S_ENDPGM 0
+)MIR";
+ rematerializerTest(MIRBody, [](RematerializerWrapper &RW) {
+ Rematerializer::DependencyReuseInfo DRI;
+ Rollbacker Rollback;
+
+ const unsigned MBB1 = 1, MBB2 = 2, MBB3 = 3;
+ const RegisterIdx Cst0 = 0, Cst1 = 1, Cst2 = 2, Cst3 = 3;
+
+ MachineInstr *Nop1 = &*std::prev(RW.MF.getBlockNumbered(1)->end());
+ MachineInstr *Nop2 = &*std::prev(RW.MF.getBlockNumbered(2)->end());
+ MachineInstr *Nop3 =
+ &*std::prev(std::prev(RW.MF.getBlockNumbered(3)->end()));
+
+ auto ExpectSeq = [](MachineInstr *MI, MachineInstr *ExpectedNext) {
+ MachineInstr *ActualNext = &*std::next(MI->getIterator());
+ EXPECT_EQ(ActualNext, ExpectedNext);
+ };
+
+ // This rematerialization is created right after %2, which is later
+ // rematerialized. It is *not* recorded by the rollbacker.
+ RegisterIdx RematCst0 = RW->rematerializeToRegion(Cst0, MBB1, DRI.clear());
+ ExpectSeq(RW->getReg(Cst2).DefMI, RW->getReg(RematCst0).DefMI);
+ ExpectSeq(RW->getReg(RematCst0).DefMI, Nop1);
+
+ RW->addListener(&Rollback);
+
+ // This rematerialization is created right after %3, which is later
+ // rematerialized. It is recorded by the rollbacker.
+ RegisterIdx RematCst1 = RW->rematerializeToRegion(Cst1, MBB2, DRI.clear());
+ ExpectSeq(RW->getReg(Cst3).DefMI, RW->getReg(RematCst1).DefMI);
+ ExpectSeq(RW->getReg(RematCst1).DefMI, Nop2);
+
+ RegisterIdx RematCst2 = RW->rematerializeToRegion(Cst2, MBB3, DRI.clear());
+ RegisterIdx RematCst3 = RW->rematerializeToRegion(Cst3, MBB3, DRI.clear());
+
+ ExpectSeq(RW->getReg(RematCst2).DefMI, RW->getReg(RematCst3).DefMI);
+ ExpectSeq(RW->getReg(RematCst3).DefMI, Nop3);
+
+ // After rollback, %2 and %3 should be re-created at the beginning of their
+ // respective original region.
+ Rollback.rollback(*RW);
+
+ // The rematerialization of %0 was not recorded so isn't rolled back, %2 is
+ // re-created right before it.
+ ExpectSeq(RW->getReg(Cst2).DefMI, RW->getReg(RematCst0).DefMI);
+ ExpectSeq(RW->getReg(RematCst0).DefMI, Nop1);
+
+ // The rematerialization of %1 was recorded so is rolled back, %3 is
+ // re-created before the S_NOP in its region.
+ ExpectSeq(RW->getReg(Cst3).DefMI, Nop2);
+ });
+}
More information about the llvm-branch-commits
mailing list