[llvm-branch-commits] [llvm] [CodeGen] Fix incorrect rematerialization rollback order (PR #197576)

Lucas Ramirez via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jun 9 07:16:03 PDT 2026


https://github.com/lucas-rami updated https://github.com/llvm/llvm-project/pull/197576

>From bdd060475f53929dc007965eece438111550c2fc Mon Sep 17 00:00:00 2001
From: Lucas Ramirez <lucas.rami at proton.me>
Date: Tue, 21 Apr 2026 22:50:29 +0000
Subject: [PATCH 1/2] [CodeGen] Fix incorrect rematerialization rollback order

This fixes an issue in the rematerializer's rollbacker wherein adjacent
MIs that were deleted through rematerializations would
sometimes---depending on the exact order in which they were
deleted---not be re-created in their original
pre-rematerialization order. While this does not impact correctness
(i.e., use-def relations are always honored), this goes against the
rollbacker's intent to re-create the MIR exactly as it was
pre-rematerializations (up to slot index changes).
---
 llvm/include/llvm/CodeGen/Rematerializer.h    |  8 +-
 llvm/lib/CodeGen/Rematerializer.cpp           | 61 ++++++++++----
 llvm/unittests/CodeGen/RematerializerTest.cpp | 82 +++++++++++++------
 3 files changed, 107 insertions(+), 44 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/Rematerializer.h b/llvm/include/llvm/CodeGen/Rematerializer.h
index 96c00c59f3186..bbebc6e0ab8b4 100644
--- a/llvm/include/llvm/CodeGen/Rematerializer.h
+++ b/llvm/include/llvm/CodeGen/Rematerializer.h
@@ -538,11 +538,17 @@ class Rollbacker : public Rematerializer::Listener {
     /// stores its index. Otherwise equals \ref Rematerializer::NoReg.
     RegisterIdx NextRegIdx;
 
-    RollbackInfo(const Rematerializer &Remater, RegisterIdx RegIdx);
+    RollbackInfo(const Rematerializer::Reg &Reg, RegisterIdx NextRegIdx);
   };
 
   /// Original registers that have been deleted, in order of deletion.
   MapVector<RegisterIdx, RollbackInfo> DeadRegs;
+  /// When there are two ajacent rematerializable MIs in the original
+  /// instruction order and the later one is deleted, stores a mapping from the
+  /// earlier one's register index to the later one's register index. If the
+  /// earlier one is then deleted, this makes it possible to rematerialize it at
+  /// the correct position after the later one is re-created.
+  DenseMap<RegisterIdx, RegisterIdx> AdjacentDeletedMIs;
   /// Registers which have been rematerialized (from original index to
   /// rematerialized index).
   DenseMap<RegisterIdx, Rematerializer::RematsOf> Rematerializations;
diff --git a/llvm/lib/CodeGen/Rematerializer.cpp b/llvm/lib/CodeGen/Rematerializer.cpp
index c4edface1e27f..2377cbfdf7a07 100644
--- a/llvm/lib/CodeGen/Rematerializer.cpp
+++ b/llvm/lib/CodeGen/Rematerializer.cpp
@@ -762,19 +762,11 @@ Printable Rematerializer::printUser(const MachineInstr *MI,
   });
 }
 
-Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater,
-                                       RegisterIdx RegIdx) {
-  const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
-  DefReg = Reg.getDefReg();
-  DefRegion = Reg.DefRegion;
-  Dependencies = Reg.Dependencies;
-
-  InsertPos = std::next(Reg.DefMI->getIterator());
-  if (InsertPos != Reg.DefMI->getParent()->end())
-    NextRegIdx = Remater.getDefRegIdx(*InsertPos);
-  else
-    NextRegIdx = Rematerializer::NoReg;
-}
+Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer::Reg &Reg,
+                                       RegisterIdx NextRegIdx)
+    : DefReg(Reg.getDefReg()), DefRegion(Reg.DefRegion),
+      Dependencies(Reg.Dependencies),
+      InsertPos(std::next(Reg.DefMI->getIterator())), NextRegIdx(NextRegIdx) {}
 
 void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater,
                                               RegisterIdx RegIdx) {
@@ -787,7 +779,34 @@ void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater,
                                               RegisterIdx RegIdx) {
   if (RollingBack || Remater.isRematerializedRegister(RegIdx))
     return;
-  DeadRegs.try_emplace(RegIdx, Remater, RegIdx);
+  const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
+
+  // If the MI that was originally after the one defining the register is
+  // rematerializable, derive its corresponding register index.
+  RegisterIdx NextRegIdx = Rematerializer::NoReg;
+  auto KnownNextRegIdx = AdjacentDeletedMIs.find(RegIdx);
+  if (KnownNextRegIdx != AdjacentDeletedMIs.end()) {
+    NextRegIdx = KnownNextRegIdx->second;
+  } else {
+    MachineBasicBlock::iterator InsertPos = std::next(Reg.DefMI->getIterator());
+    if (InsertPos != Reg.DefMI->getParent()->end())
+      NextRegIdx = Remater.getDefRegIdx(*InsertPos);
+  }
+  DeadRegs.try_emplace(RegIdx, Reg, NextRegIdx);
+
+  // Keep track of adjacent rematerializable MIs to allow re-creation in the
+  // same exact order regardless of rematerialization order.
+  MachineBasicBlock::iterator DefMI = Reg.DefMI->getIterator();
+  if (Reg.DefMI->getParent()->begin() != DefMI) {
+    RegisterIdx PrevRegIdx = Remater.getDefRegIdx(*std::prev(DefMI));
+    if (PrevRegIdx != Rematerializer::NoReg) {
+      // The key might already be in the map, which indicates that there were
+      // originally instructions in between the MIs which have since then been
+      // deleted. In these cases we leave the value untouched, as we care about
+      // MIs that were adjacent in the original instruction order.
+      AdjacentDeletedMIs.try_emplace(PrevRegIdx, RegIdx);
+    }
+  }
 }
 
 void Rollbacker::rollback(Rematerializer &Remater) {
@@ -804,11 +823,16 @@ void Rollbacker::rollback(Rematerializer &Remater) {
     MachineBasicBlock::iterator InsertPos = Info.InsertPos;
     RegisterIdx NextRegIdx = Info.NextRegIdx;
     while (NextRegIdx != Rematerializer::NoReg) {
-      const auto *NextRegRollback = DeadRegs.find(NextRegIdx);
-      if (NextRegRollback == DeadRegs.end())
+      const Rematerializer::Reg &MaybeAliveReg = Remater.getReg(NextRegIdx);
+      // When the next MI is alive (including when it was dead and already
+      // re-created), we must use it as the position to insert before.
+      if (MaybeAliveReg.isAlive()) {
+        InsertPos = MaybeAliveReg.DefMI->getIterator();
         break;
-      InsertPos = NextRegRollback->second.InsertPos;
-      NextRegIdx = NextRegRollback->second.NextRegIdx;
+      }
+      RollbackInfo NextInfo = DeadRegs.at(NextRegIdx);
+      InsertPos = NextInfo.InsertPos;
+      NextRegIdx = NextInfo.NextRegIdx;
     }
     Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg,
                         std::move(Info.Dependencies));
@@ -829,5 +853,6 @@ void Rollbacker::rollback(Rematerializer &Remater) {
   Remater.updateLiveIntervals();
   DeadRegs.clear();
   Rematerializations.clear();
+  AdjacentDeletedMIs.clear();
   RollingBack = false;
 }
diff --git a/llvm/unittests/CodeGen/RematerializerTest.cpp b/llvm/unittests/CodeGen/RematerializerTest.cpp
index 2211470297b65..1a2e0373b8cb4 100644
--- a/llvm/unittests/CodeGen/RematerializerTest.cpp
+++ b/llvm/unittests/CodeGen/RematerializerTest.cpp
@@ -635,33 +635,65 @@ TEST_F(RematerializerTest, RollbackInvalidInsertPos) {
     const unsigned MBB0 = 0, MBB1 = 1;
     const RegisterIdx Cst0 = 0, Cst1 = 1, Cst2 = 2, Cst3 = 3;
 
-    // Rematerialize %0 to MBB1, deleting the original register.
-    RW->rematerializeToRegion(Cst0, MBB1, DRI);
-    RW.moveMIs(MBB0, MBB1, 1);
-    ASSERT_REGION_SIZES();
-
-    // Rematerialize %1 to MBB1, deleting the original register.
-    RW->rematerializeToRegion(Cst1, MBB1, DRI.clear());
-    RW.moveMIs(MBB0, MBB1, 1);
-    ASSERT_REGION_SIZES();
+    auto RematToMBB1 = [&](RegisterIdx RegIdx) -> void {
+      // Rematerialize %RegIdx to MBB1, deleting the original register.
+      RW->rematerializeToRegion(RegIdx, MBB1, DRI.clear());
+      RW.moveMIs(MBB0, MBB1, 1);
+      ASSERT_REGION_SIZES();
+    };
 
-    // Rematerialize %2 to MBB1, deleting the original register.
-    RW->rematerializeToRegion(Cst2, MBB1, DRI.clear());
-    RW.moveMIs(MBB0, MBB1, 1);
-    ASSERT_REGION_SIZES();
+    auto GetNextMI = [&](MachineInstr *MI) -> MachineInstr * {
+      return &*std::next(MI->getIterator());
+    };
 
-    // Now rollback and check for correct instruction order in the original
-    // defining region.
-    Rollback.rollback(*RW);
-    RW.moveMIs(MBB1, MBB0, 3);
-    ASSERT_REGION_SIZES();
+    auto RollbackAndCheckOriginalOrder = [&]() -> void {
+      // Rollback and check for correct instruction order in the original
+      // defining region. The asserts on region sizes ensure that all original
+      // registers were indeed deleted and will be re-created in the original
+      // region.
+      Rollback.rollback(*RW);
+      RW.moveMIs(MBB1, MBB0, 3);
+      ASSERT_REGION_SIZES();
+
+      MachineInstr *DefCst0 = RW->getReg(Cst0).DefMI;
+      MachineInstr *DefCst1 = RW->getReg(Cst1).DefMI;
+      MachineInstr *DefCst2 = RW->getReg(Cst2).DefMI;
+      MachineInstr *DefCst3 = RW->getReg(Cst3).DefMI;
+      EXPECT_EQ(GetNextMI(DefCst0), DefCst1);
+      EXPECT_EQ(GetNextMI(DefCst1), DefCst2);
+      EXPECT_EQ(GetNextMI(DefCst2), DefCst3);
+    };
 
-    MachineInstr &DefCst0 = *RW->getReg(Cst0).DefMI;
-    MachineInstr &DefCst1 = *RW->getReg(Cst1).DefMI;
-    MachineInstr &DefCst2 = *RW->getReg(Cst2).DefMI;
-    MachineInstr &DefCst3 = *RW->getReg(Cst3).DefMI;
-    EXPECT_EQ(std::next(DefCst0.getIterator()), DefCst1.getIterator());
-    EXPECT_EQ(std::next(DefCst1.getIterator()), DefCst2.getIterator());
-    EXPECT_EQ(std::next(DefCst2.getIterator()), DefCst3.getIterator());
+    // Test every possible rematerialization order.
+
+    RematToMBB1(Cst0);
+    RematToMBB1(Cst1);
+    RematToMBB1(Cst2);
+    RollbackAndCheckOriginalOrder();
+
+    RematToMBB1(Cst0);
+    RematToMBB1(Cst2);
+    RematToMBB1(Cst1);
+    RollbackAndCheckOriginalOrder();
+
+    RematToMBB1(Cst1);
+    RematToMBB1(Cst0);
+    RematToMBB1(Cst2);
+    RollbackAndCheckOriginalOrder();
+
+    RematToMBB1(Cst1);
+    RematToMBB1(Cst2);
+    RematToMBB1(Cst0);
+    RollbackAndCheckOriginalOrder();
+
+    RematToMBB1(Cst2);
+    RematToMBB1(Cst0);
+    RematToMBB1(Cst1);
+    RollbackAndCheckOriginalOrder();
+
+    RematToMBB1(Cst2);
+    RematToMBB1(Cst1);
+    RematToMBB1(Cst0);
+    RollbackAndCheckOriginalOrder();
   });
 }

