[llvm] [SandboxVec][DAG] Fix MemDGNode chain maintenance when move destinati… (PR #124227)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 23 21:31:16 PST 2025
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/124227
…on is non mem
This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG. Whenever we move a memory instruction, the DAG gets notified about the move and maintains the chain of memory nodes. The bug was that if the destination of the move was not a memory instruction, then the memory node's next node would end up pointing to itself.
>From b71e3ff8a6367a026812ce1b988d500952f4f625 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 14 Jan 2025 10:49:22 -0800
Subject: [PATCH] [SandboxVec][DAG] Fix MemDGNode chain maintenance when move
destination is non mem
This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG.
Whenever we move a memory instruction, the DAG gets notified about the move
and maintains the chain of memory nodes. The bug was that if the destination
of the move was not a memory instruction, then the memory node's next node
would end up pointing to itself.
---
.../SandboxVectorizer/DependencyGraph.h | 16 +++--
.../SandboxVectorizer/DependencyGraph.cpp | 70 +++++++++++++------
.../SandboxVectorizer/DependencyGraphTest.cpp | 43 ++++++++++++
3 files changed, 103 insertions(+), 26 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b2d7c9b8aa8bbc..6e3f99d78b9329 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -218,12 +218,14 @@ class MemDGNode final : public DGNode {
friend class PredIterator; // For MemPreds.
/// Creates both edges: this<->N.
void setNextNode(MemDGNode *N) {
+ assert(N != this && "About to point to self!");
NextMemN = N;
if (NextMemN != nullptr)
NextMemN->PrevMemN = this;
}
/// Creates both edges: N<->this.
void setPrevNode(MemDGNode *N) {
+ assert(N != this && "About to point to self!");
PrevMemN = N;
if (PrevMemN != nullptr)
PrevMemN->NextMemN = this;
@@ -348,13 +350,15 @@ class DependencyGraph {
void createNewNodes(const Interval<Instruction> &NewInterval);
/// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
- /// before \p N, including or excluding \p N based on \p IncludingN, or
- /// nullptr if not found.
- MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
+ /// before \p N, skipping \p SkipN, including or excluding \p N based on
+ /// \p IncludingN, or nullptr if not found.
+ MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN = nullptr) const;
/// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
- /// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
- /// if not found.
- MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
+ /// after \p N, skipping \p SkipN, including or excluding \p N based on \p
+ /// IncludingN, or nullptr if not found.
+ MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN = nullptr) const;
/// Called by the callbacks when a new instruction \p I has been created.
void notifyCreateInstr(Instruction *I);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index f080111f08d45e..390a5e9688cc78 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -325,29 +325,31 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
setDefUseUnscheduledSuccs(NewInterval);
}
-MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
- bool IncludingN) const {
+MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN) const {
auto *I = N->getInstruction();
for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
PrevI = PrevI->getPrevNode()) {
auto *PrevN = getNodeOrNull(PrevI);
if (PrevN == nullptr)
return nullptr;
- if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
+ auto *PrevMemN = dyn_cast<MemDGNode>(PrevN);
+ if (PrevMemN != nullptr && PrevMemN != SkipN)
return PrevMemN;
}
return nullptr;
}
-MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
- bool IncludingN) const {
+MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN) const {
auto *I = N->getInstruction();
for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
NextI = NextI->getNextNode()) {
auto *NextN = getNodeOrNull(NextI);
if (NextN == nullptr)
return nullptr;
- if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
+ auto *NextMemN = dyn_cast<MemDGNode>(NextN);
+ if (NextMemN != nullptr && NextMemN != SkipN)
return NextMemN;
}
return nullptr;
@@ -377,6 +379,20 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
!(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
"Should not have been called if destination is same as origin.");
+ // TODO: We can only handle fully internal movements within DAGInterval or at
+ // the borders, i.e., right before the top or right after the bottom.
+ assert(To.getNodeParent() == I->getParent() &&
+ "TODO: We don't support movement across BBs!");
+ assert(
+ (To == std::next(DAGInterval.bottom()->getIterator()) ||
+ (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
+ (To != BB->end() && DAGInterval.contains(&*To))) &&
+ "TODO: To should be either within the DAGInterval or right "
+ "before/after it.");
+
+ // Make a copy of the DAGInterval before we update it.
+ auto OrigDAGInterval = DAGInterval;
+
// Maintain the DAGInterval.
DAGInterval.notifyMoveInstr(I, To);
@@ -389,23 +405,37 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
MemDGNode *MemN = dyn_cast<MemDGNode>(N);
if (MemN == nullptr)
return;
- // First detach it from the existing chain.
+
+ // First safely detach it from the existing chain.
MemN->detachFromChain();
+
// Now insert it back into the chain at the new location.
- if (To != BB->end()) {
- DGNode *ToN = getNodeOrNull(&*To);
- if (ToN != nullptr) {
- MemN->setPrevNode(getMemDGNodeBefore(ToN, /*IncludingN=*/false));
- MemN->setNextNode(getMemDGNodeAfter(ToN, /*IncludingN=*/true));
- }
+ //
+ // We won't always have a DGNode to insert before it. If `To` is BB->end() or
+ // if it points to an instr after DAGInterval.bottom() then we will have to
+ // find a node to insert *after*.
+ //
+ // BB: BB:
+ // I1 I1 ^
+ // I2 I2 | DAGInteval [I1 to I3]
+ // I3 I3 V
+ // I4 I4 <- `To` == right after DAGInterval
+ // <- `To` == BB->end()
+ //
+ if (To == BB->end() ||
+ To == std::next(OrigDAGInterval.bottom()->getIterator())) {
+ // If we don't have a node to insert before, find a node to insert after and
+ // update the chain.
+ DGNode *InsertAfterN = getNode(&*std::prev(To));
+ MemN->setPrevNode(
+ getMemDGNodeBefore(InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
} else {
- // MemN becomes the last instruction in the BB.
- auto *TermN = getNodeOrNull(BB->getTerminator());
- if (TermN != nullptr) {
- MemN->setPrevNode(getMemDGNodeBefore(TermN, /*IncludingN=*/false));
- } else {
- // The terminator is outside the DAG interval so do nothing.
- }
+ // We have a node to insert before, so update the chain.
+ DGNode *BeforeToN = getNode(&*To);
+ MemN->setPrevNode(
+ getMemDGNodeBefore(BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
+ MemN->setNextNode(
+ getMemDGNodeAfter(BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
}
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 3fa4de501f3f5d..29fc05a7f256a2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -926,3 +926,46 @@ define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
EXPECT_EQ(LdN->getPrevNode(), S1N);
EXPECT_EQ(LdN->getNextNode(), S2N);
}
+
+// Check that the mem chain is maintained correctly when the move destination is
+// not a mem node.
+TEST_F(DependencyGraphTest, MoveInstrCallbackWithNonMemInstrs) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %arg) {
+ %ld = load i8, ptr %ptr
+ %zext1 = zext i8 %arg to i32
+ %zext2 = zext i8 %arg to i32
+ store i8 %v1, ptr %ptr
+ store i8 %v2, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+ [[maybe_unused]] auto *Zext1 = cast<sandboxir::CastInst>(&*It++);
+ auto *Zext2 = cast<sandboxir::CastInst>(&*It++);
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+ DAG.extend({Ld, S2});
+ auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
+ auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ EXPECT_EQ(LdN->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), S2N);
+
+ S1->moveBefore(Zext2);
+ EXPECT_EQ(LdN->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), S2N);
+
+ // Try move right after the end of the DAGInterval.
+ S1->moveBefore(Ret);
+ EXPECT_EQ(S2N->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), nullptr);
+}
More information about the llvm-commits
mailing list