[llvm] [GVN] MemorySSA for GVN: embed the memory state in symbolic expressions (PR #123218)
Antonio Frighetto via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 2 04:15:38 PDT 2025
https://github.com/antoniofrighetto updated https://github.com/llvm/llvm-project/pull/123218
>From b4e2bbc42d1582b36b0d89ebdd10bd8113af8d6a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 16 Jan 2025 15:31:30 +0100
Subject: [PATCH] [GVN] MemorySSA for GVN: embed the memory state in symbolic
expressions
While migrating towards MemorySSA, account for the memory state modeled
by MemorySSA by hashing it, when computing the symbolic expressions for
the memory operations. Likewise, when phi-translating while walking the
CFG for PRE possibilities, see if the value number of an operand may be
refined with one of the value from the incoming edges of the MemoryPhi
associated to the current phi.
---
llvm/include/llvm/Transforms/Scalar/GVN.h | 11 +++
llvm/lib/Transforms/Scalar/GVN.cpp | 89 +++++++++++++++++++++--
2 files changed, 95 insertions(+), 5 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Scalar/GVN.h b/llvm/include/llvm/Transforms/Scalar/GVN.h
index ffdf57cd3d8f8..507151cf3e5ab 100644
--- a/llvm/include/llvm/Transforms/Scalar/GVN.h
+++ b/llvm/include/llvm/Transforms/Scalar/GVN.h
@@ -46,7 +46,9 @@ class ImplicitControlFlowTracking;
class LoadInst;
class LoopInfo;
class MemDepResult;
+class MemoryAccess;
class MemoryDependenceResults;
+class MemoryLocation;
class MemorySSA;
class MemorySSAUpdater;
class NonLocalDepResult;
@@ -172,6 +174,10 @@ class GVNPass : public PassInfoMixin<GVNPass> {
// Value number to PHINode mapping. Used for phi-translate in scalarpre.
DenseMap<uint32_t, PHINode *> NumberingPhi;
+ // Value number to BasicBlock mapping. Used for phi-translate across
+ // MemoryPhis.
+ DenseMap<uint32_t, BasicBlock *> NumberingBB;
+
// Cache for phi-translate in scalarpre.
using PhiTranslateMap =
DenseMap<std::pair<uint32_t, const BasicBlock *>, uint32_t>;
@@ -179,6 +185,7 @@ class GVNPass : public PassInfoMixin<GVNPass> {
AAResults *AA = nullptr;
MemoryDependenceResults *MD = nullptr;
+ MemorySSA *MSSA = nullptr;
DominatorTree *DT = nullptr;
uint32_t NextValueNumber = 1;
@@ -189,12 +196,14 @@ class GVNPass : public PassInfoMixin<GVNPass> {
Expression createExtractvalueExpr(ExtractValueInst *EI);
Expression createGEPExpr(GetElementPtrInst *GEP);
uint32_t lookupOrAddCall(CallInst *C);
+ uint32_t lookupOrAddLoadStore(Instruction *I);
uint32_t phiTranslateImpl(const BasicBlock *BB, const BasicBlock *PhiBlock,
uint32_t Num, GVNPass &GVN);
bool areCallValsEqual(uint32_t Num, uint32_t NewNum, const BasicBlock *Pred,
const BasicBlock *PhiBlock, GVNPass &GVN);
std::pair<uint32_t, bool> assignExpNewValueNum(Expression &Exp);
bool areAllValsInBB(uint32_t Num, const BasicBlock *BB, GVNPass &GVN);
+ void addMemoryStateToExp(Instruction *I, Expression &Exp);
public:
ValueTable();
@@ -203,6 +212,7 @@ class GVNPass : public PassInfoMixin<GVNPass> {
~ValueTable();
ValueTable &operator=(const ValueTable &Arg);
+ uint32_t lookupOrAdd(MemoryAccess *MA);
uint32_t lookupOrAdd(Value *V);
uint32_t lookup(Value *V, bool Verify = true) const;
uint32_t lookupOrAddCmp(unsigned Opcode, CmpInst::Predicate Pred,
@@ -217,6 +227,7 @@ class GVNPass : public PassInfoMixin<GVNPass> {
void setAliasAnalysis(AAResults *A) { AA = A; }
AAResults *getAliasAnalysis() const { return AA; }
void setMemDep(MemoryDependenceResults *M) { MD = M; }
+ void setMemorySSA(MemorySSA *M) { MSSA = M; }
void setDomTree(DominatorTree *D) { DT = D; }
uint32_t getNextUnusedValueNumber() { return NextValueNumber; }
void verifyRemoved(const Value *) const;
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index 6233e8e2ee681..3fff576a9d0b1 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -475,6 +475,19 @@ void GVNPass::ValueTable::add(Value *V, uint32_t Num) {
NumberingPhi[Num] = PN;
}
+// Include the incoming memory state into the hash of the expression for the
+// given instruction. If the incoming memory state is:
+// * LiveOnEntry, add the value number of the entry block,
+// * a MemoryPhi, add the value number of the basic block corresponding to that
+// MemoryPhi,
+// * a MemoryDef, add the value number of the memory setting instruction.
+void GVNPass::ValueTable::addMemoryStateToExp(Instruction *I, Expression &Exp) {
+ assert(MSSA && "addMemoryStateToExp should not be called without MemorySSA");
+ assert(MSSA->getMemoryAccess(I) && "Instruction does not access memory");
+ MemoryAccess *MA = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(I);
+ Exp.VarArgs.push_back(lookupOrAdd(MA));
+}
+
uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) {
// FIXME: Currently the calls which may access the thread id may
// be considered as not accessing the memory. But this is
@@ -595,15 +608,48 @@ uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) {
return V;
}
+ if (MSSA && AA->onlyReadsMemory(C)) {
+ Expression Exp = createExpr(C);
+ addMemoryStateToExp(C, Exp);
+ auto [V, _] = assignExpNewValueNum(Exp);
+ ValueNumbering[C] = V;
+ return V;
+ }
+
ValueNumbering[C] = NextValueNumber;
return NextValueNumber++;
}
+/// Returns the value number for the specified load or store instruction.
+uint32_t GVNPass::ValueTable::lookupOrAddLoadStore(Instruction *I) {
+ if (!MSSA) {
+ ValueNumbering[I] = NextValueNumber;
+ return NextValueNumber++;
+ }
+
+ Expression Exp;
+ Exp.Ty = I->getType();
+ Exp.Opcode = I->getOpcode();
+ for (Use &Op : I->operands())
+ Exp.VarArgs.push_back(lookupOrAdd(Op));
+ addMemoryStateToExp(I, Exp);
+
+ auto [V, _] = assignExpNewValueNum(Exp);
+ ValueNumbering[I] = V;
+ return V;
+}
+
/// Returns true if a value number exists for the specified value.
bool GVNPass::ValueTable::exists(Value *V) const {
return ValueNumbering.contains(V);
}
+uint32_t GVNPass::ValueTable::lookupOrAdd(MemoryAccess *MA) {
+ return MSSA->isLiveOnEntryDef(MA) || isa<MemoryPhi>(MA)
+ ? lookupOrAdd(MA->getBlock())
+ : lookupOrAdd(cast<MemoryUseOrDef>(MA)->getMemoryInst());
+}
+
/// lookupOrAdd - Returns the value number for the specified value, assigning
/// it a new number if it did not have one before.
uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) {
@@ -614,6 +660,8 @@ uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) {
auto *I = dyn_cast<Instruction>(V);
if (!I) {
ValueNumbering[V] = NextValueNumber;
+ if (isa<BasicBlock>(V))
+ NumberingBB[NextValueNumber] = cast<BasicBlock>(V);
return NextValueNumber++;
}
@@ -673,6 +721,9 @@ uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) {
ValueNumbering[V] = NextValueNumber;
NumberingPhi[NextValueNumber] = cast<PHINode>(V);
return NextValueNumber++;
+ case Instruction::Load:
+ case Instruction::Store:
+ return lookupOrAddLoadStore(I);
default:
ValueNumbering[V] = NextValueNumber;
return NextValueNumber++;
@@ -710,6 +761,7 @@ void GVNPass::ValueTable::clear() {
ValueNumbering.clear();
ExpressionNumbering.clear();
NumberingPhi.clear();
+ NumberingBB.clear();
PhiTranslateTable.clear();
NextValueNumber = 1;
Expressions.clear();
@@ -724,6 +776,8 @@ void GVNPass::ValueTable::erase(Value *V) {
// If V is PHINode, V <--> value number is an one-to-one mapping.
if (isa<PHINode>(V))
NumberingPhi.erase(Num);
+ else if (isa<BasicBlock>(V))
+ NumberingBB.erase(Num);
}
/// verifyRemoved - Verify that the value is removed from all internal data
@@ -2295,15 +2349,39 @@ bool GVNPass::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum,
uint32_t GVNPass::ValueTable::phiTranslateImpl(const BasicBlock *Pred,
const BasicBlock *PhiBlock,
uint32_t Num, GVNPass &GVN) {
+ // See if we can refine the value number by looking at the PN incoming value
+ // for the given predecessor.
if (PHINode *PN = NumberingPhi[Num]) {
- for (unsigned I = 0; I != PN->getNumIncomingValues(); ++I) {
- if (PN->getParent() == PhiBlock && PN->getIncomingBlock(I) == Pred)
- if (uint32_t TransVal = lookup(PN->getIncomingValue(I), false))
- return TransVal;
- }
+ if (PN->getParent() == PhiBlock)
+ for (unsigned I = 0; I != PN->getNumIncomingValues(); ++I)
+ if (PN->getIncomingBlock(I) == Pred)
+ if (uint32_t TransVal = lookup(PN->getIncomingValue(I), false))
+ return TransVal;
return Num;
}
+ if (BasicBlock *BB = NumberingBB[Num]) {
+ assert(MSSA && "NumberingBB is non-empty only when using MemorySSA");
+ // Value numbers of basic blocks are used to represent memory state in
+ // load/store instructions and read-only function calls when said state is
+ // set by a MemoryPhi.
+ if (BB != PhiBlock)
+ return Num;
+ MemoryPhi *MPhi = MSSA->getMemoryAccess(BB);
+ for (unsigned i = 0, N = MPhi->getNumIncomingValues(); i != N; ++i) {
+ if (MPhi->getIncomingBlock(i) != Pred)
+ continue;
+ MemoryAccess *MA = MPhi->getIncomingValue(i);
+ if (auto *PredPhi = dyn_cast<MemoryPhi>(MA))
+ return lookupOrAdd(PredPhi->getBlock());
+ if (MSSA->isLiveOnEntryDef(MA))
+ return lookupOrAdd(&BB->getParent()->getEntryBlock());
+ return lookupOrAdd(cast<MemoryUseOrDef>(MA)->getMemoryInst());
+ }
+ llvm_unreachable(
+ "CFG/MemorySSA mismatch: predecessor not found among incoming blocks");
+ }
+
// If there is any value related with Num is defined in a BB other than
// PhiBlock, it cannot depend on a phi in PhiBlock without going through
// a backedge. We can do an early exit in that case to save compile time.
@@ -2738,6 +2816,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,
ICF = &ImplicitCFT;
this->LI = &LI;
VN.setMemDep(MD);
+ VN.setMemorySSA(MSSA);
ORE = RunORE;
InvalidBlockRPONumbers = true;
MemorySSAUpdater Updater(MSSA);
More information about the llvm-commits
mailing list