>From 69e176310e4f746089e3dee91b8698555af5a3e5 Mon Sep 17 00:00:00 2001
From: Lucas Ramirez <lucas.rami at proton.me>
Date: Mon, 8 Jun 2026 11:06:39 +0000
Subject: [PATCH 2/2] Change rollback method to reduce tracking need

---
 llvm/include/llvm/CodeGen/Rematerializer.h    | 111 +++++++-----
 llvm/lib/CodeGen/Rematerializer.cpp           | 170 ++++++++++--------
 llvm/unittests/CodeGen/RematerializerTest.cpp |  72 ++++++++
 3 files changed, 235 insertions(+), 118 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/Rematerializer.h b/llvm/include/llvm/CodeGen/Rematerializer.h
index bbebc6e0ab8b4..f441433ac6ceb 100644
--- a/llvm/include/llvm/CodeGen/Rematerializer.h
+++ b/llvm/include/llvm/CodeGen/Rematerializer.h
@@ -14,15 +14,11 @@
 #ifndef LLVM_CODEGEN_REMATERIALIZER_H
 #define LLVM_CODEGEN_REMATERIALIZER_H
 
-#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/CodeGen/LiveIntervals.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
-#include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/CodeGen/TargetRegisterInfo.h"
-#include <iterator>
 
 namespace llvm {
 
@@ -185,10 +181,11 @@ class Rematerializer {
     virtual void rematerializerNoteRegCreated(const Rematerializer &Remater,
                                               RegisterIdx NewRegIdx) {}
 
-    /// Called juste before register \p RegIdx is deleted from the MIR. At this
+    /// Called just before register \p RegIdx is deleted from the MIR. At this
     /// point the register still exists in the MIR but no longer has any user.
-    virtual void rematerializerNoteRegDeleted(const Rematerializer &Remater,
-                                              RegisterIdx RegIdx) {}
+    virtual void
+    rematerializerNoteRegWillBeDeleted(const Rematerializer &Remater,
+                                       RegisterIdx RegIdx) {}
 
     virtual ~Listener() = default;
 
@@ -368,14 +365,12 @@ class Rematerializer {
                                MachineBasicBlock::iterator InsertPos,
                                SmallVectorImpl<Reg::Dependency> &&Dependencies);
 
-  /// Re-creates a previously deleted register \p RegIdx before \p InsertPos in
-  /// \p DefRegion. \p DefReg must be the original virtual register that \p
-  /// RegIdx used to define. Sets the new register's rematerializable
-  /// dependencies to \p Dependencies (these are assumed to already exist in the
-  /// MIR).
-  void recreateReg(RegisterIdx RegIdx, unsigned DefRegion,
-                   MachineBasicBlock::iterator InsertPos, Register DefReg,
-                   SmallVectorImpl<Reg::Dependency> &&Dependencies);
+  /// Re-creates a previously deleted register \p RegIdx before \p InsertPos,
+  /// which must be in the register's original defining region. \p DefReg must
+  /// be the original virtual register that \p RegIdx used to define.
+  /// Dependencies are assumed to already exist in the MIR.
+  void recreateReg(RegisterIdx RegIdx, MachineBasicBlock::iterator InsertPos,
+                   Register DefReg);
 
   /// Transfers all users of register \p FromRegIdx in region \p UseRegion to \p
   /// ToRegIdx, the latter of which must be a rematerialization of the former or
@@ -432,9 +427,9 @@ class Rematerializer {
       Listen->rematerializerNoteRegCreated(*this, RegIdx);
   }
 
-  void noteRegDeleted(RegisterIdx RegIdx) const {
+  void noteRegWillBeDeleted(RegisterIdx RegIdx) const {
     for (Listener *Listen : Listeners)
-      Listen->rematerializerNoteRegDeleted(*this, RegIdx);
+      Listen->rematerializerNoteRegWillBeDeleted(*this, RegIdx);
   }
 
   /// Rematerializable registers identified since the rematerializer's creation,
@@ -494,7 +489,9 @@ class Rematerializer {
 
   /// Deletes register \p RootIdx if it no longer has any user. If the register
   /// is deleted, recursively deletes any of its transitive rematerializable
-  /// dependencies that no longer have users as a result.
+  /// dependencies that no longer have users as a result. In case of recursive
+  /// deletion, all of a register's users are always deleted before the register
+  /// itself.
   void deleteRegIfUnused(RegisterIdx RootIdx);
 
   /// Deletes rematerializable register \p RegIdx from the DAG and relevant
@@ -516,45 +513,71 @@ class Rollbacker : public Rematerializer::Listener {
   void rematerializerNoteRegCreated(const Rematerializer &Remater,
                                     RegisterIdx RegIdx) override;
 
-  void rematerializerNoteRegDeleted(const Rematerializer &Remater,
-                                    RegisterIdx RegIdx) override;
+  void rematerializerNoteRegWillBeDeleted(const Rematerializer &Remater,
+                                          RegisterIdx RegIdx) override;
 
 private:
-  struct RollbackInfo {
+  struct DeadReg {
+    /// Register index.
+    RegisterIdx Idx;
     /// Original register.
     Register DefReg;
-    /// Original defining region.
-    unsigned DefRegion;
-    /// Original dependencies.
-    SmallVector<Rematerializer::Reg::Dependency, 2> Dependencies;
-    /// Position to re-create the register before in case of rollback. This
-    /// becomes invalid if it originally points to an MI that is deleted later
-    /// as a consequence of other rematerializations. In such cases \ref
-    /// NextRegIdx is guaranteed to be an actual register index from which the
-    /// rollback logic will determine a valid insert position before which to
-    /// re-create this register.
-    MachineBasicBlock::iterator InsertPos;
-    /// If \ref InsertPos points to an MI defining a rematerializable register,
-    /// stores its index. Otherwise equals \ref Rematerializer::NoReg.
-    RegisterIdx NextRegIdx;
-
-    RollbackInfo(const Rematerializer::Reg &Reg, RegisterIdx NextRegIdx);
+    /// Original definition of the register. The underlying MI no longer exist
+    /// at rollback time, but may be referenced as re-creation position for
+    /// previously deleted registers.
+    MachineInstr *DefMI;
+
+    DeadReg(RegisterIdx Idx, const Rematerializer &Remater)
+        : Idx(Idx), DefReg(Remater.getReg(Idx).getDefReg()),
+          DefMI(Remater.getReg(Idx).DefMI) {}
   };
 
+  /// An insertion position in the MIR. The pointer should be interpreted as:
+  /// - a MachineInstr* if the int is 0/false (insert before the MI).
+  /// - a MachineBasicBlock* if the int is 1/true (insert at the MBB's end).
+  using InsertBeforePos = PointerIntPair<void *, 1, bool>;
+
   /// Original registers that have been deleted, in order of deletion.
-  MapVector<RegisterIdx, RollbackInfo> DeadRegs;
-  /// When there are two ajacent rematerializable MIs in the original
-  /// instruction order and the later one is deleted, stores a mapping from the
-  /// earlier one's register index to the later one's register index. If the
-  /// earlier one is then deleted, this makes it possible to rematerialize it at
-  /// the correct position after the later one is re-created.
-  DenseMap<RegisterIdx, RegisterIdx> AdjacentDeletedMIs;
+  SmallVector<DeadReg> DeadRegs;
+  /// Re-creation positions for all original registers that have been deleted,
+  /// in register deletion order. A position is either a MachineInstr* that
+  /// existed in the MIR at the time the rollbacker was attached to the
+  /// rematerializer, or a MachineBasicBlock*.
+  SmallVector<InsertBeforePos> Positions;
+  /// Maps all re-creation positions that exist in \ref Positions to the indices
+  /// of elements holding that position in the vector.
+  DenseMap<InsertBeforePos, SmallDenseSet<unsigned, 1>> PosToIdx;
   /// Registers which have been rematerialized (from original index to
   /// rematerialized index).
   DenseMap<RegisterIdx, Rematerializer::RematsOf> Rematerializations;
   /// Used to block further recording of events whenver we are actively rolling
   /// back.
   bool RollingBack = false;
+
+  InsertBeforePos makePos(MachineInstr *MI) const {
+    return InsertBeforePos(MI, false);
+  }
+  InsertBeforePos makePos(MachineBasicBlock *MBB) const {
+    return InsertBeforePos(MBB, true);
+  }
+  InsertBeforePos makePos(MachineBasicBlock::iterator It,
+                          MachineBasicBlock *MBB) const {
+    if (It == MBB->end())
+      return makePos(MBB);
+    return makePos(&*It);
+  }
+
+  /// Whether \p MI would be deleted if we were to rollback later. These are MIs
+  /// defining rematerializable registers whose creation has been recorded by
+  /// the rollbacker.
+  bool isRollbackableMI(const MachineInstr &MI,
+                        const Rematerializer &Remater) const;
+
+  /// Switches all positions that point to \p MI to \p It in the \ref Positions
+  /// vector, and updates \ref PosToIdx accordingly. This is used when it
+  /// becomes known that \p MI is about to be permanently deleted from the MIR
+  /// and thus becomes an invalid re-creation position.
+  void invalidatePosition(MachineInstr *MI, MachineBasicBlock::iterator It);
 };
 
 } // namespace llvm
diff --git a/llvm/lib/CodeGen/Rematerializer.cpp b/llvm/lib/CodeGen/Rematerializer.cpp
index 2377cbfdf7a07..26e468cd00475 100644
--- a/llvm/lib/CodeGen/Rematerializer.cpp
+++ b/llvm/lib/CodeGen/Rematerializer.cpp
@@ -13,7 +13,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/CodeGen/Rematerializer.h"
-#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/CodeGen/LiveIntervals.h"
@@ -281,23 +280,22 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
   // Traverse the root's dependency DAG depth-first to find the set of registers
   // we can delete and a legal order to delete them in.
   SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
-  SmallSetVector<RegisterIdx, 8> DeleteOrder;
-  DeleteOrder.insert(RootIdx);
+  SmallVector<RegisterIdx, 8> DeleteOrder{RootIdx};
   do {
     // A deleted register's dependencies may be deletable too.
     const Reg &DeleteReg = getReg(DepDAG.pop_back_val());
     for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
-      // All dependencies loose a user (the deleted register).
+      // All dependencies lose a user (the deleted register).
       Reg &DepReg = Regs[Dep.RegIdx];
       DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion);
       if (DepReg.Uses.empty()) {
-        DeleteOrder.insert(Dep.RegIdx);
+        DeleteOrder.push_back(Dep.RegIdx);
         DepDAG.push_back(Dep.RegIdx);
       }
     }
   } while (!DepDAG.empty());
 
-  for (RegisterIdx RegIdx : reverse(DeleteOrder)) {
+  for (RegisterIdx RegIdx : DeleteOrder) {
     Reg &DeleteReg = Regs[RegIdx];
 
     // It is possible that the defined register we are deleting doesn't have an
@@ -322,7 +320,7 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
 }
 
 void Rematerializer::deleteReg(RegisterIdx RegIdx) {
-  noteRegDeleted(RegIdx);
+  noteRegWillBeDeleted(RegIdx);
 
   Reg &DeleteReg = Regs[RegIdx];
   assert(DeleteReg.DefMI && "register was already deleted");
@@ -533,17 +531,14 @@ RegisterIdx Rematerializer::rematerializeReg(
   return NewRegIdx;
 }
 
-void Rematerializer::recreateReg(
-    RegisterIdx RegIdx, unsigned DefRegion,
-    MachineBasicBlock::iterator InsertPos, Register DefReg,
-    SmallVectorImpl<Reg::Dependency> &&Dependencies) {
+void Rematerializer::recreateReg(RegisterIdx RegIdx,
+                                 MachineBasicBlock::iterator InsertPos,
+                                 Register DefReg) {
   assert(RegToIdx.contains(DefReg) && "unknown defined register");
   assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register");
   assert(!getReg(RegIdx).DefMI && "register is still alive");
 
   Reg &OriginReg = Regs[RegIdx];
-  OriginReg.DefRegion = DefRegion;
-  OriginReg.Dependencies = std::move(Dependencies);
 
   // Re-establish the link between origin and rematerialization if necessary.
   const bool RecreateOriginalReg = isOriginalRegister(RegIdx);
@@ -563,7 +558,8 @@ void Rematerializer::recreateReg(
   }
   const MachineInstr &ModelDefMI = *getReg(ModelRegIdx).DefMI;
 
-  TII.reMaterialize(*RegionMBB[DefRegion], InsertPos, DefReg, 0, ModelDefMI);
+  TII.reMaterialize(*RegionMBB[OriginReg.DefRegion], InsertPos, DefReg, 0,
+                    ModelDefMI);
   OriginReg.DefMI = &*std::prev(InsertPos);
   postRematerialization(ModelRegIdx, RegIdx, InsertPos);
   LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as "
@@ -762,80 +758,75 @@ Printable Rematerializer::printUser(const MachineInstr *MI,
   });
 }
 
-Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer::Reg &Reg,
-                                       RegisterIdx NextRegIdx)
-    : DefReg(Reg.getDefReg()), DefRegion(Reg.DefRegion),
-      Dependencies(Reg.Dependencies),
-      InsertPos(std::next(Reg.DefMI->getIterator())), NextRegIdx(NextRegIdx) {}
-
 void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater,
                                               RegisterIdx RegIdx) {
   if (RollingBack)
     return;
+  assert(Remater.isRematerializedRegister(RegIdx) && "only remats are created");
   Rematerializations[Remater.getOriginOf(RegIdx)].insert(RegIdx);
 }
 
-void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater,
-                                              RegisterIdx RegIdx) {
-  if (RollingBack || Remater.isRematerializedRegister(RegIdx))
+void Rollbacker::rematerializerNoteRegWillBeDeleted(
+    const Rematerializer &Remater, RegisterIdx RegIdx) {
+  if (RollingBack)
+    return;
+
+  // Find a valid re-creation position after the register's definition.
+  MachineInstr *DefMI = Remater.getReg(RegIdx).DefMI;
+  MachineBasicBlock *ParentMBB = DefMI->getParent();
+  MachineBasicBlock::iterator ValidPos = std::next(DefMI->getIterator());
+  while (ValidPos != ParentMBB->end() && isRollbackableMI(*ValidPos, Remater))
+    ValidPos = std::next(ValidPos);
+
+  if (Remater.isRematerializedRegister(RegIdx)) {
+    // Rematerializations will not be re-created. Previously deleted registers
+    // that reference this register's defining instruction as their re-creation
+    // position should instead be re-created at a valid position after the
+    // deleted MI.
+    invalidatePosition(DefMI, ValidPos);
     return;
-  const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
-
-  // If the MI that was originally after the one defining the register is
-  // rematerializable, derive its corresponding register index.
-  RegisterIdx NextRegIdx = Rematerializer::NoReg;
-  auto KnownNextRegIdx = AdjacentDeletedMIs.find(RegIdx);
-  if (KnownNextRegIdx != AdjacentDeletedMIs.end()) {
-    NextRegIdx = KnownNextRegIdx->second;
-  } else {
-    MachineBasicBlock::iterator InsertPos = std::next(Reg.DefMI->getIterator());
-    if (InsertPos != Reg.DefMI->getParent()->end())
-      NextRegIdx = Remater.getDefRegIdx(*InsertPos);
-  }
-  DeadRegs.try_emplace(RegIdx, Reg, NextRegIdx);
-
-  // Keep track of adjacent rematerializable MIs to allow re-creation in the
-  // same exact order regardless of rematerialization order.
-  MachineBasicBlock::iterator DefMI = Reg.DefMI->getIterator();
-  if (Reg.DefMI->getParent()->begin() != DefMI) {
-    RegisterIdx PrevRegIdx = Remater.getDefRegIdx(*std::prev(DefMI));
-    if (PrevRegIdx != Rematerializer::NoReg) {
-      // The key might already be in the map, which indicates that there were
-      // originally instructions in between the MIs which have since then been
-      // deleted. In these cases we leave the value untouched, as we care about
-      // MIs that were adjacent in the original instruction order.
-      AdjacentDeletedMIs.try_emplace(PrevRegIdx, RegIdx);
-    }
   }
+
+  // Original registers can be re-created. Add a re-creation position for the
+  // definition of the rematerializable register.
+  DeadRegs.push_back(DeadReg(RegIdx, Remater));
+  const InsertBeforePos InsertPos = makePos(ValidPos, ParentMBB);
+  PosToIdx[InsertPos].insert(Positions.size());
+  Positions.push_back(InsertPos);
 }
 
 void Rollbacker::rollback(Rematerializer &Remater) {
   RollingBack = true;
 
-  // Re-create deleted registers.
-  for (auto &[RegIdx, Info] : DeadRegs) {
-    assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead");
-
-    // The MI that was originally just after the MI defining the register we
-    // are trying to re-create may have been deleted. In such cases, we can
-    // re-create at that MI's own insert position (and apply the same logic
-    // recursively).
-    MachineBasicBlock::iterator InsertPos = Info.InsertPos;
-    RegisterIdx NextRegIdx = Info.NextRegIdx;
-    while (NextRegIdx != Rematerializer::NoReg) {
-      const Rematerializer::Reg &MaybeAliveReg = Remater.getReg(NextRegIdx);
-      // When the next MI is alive (including when it was dead and already
-      // re-created), we must use it as the position to insert before.
-      if (MaybeAliveReg.isAlive()) {
-        InsertPos = MaybeAliveReg.DefMI->getIterator();
-        break;
-      }
-      RollbackInfo NextInfo = DeadRegs.at(NextRegIdx);
-      InsertPos = NextInfo.InsertPos;
-      NextRegIdx = NextInfo.NextRegIdx;
+  // As we re-create registers, map deleted definitions to re-created ones. This
+  // allows to replace invalid re-creation positions that reference deleted
+  // definitions to valid new positions while restoring original MI order.
+  DenseMap<MachineInstr *, MachineInstr *> Replacements;
+  unsigned PositionIndex = Positions.size();
+
+  // Re-create deleted registers in reverse order of deletion. Related registers
+  // are deleted in reverse def-use order so this ensures we re-create registers
+  // in def-use order. This also ensures that re-creation positions that became
+  // invalid due to later MI deletions can be corrected as we go.
+  for (const DeadReg &Reg : reverse(DeadRegs)) {
+    assert(!Remater.getReg(Reg.Idx).isAlive() && "register should be dead");
+
+    // Determine re-creation position for the register's definition.
+    MachineBasicBlock::iterator InsertPosition;
+    const auto [Ptr, IsMBB] = Positions[--PositionIndex];
+    if (IsMBB) {
+      InsertPosition = static_cast<MachineBasicBlock *>(Ptr)->end();
+    } else {
+      MachineInstr *InsertBeforeMI = static_cast<MachineInstr *>(Ptr);
+      InsertBeforeMI = Replacements.lookup_or(InsertBeforeMI, InsertBeforeMI);
+      InsertPosition = InsertBeforeMI->getIterator();
     }
-    Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg,
-                        std::move(Info.Dependencies));
+
+    Remater.recreateReg(Reg.Idx, InsertPosition, Reg.DefReg);
+
+    const Rematerializer::Reg &RecreateReg = Remater.getReg(Reg.Idx);
+    assert(!Replacements.contains(Reg.DefMI) && "duplicate deleted MI");
+    Replacements[Reg.DefMI] = RecreateReg.DefMI;
   }
 
   // Rollback rematerializations.
@@ -852,7 +843,38 @@ void Rollbacker::rollback(Rematerializer &Remater) {
 
   Remater.updateLiveIntervals();
   DeadRegs.clear();
+  Positions.clear();
+  PosToIdx.clear();
   Rematerializations.clear();
-  AdjacentDeletedMIs.clear();
   RollingBack = false;
 }
+
+bool Rollbacker::isRollbackableMI(const MachineInstr &MI,
+                                  const Rematerializer &Remater) const {
+  RegisterIdx RegIdx = Remater.getDefRegIdx(MI);
+  if (RegIdx == Rematerializer::NoReg ||
+      !Remater.isRematerializedRegister(RegIdx))
+    return false;
+  // It is possible that the MI defines a rematerializable register that was not
+  // recorded if the rollbacker was attached to the rematerializer after the
+  // rematerialization happened. In such cases the MI won't be rolled back.
+  auto RematsOf = Rematerializations.find(Remater.getOriginOf(RegIdx));
+  if (RematsOf == Rematerializations.end())
+    return false;
+  return RematsOf->getSecond().contains(RegIdx);
+}
+
+void Rollbacker::invalidatePosition(MachineInstr *MI,
+                                    MachineBasicBlock::iterator It) {
+  const InsertBeforePos MIPos = makePos(MI),
+                        NewPos = makePos(It, MI->getParent());
+  auto MIIndices = PosToIdx.find(MIPos);
+  if (MIIndices == PosToIdx.end())
+    return;
+  assert(!MIIndices->getSecond().empty() && "no index hold position");
+  for (unsigned I : MIIndices->getSecond())
+    Positions[I] = NewPos;
+  PosToIdx.try_emplace(NewPos).first->getSecond().insert_range(
+      MIIndices->getSecond());
+  PosToIdx.erase(MIPos);
+}
diff --git a/llvm/unittests/CodeGen/RematerializerTest.cpp b/llvm/unittests/CodeGen/RematerializerTest.cpp
index 1a2e0373b8cb4..2fdb80b9be5e6 100644
--- a/llvm/unittests/CodeGen/RematerializerTest.cpp
+++ b/llvm/unittests/CodeGen/RematerializerTest.cpp
@@ -697,3 +697,75 @@ TEST_F(RematerializerTest, RollbackInvalidInsertPos) {
     RollbackAndCheckOriginalOrder();
   });
 }
+
+/// Checks that rollback re-creates MIs in the correct order when the next MI
+/// after a deleted one is a rematerialization of another MI.
+TEST_F(RematerializerTest, RollbackNextPosIsRemat) {
+  StringRef MIRBody = R"MIR(
+  bb.0:
+    %0:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 0, implicit $exec, implicit $mode
+    %1:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 1, implicit $exec, implicit $mode
+    
+  bb.1:
+    %2:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 2, implicit $exec, implicit $mode
+    S_NOP 0, implicit %0
+
+  bb.2:
+    %3:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 3, implicit $exec, implicit $mode
+    S_NOP 0, implicit %1
+
+  bb.3:
+    S_NOP 0, implicit %2, implicit %3
+    S_ENDPGM 0
+)MIR";
+  rematerializerTest(MIRBody, [](RematerializerWrapper &RW) {
+    Rematerializer::DependencyReuseInfo DRI;
+    Rollbacker Rollback;
+
+    const unsigned MBB1 = 1, MBB2 = 2, MBB3 = 3;
+    const RegisterIdx Cst0 = 0, Cst1 = 1, Cst2 = 2, Cst3 = 3;
+
+    MachineInstr *Nop1 = &*std::prev(RW.MF.getBlockNumbered(1)->end());
+    MachineInstr *Nop2 = &*std::prev(RW.MF.getBlockNumbered(2)->end());
+    MachineInstr *Nop3 =
+        &*std::prev(std::prev(RW.MF.getBlockNumbered(3)->end()));
+
+    auto ExpectSeq = [](MachineInstr *MI, MachineInstr *ExpectedNext) {
+      MachineInstr *ActualNext = &*std::next(MI->getIterator());
+      EXPECT_EQ(ActualNext, ExpectedNext);
+    };
+
+    // This rematerialization is created right after %2, which is later
+    // rematerialized. It is *not* recorded by the rollbacker.
+    RegisterIdx RematCst0 = RW->rematerializeToRegion(Cst0, MBB1, DRI.clear());
+    ExpectSeq(RW->getReg(Cst2).DefMI, RW->getReg(RematCst0).DefMI);
+    ExpectSeq(RW->getReg(RematCst0).DefMI, Nop1);
+
+    RW->addListener(&Rollback);
+
+    // This rematerialization is created right after %3, which is later
+    // rematerialized. It is recorded by the rollbacker.
+    RegisterIdx RematCst1 = RW->rematerializeToRegion(Cst1, MBB2, DRI.clear());
+    ExpectSeq(RW->getReg(Cst3).DefMI, RW->getReg(RematCst1).DefMI);
+    ExpectSeq(RW->getReg(RematCst1).DefMI, Nop2);
+
+    RegisterIdx RematCst2 = RW->rematerializeToRegion(Cst2, MBB3, DRI.clear());
+    RegisterIdx RematCst3 = RW->rematerializeToRegion(Cst3, MBB3, DRI.clear());
+
+    ExpectSeq(RW->getReg(RematCst2).DefMI, RW->getReg(RematCst3).DefMI);
+    ExpectSeq(RW->getReg(RematCst3).DefMI, Nop3);
+
+    // After rollback, %2 and %3 should be re-created at the beginning of their
+    // respective original region.
+    Rollback.rollback(*RW);
+
+    // The rematerialization of %0 was not recorded so isn't rolled back, %2 is
+    // re-created right before it.
+    ExpectSeq(RW->getReg(Cst2).DefMI, RW->getReg(RematCst0).DefMI);
+    ExpectSeq(RW->getReg(RematCst0).DefMI, Nop1);
+
+    // The rematerialization of %1 was recorded so is rolled back, %3 is
+    // re-created before the S_NOP in its region.
+    ExpectSeq(RW->getReg(Cst3).DefMI, Nop2);
+  });
+}



More information about the llvm-branch-commits mailing list