[llvm] [ReachingDefAnalysis] Extend the analysis to stack objects. (PR #118097)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 9 21:43:23 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-systemz
Author: Mikhail Gudim (mgudim)
<details>
<summary>Changes</summary>
We track definitions of stack objects, the implementation is identical to tracking of registers.
Also, added printing of all found reaching definitions for testing purposes.
---
Patch is 30.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118097.diff
4 Files Affected:
- (modified) llvm/include/llvm/CodeGen/ReachingDefAnalysis.h (+19-7)
- (modified) llvm/lib/CodeGen/ReachingDefAnalysis.cpp (+165-74)
- (added) llvm/test/CodeGen/RISCV/rda-stack.mir (+151)
- (added) llvm/test/CodeGen/SystemZ/rda-stack-copy.mir (+30)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
index d6a1f064ec0a58..a1a8f764b9dfee 100644
--- a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
+++ b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
@@ -114,8 +114,11 @@ class ReachingDefAnalysis : public MachineFunctionPass {
private:
MachineFunction *MF = nullptr;
const TargetRegisterInfo *TRI = nullptr;
+ const TargetInstrInfo *TII = nullptr;
LoopTraversal::TraversalOrder TraversedMBBOrder;
unsigned NumRegUnits = 0;
+ unsigned NumStackObjects = 0;
+ int ObjectIndexBegin = 0;
/// Instruction that defined each register, relative to the beginning of the
/// current basic block. When a LiveRegsDefInfo is used to represent a
/// live-out register, this value is relative to the end of the basic block,
@@ -138,6 +141,13 @@ class ReachingDefAnalysis : public MachineFunctionPass {
DenseMap<MachineInstr *, int> InstIds;
MBBReachingDefsInfo MBBReachingDefs;
+ using MBBFrameObjsReachingDefsInfo =
+ std::vector<std::vector<std::vector<int>>>;
+ // MBBFrameObjsReachingDefs[i][j] is a list of instruction indicies (relative
+ // to begining of MBB) that define frame index (j +
+ // MF->getFrameInfo().getObjectIndexBegin()) in MBB i. This is used in
+ // answering reaching defenition queries.
+ MBBFrameObjsReachingDefsInfo MBBFrameObjsReachingDefs;
/// Default values are 'nothing happened a long time ago'.
const int ReachingDefDefaultVal = -(1 << 21);
@@ -158,6 +168,7 @@ class ReachingDefAnalysis : public MachineFunctionPass {
MachineFunctionPass::getAnalysisUsage(AU);
}
+ void printAllReachingDefs(MachineFunction &MF);
bool runOnMachineFunction(MachineFunction &MF) override;
MachineFunctionProperties getRequiredProperties() const override {
@@ -176,12 +187,13 @@ class ReachingDefAnalysis : public MachineFunctionPass {
void traverse();
/// Provides the instruction id of the closest reaching def instruction of
- /// PhysReg that reaches MI, relative to the begining of MI's basic block.
- int getReachingDef(MachineInstr *MI, MCRegister PhysReg) const;
+ /// Reg that reaches MI, relative to the begining of MI's basic block.
+ // Note that Reg may represent a stack slot.
+ int getReachingDef(MachineInstr *MI, MCRegister Reg) const;
- /// Return whether A and B use the same def of PhysReg.
+ /// Return whether A and B use the same def of Reg.
bool hasSameReachingDef(MachineInstr *A, MachineInstr *B,
- MCRegister PhysReg) const;
+ MCRegister Reg) const;
/// Return whether the reaching def for MI also is live out of its parent
/// block.
@@ -309,9 +321,9 @@ class ReachingDefAnalysis : public MachineFunctionPass {
MachineInstr *getInstFromId(MachineBasicBlock *MBB, int InstId) const;
/// Provides the instruction of the closest reaching def instruction of
- /// PhysReg that reaches MI, relative to the begining of MI's basic block.
- MachineInstr *getReachingLocalMIDef(MachineInstr *MI,
- MCRegister PhysReg) const;
+ /// Reg that reaches MI, relative to the begining of MI's basic block.
+ // Note that Reg may represent a stack slot.
+ MachineInstr *getReachingLocalMIDef(MachineInstr *MI, MCRegister Reg) const;
};
} // namespace llvm
diff --git a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
index 0e8220ec6251cb..1e8dad37bb0c7d 100644
--- a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
+++ b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
@@ -10,6 +10,8 @@
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/CodeGen/LiveRegUnits.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/Support/Debug.h"
@@ -18,6 +20,10 @@ using namespace llvm;
#define DEBUG_TYPE "reaching-deps-analysis"
+static cl::opt<bool> PrintAllReachingDefs("print-all-reaching-defs", cl::Hidden,
+ cl::desc("Used for test purpuses"),
+ cl::Hidden);
+
char ReachingDefAnalysis::ID = 0;
INITIALIZE_PASS(ReachingDefAnalysis, DEBUG_TYPE, "ReachingDefAnalysis", false,
true)
@@ -30,22 +36,33 @@ static bool isValidRegUse(const MachineOperand &MO) {
return isValidReg(MO) && MO.isUse();
}
-static bool isValidRegUseOf(const MachineOperand &MO, MCRegister PhysReg,
+static bool isValidRegUseOf(const MachineOperand &MO, MCRegister Reg,
const TargetRegisterInfo *TRI) {
if (!isValidRegUse(MO))
return false;
- return TRI->regsOverlap(MO.getReg(), PhysReg);
+ return TRI->regsOverlap(MO.getReg(), Reg);
}
static bool isValidRegDef(const MachineOperand &MO) {
return isValidReg(MO) && MO.isDef();
}
-static bool isValidRegDefOf(const MachineOperand &MO, MCRegister PhysReg,
+static bool isValidRegDefOf(const MachineOperand &MO, MCRegister Reg,
const TargetRegisterInfo *TRI) {
if (!isValidRegDef(MO))
return false;
- return TRI->regsOverlap(MO.getReg(), PhysReg);
+ return TRI->regsOverlap(MO.getReg(), Reg);
+}
+
+static bool isFIDef(const MachineInstr &MI, int FrameIndex,
+ const TargetInstrInfo *TII) {
+ int DefFrameIndex = 0;
+ int SrcFrameIndex = 0;
+ if (TII->isStoreToStackSlot(MI, DefFrameIndex) ||
+ TII->isStackSlotCopy(MI, DefFrameIndex, SrcFrameIndex)) {
+ return DefFrameIndex == FrameIndex;
+ }
+ return false;
}
void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
@@ -54,6 +71,11 @@ void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
"Unexpected basic block number.");
MBBReachingDefs.startBasicBlock(MBBNumber, NumRegUnits);
+ MBBFrameObjsReachingDefs[MBBNumber].resize(NumStackObjects);
+ for (unsigned FOIdx = 0; FOIdx < NumStackObjects; ++FOIdx) {
+ MBBFrameObjsReachingDefs[MBBNumber][FOIdx].push_back(-1);
+ }
+
// Reset instruction counter in each basic block.
CurInstr = 0;
@@ -126,6 +148,13 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {
"Unexpected basic block number.");
for (auto &MO : MI->operands()) {
+ if (MO.isFI()) {
+ int FrameIndex = MO.getIndex();
+ if (!isFIDef(*MI, FrameIndex, TII))
+ continue;
+ MBBFrameObjsReachingDefs[MBBNumber][FrameIndex - ObjectIndexBegin]
+ .push_back(CurInstr);
+ }
if (!isValidRegDef(MO))
continue;
for (MCRegUnit Unit : TRI->regunits(MO.getReg().asMCReg())) {
@@ -209,12 +238,54 @@ void ReachingDefAnalysis::processBasicBlock(
leaveBasicBlock(MBB);
}
+void ReachingDefAnalysis::printAllReachingDefs(MachineFunction &MF) {
+ dbgs() << "RDA results for " << MF.getName() << "\n";
+ int Num = 0;
+ DenseMap<MachineInstr *, int> InstToNumMap;
+ SmallPtrSet<MachineInstr *, 2> Defs;
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ for (MachineOperand &MO : MI.operands()) {
+ Register Reg;
+ if (MO.isFI()) {
+ int FrameIndex = MO.getIndex();
+ Reg = Register::index2StackSlot(FrameIndex);
+ } else if (MO.isReg()) {
+ if (MO.isDef())
+ continue;
+ Reg = MO.getReg();
+ if (Reg == MCRegister::NoRegister)
+ continue;
+ } else {
+ continue;
+ }
+ Defs.clear();
+ getGlobalReachingDefs(&MI, Reg, Defs);
+ MO.print(dbgs(), TRI);
+ dbgs() << ":{ ";
+ for (MachineInstr *Def : Defs) {
+ dbgs() << InstToNumMap[Def] << " ";
+ }
+ dbgs() << "}\n";
+ }
+ dbgs() << Num << ": " << MI << "\n";
+ InstToNumMap[&MI] = Num;
+ ++Num;
+ }
+ }
+}
+
bool ReachingDefAnalysis::runOnMachineFunction(MachineFunction &mf) {
MF = &mf;
TRI = MF->getSubtarget().getRegisterInfo();
+ const TargetSubtargetInfo &STI = MF->getSubtarget();
+ TRI = STI.getRegisterInfo();
+ TII = STI.getInstrInfo();
LLVM_DEBUG(dbgs() << "********** REACHING DEFINITION ANALYSIS **********\n");
init();
traverse();
+ if (PrintAllReachingDefs)
+ printAllReachingDefs(*MF);
return false;
}
@@ -222,6 +293,7 @@ void ReachingDefAnalysis::releaseMemory() {
// Clear the internal vectors.
MBBOutRegsInfos.clear();
MBBReachingDefs.clear();
+ MBBFrameObjsReachingDefs.clear();
InstIds.clear();
LiveRegs.clear();
}
@@ -234,7 +306,10 @@ void ReachingDefAnalysis::reset() {
void ReachingDefAnalysis::init() {
NumRegUnits = TRI->getNumRegUnits();
+ NumStackObjects = MF->getFrameInfo().getNumObjects();
+ ObjectIndexBegin = MF->getFrameInfo().getObjectIndexBegin();
MBBReachingDefs.init(MF->getNumBlockIDs());
+ MBBFrameObjsReachingDefs.resize(MF->getNumBlockIDs());
// Initialize the MBBOutRegsInfos
MBBOutRegsInfos.resize(MF->getNumBlockIDs());
LoopTraversal Traversal;
@@ -261,7 +336,7 @@ void ReachingDefAnalysis::traverse() {
}
int ReachingDefAnalysis::getReachingDef(MachineInstr *MI,
- MCRegister PhysReg) const {
+ MCRegister Reg) const {
assert(InstIds.count(MI) && "Unexpected machine instuction.");
int InstId = InstIds.lookup(MI);
int DefRes = ReachingDefDefaultVal;
@@ -269,7 +344,20 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI,
assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
"Unexpected basic block number.");
int LatestDef = ReachingDefDefaultVal;
- for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
+
+ if (Register::isStackSlot(Reg)) {
+ int FrameIndex = Register::stackSlot2Index(Reg);
+ for (int Def :
+ MBBFrameObjsReachingDefs[MBBNumber][FrameIndex - ObjectIndexBegin]) {
+ if (Def >= InstId)
+ break;
+ DefRes = Def;
+ }
+ LatestDef = std::max(LatestDef, DefRes);
+ return LatestDef;
+ }
+
+ for (MCRegUnit Unit : TRI->regunits(Reg)) {
for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) {
if (Def >= InstId)
break;
@@ -280,22 +368,21 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI,
return LatestDef;
}
-MachineInstr *
-ReachingDefAnalysis::getReachingLocalMIDef(MachineInstr *MI,
- MCRegister PhysReg) const {
- return hasLocalDefBefore(MI, PhysReg)
- ? getInstFromId(MI->getParent(), getReachingDef(MI, PhysReg))
- : nullptr;
+MachineInstr *ReachingDefAnalysis::getReachingLocalMIDef(MachineInstr *MI,
+ MCRegister Reg) const {
+ return hasLocalDefBefore(MI, Reg)
+ ? getInstFromId(MI->getParent(), getReachingDef(MI, Reg))
+ : nullptr;
}
bool ReachingDefAnalysis::hasSameReachingDef(MachineInstr *A, MachineInstr *B,
- MCRegister PhysReg) const {
+ MCRegister Reg) const {
MachineBasicBlock *ParentA = A->getParent();
MachineBasicBlock *ParentB = B->getParent();
if (ParentA != ParentB)
return false;
- return getReachingDef(A, PhysReg) == getReachingDef(B, PhysReg);
+ return getReachingDef(A, Reg) == getReachingDef(B, Reg);
}
MachineInstr *ReachingDefAnalysis::getInstFromId(MachineBasicBlock *MBB,
@@ -318,19 +405,18 @@ MachineInstr *ReachingDefAnalysis::getInstFromId(MachineBasicBlock *MBB,
return nullptr;
}
-int ReachingDefAnalysis::getClearance(MachineInstr *MI,
- MCRegister PhysReg) const {
+int ReachingDefAnalysis::getClearance(MachineInstr *MI, MCRegister Reg) const {
assert(InstIds.count(MI) && "Unexpected machine instuction.");
- return InstIds.lookup(MI) - getReachingDef(MI, PhysReg);
+ return InstIds.lookup(MI) - getReachingDef(MI, Reg);
}
bool ReachingDefAnalysis::hasLocalDefBefore(MachineInstr *MI,
- MCRegister PhysReg) const {
- return getReachingDef(MI, PhysReg) >= 0;
+ MCRegister Reg) const {
+ return getReachingDef(MI, Reg) >= 0;
}
void ReachingDefAnalysis::getReachingLocalUses(MachineInstr *Def,
- MCRegister PhysReg,
+ MCRegister Reg,
InstSet &Uses) const {
MachineBasicBlock *MBB = Def->getParent();
MachineBasicBlock::iterator MI = MachineBasicBlock::iterator(Def);
@@ -340,11 +426,11 @@ void ReachingDefAnalysis::getReachingLocalUses(MachineInstr *Def,
// If/when we find a new reaching def, we know that there's no more uses
// of 'Def'.
- if (getReachingLocalMIDef(&*MI, PhysReg) != Def)
+ if (getReachingLocalMIDef(&*MI, Reg) != Def)
return;
for (auto &MO : MI->operands()) {
- if (!isValidRegUseOf(MO, PhysReg, TRI))
+ if (!isValidRegUseOf(MO, Reg, TRI))
continue;
Uses.insert(&*MI);
@@ -354,15 +440,14 @@ void ReachingDefAnalysis::getReachingLocalUses(MachineInstr *Def,
}
}
-bool ReachingDefAnalysis::getLiveInUses(MachineBasicBlock *MBB,
- MCRegister PhysReg,
+bool ReachingDefAnalysis::getLiveInUses(MachineBasicBlock *MBB, MCRegister Reg,
InstSet &Uses) const {
for (MachineInstr &MI :
instructionsWithoutDebug(MBB->instr_begin(), MBB->instr_end())) {
for (auto &MO : MI.operands()) {
- if (!isValidRegUseOf(MO, PhysReg, TRI))
+ if (!isValidRegUseOf(MO, Reg, TRI))
continue;
- if (getReachingDef(&MI, PhysReg) >= 0)
+ if (getReachingDef(&MI, Reg) >= 0)
return false;
Uses.insert(&MI);
}
@@ -370,18 +455,18 @@ bool ReachingDefAnalysis::getLiveInUses(MachineBasicBlock *MBB,
auto Last = MBB->getLastNonDebugInstr();
if (Last == MBB->end())
return true;
- return isReachingDefLiveOut(&*Last, PhysReg);
+ return isReachingDefLiveOut(&*Last, Reg);
}
-void ReachingDefAnalysis::getGlobalUses(MachineInstr *MI, MCRegister PhysReg,
+void ReachingDefAnalysis::getGlobalUses(MachineInstr *MI, MCRegister Reg,
InstSet &Uses) const {
MachineBasicBlock *MBB = MI->getParent();
// Collect the uses that each def touches within the block.
- getReachingLocalUses(MI, PhysReg, Uses);
+ getReachingLocalUses(MI, Reg, Uses);
// Handle live-out values.
- if (auto *LiveOut = getLocalLiveOutMIDef(MI->getParent(), PhysReg)) {
+ if (auto *LiveOut = getLocalLiveOutMIDef(MI->getParent(), Reg)) {
if (LiveOut != MI)
return;
@@ -389,9 +474,9 @@ void ReachingDefAnalysis::getGlobalUses(MachineInstr *MI, MCRegister PhysReg,
SmallPtrSet<MachineBasicBlock*, 4>Visited;
while (!ToVisit.empty()) {
MachineBasicBlock *MBB = ToVisit.pop_back_val();
- if (Visited.count(MBB) || !MBB->isLiveIn(PhysReg))
+ if (Visited.count(MBB) || !MBB->isLiveIn(Reg))
continue;
- if (getLiveInUses(MBB, PhysReg, Uses))
+ if (getLiveInUses(MBB, Reg, Uses))
llvm::append_range(ToVisit, MBB->successors());
Visited.insert(MBB);
}
@@ -399,25 +484,25 @@ void ReachingDefAnalysis::getGlobalUses(MachineInstr *MI, MCRegister PhysReg,
}
void ReachingDefAnalysis::getGlobalReachingDefs(MachineInstr *MI,
- MCRegister PhysReg,
+ MCRegister Reg,
InstSet &Defs) const {
- if (auto *Def = getUniqueReachingMIDef(MI, PhysReg)) {
+ if (auto *Def = getUniqueReachingMIDef(MI, Reg)) {
Defs.insert(Def);
return;
}
for (auto *MBB : MI->getParent()->predecessors())
- getLiveOuts(MBB, PhysReg, Defs);
+ getLiveOuts(MBB, Reg, Defs);
}
-void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB,
- MCRegister PhysReg, InstSet &Defs) const {
+void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, MCRegister Reg,
+ InstSet &Defs) const {
SmallPtrSet<MachineBasicBlock*, 2> VisitedBBs;
- getLiveOuts(MBB, PhysReg, Defs, VisitedBBs);
+ getLiveOuts(MBB, Reg, Defs, VisitedBBs);
}
-void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB,
- MCRegister PhysReg, InstSet &Defs,
+void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, MCRegister Reg,
+ InstSet &Defs,
BlockSet &VisitedBBs) const {
if (VisitedBBs.count(MBB))
return;
@@ -425,28 +510,28 @@ void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB,
VisitedBBs.insert(MBB);
LiveRegUnits LiveRegs(*TRI);
LiveRegs.addLiveOuts(*MBB);
- if (LiveRegs.available(PhysReg))
+ if (Register::isPhysicalRegister(Reg) && LiveRegs.available(Reg))
return;
- if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg))
+ if (auto *Def = getLocalLiveOutMIDef(MBB, Reg))
Defs.insert(Def);
else
for (auto *Pred : MBB->predecessors())
- getLiveOuts(Pred, PhysReg, Defs, VisitedBBs);
+ getLiveOuts(Pred, Reg, Defs, VisitedBBs);
}
MachineInstr *
ReachingDefAnalysis::getUniqueReachingMIDef(MachineInstr *MI,
- MCRegister PhysReg) const {
+ MCRegister Reg) const {
// If there's a local def before MI, return it.
- MachineInstr *LocalDef = getReachingLocalMIDef(MI, PhysReg);
+ MachineInstr *LocalDef = getReachingLocalMIDef(MI, Reg);
if (LocalDef && InstIds.lookup(LocalDef) < InstIds.lookup(MI))
return LocalDef;
SmallPtrSet<MachineInstr*, 2> Incoming;
MachineBasicBlock *Parent = MI->getParent();
for (auto *Pred : Parent->predecessors())
- getLiveOuts(Pred, PhysReg, Incoming);
+ getLiveOuts(Pred, Reg, Incoming);
// Check that we have a single incoming value and that it does not
// come from the same block as MI - since it would mean that the def
@@ -469,13 +554,13 @@ MachineInstr *ReachingDefAnalysis::getMIOperand(MachineInstr *MI,
}
bool ReachingDefAnalysis::isRegUsedAfter(MachineInstr *MI,
- MCRegister PhysReg) const {
+ MCRegister Reg) const {
MachineBasicBlock *MBB = MI->getParent();
LiveRegUnits LiveRegs(*TRI);
LiveRegs.addLiveOuts(*MBB);
// Yes if the register is live out of the basic block.
- if (!LiveRegs.available(PhysReg))
+ if (!LiveRegs.available(Reg))
return true;
// Walk backwards through the block to see if the register is live at some
@@ -483,62 +568,68 @@ bool ReachingDefAnalysis::isRegUsedAfter(MachineInstr *MI,
for (MachineInstr &Last :
instructionsWithoutDebug(MBB->instr_rbegin(), MBB->instr_rend())) {
LiveRegs.stepBackward(Last);
- if (!LiveRegs.available(PhysReg))
+ if (!LiveRegs.available(Reg))
return InstIds.lookup(&Last) > InstIds.lookup(MI);
}
return false;
}
bool ReachingDefAnalysis::isRegDefinedAfter(MachineInstr *MI,
- MCRegister PhysReg) const {
+ MCRegister Reg) const {
MachineBasicBlock *MBB = MI->getParent();
auto Last = MBB->getLastNonDebugInstr();
if (Last != MBB->end() &&
- getReachingDef(MI, PhysReg) != getReachingDef(&*Last, PhysReg))
+ getReachingDef(MI, Reg) != getReachingDef(&*Last, Reg))
return true;
- if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg))
- return Def == getReachingLocalMIDef(MI, PhysReg);
+ if (auto *Def = getLocalLiveOutMIDef(MBB, Reg))
+ return Def == getReachingLocalMIDef(MI, Reg);
return false;
}
bool ReachingDefAnalysis::isReachingDefLiveOut(MachineInstr *MI,
- MCRegister PhysReg) const {
+ MCRegister Reg) const {
MachineBasicBlock *MBB = MI->getParent();
LiveRegUnits LiveRegs(*TRI);
LiveRegs.addLiveOuts(*MBB);
- if (LiveRegs.available(PhysReg))
+ if (Register::isPhysicalRegister(Reg) && LiveRegs.available(Reg))
return false;
auto Last = MBB->getLastNonDeb...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/118097
More information about the llvm-commits
mailing list