[llvm] eeb55d3 - [SandboxVec][DAG] Update MemDGNode chain upon instr creation (#116896)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 5 20:23:09 PST 2024
Author: vporpo
Date: 2024-12-05T20:23:06-08:00
New Revision: eeb55d3af63e10c573d4bb1f0fe69a55eafa52cb
URL: https://github.com/llvm/llvm-project/commit/eeb55d3af63e10c573d4bb1f0fe69a55eafa52cb
DIFF: https://github.com/llvm/llvm-project/commit/eeb55d3af63e10c573d4bb1f0fe69a55eafa52cb.diff
LOG: [SandboxVec][DAG] Update MemDGNode chain upon instr creation (#116896)
The DAG maintains a chain of MemDGNodes that links together all the
nodes that may touch memroy.
Whenever a new instruction gets created we need to make sure that this
chain gets updated. If the new instruction touches memory then its
corresponding MemDGNode should be inserted into the chain.
Added:
Modified:
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 68a2daca1403df..911ee3e839521c 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -329,13 +329,19 @@ class DependencyGraph {
/// chain.
void createNewNodes(const Interval<Instruction> &NewInterval);
+ /// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
+ /// before \p N, including or excluding \p N based on \p IncludingN, or
+ /// nullptr if not found.
+ MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
+ /// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
+ /// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
+ /// if not found.
+ MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
+
/// Called by the callbacks when a new instruction \p I has been created.
- void notifyCreateInstr(Instruction *I) {
- getOrCreateNode(I);
- // TODO: Update the dependencies for the new node.
- // TODO: Update the MemDGNode chain to include the new node if needed.
- }
- /// Called by the callbacks when instruction \p I is about to get deleted.
+ void notifyCreateInstr(Instruction *I);
+ /// Called by the callbacks when instruction \p I is about to get
+ /// deleted.
void notifyEraseInstr(Instruction *I) {
InstrToNodeMap.erase(I);
// TODO: Update the dependencies.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 4b0e12c28f07b7..5cf44ba9dcbaaa 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -325,6 +325,51 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
setDefUseUnscheduledSuccs(NewInterval);
}
+MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
+ bool IncludingN) const {
+ auto *I = N->getInstruction();
+ for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
+ PrevI = PrevI->getPrevNode()) {
+ auto *PrevN = getNodeOrNull(PrevI);
+ if (PrevN == nullptr)
+ return nullptr;
+ if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
+ return PrevMemN;
+ }
+ return nullptr;
+}
+
+MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
+ bool IncludingN) const {
+ auto *I = N->getInstruction();
+ for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
+ NextI = NextI->getNextNode()) {
+ auto *NextN = getNodeOrNull(NextI);
+ if (NextN == nullptr)
+ return nullptr;
+ if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
+ return NextMemN;
+ }
+ return nullptr;
+}
+
+void DependencyGraph::notifyCreateInstr(Instruction *I) {
+ auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
+ // TODO: Update the dependencies for the new node.
+
+ // Update the MemDGNode chain if this is a memory node.
+ if (MemN != nullptr) {
+ if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) {
+ PrevMemN->NextMemN = MemN;
+ MemN->PrevMemN = PrevMemN;
+ }
+ if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) {
+ NextMemN->PrevMemN = MemN;
+ MemN->NextMemN = NextMemN;
+ }
+ }
+}
+
Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
if (Instrs.empty())
return {};
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e6bb4b4684d262..1130c9c63c71da 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -814,21 +814,46 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
auto *BB = &*F->begin();
auto It = BB->begin();
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
- [[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S2 = cast<sandboxir::StoreInst>(&*It++);
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
// Check new instruction callback.
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
- DAG.extend({S1, S3});
+ DAG.extend({S1, Ret});
auto *Arg = F->getArg(3);
auto *Ptr = S1->getPointerOperand();
- sandboxir::StoreInst *NewS =
- sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
- /*IsVolatile=*/true, Ctx);
- auto *NewSN = DAG.getNode(NewS);
- EXPECT_TRUE(NewSN != nullptr);
+ {
+ sandboxir::StoreInst *NewS =
+ sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+ /*IsVolatile=*/true, Ctx);
+ auto *NewSN = DAG.getNode(NewS);
+ EXPECT_TRUE(NewSN != nullptr);
+
+ // Check the MemDGNode chain.
+ auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
+ auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
+ EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
+ EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
+ EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
+ }
+
+ {
+ // Also check if new node is at the end of the BB, after Ret.
+ sandboxir::StoreInst *NewS =
+ sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
+ /*IsVolatile=*/true, Ctx);
+ // Check the MemDGNode chain.
+ auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+ EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
+ EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
+ EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
+ }
+
// TODO: Check the dependencies to/from NewSN after they land.
- // TODO: Check the MemDGNode chain.
}
TEST_F(DependencyGraphTest, EraseInstrCallback) {
More information about the llvm-commits
mailing list