[llvm] [SandboxVec][DAG] Update MemDGNode chain upon instr creation (PR #116896)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 19 16:45:27 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: vporpo (vporpo)
<details>
<summary>Changes</summary>
The DAG maintains a chain of MemDGNodes that links together all the nodes that may touch memroy.
Whenever a new instruction gets created we need to make sure that this chain gets updated. If the new instruction touches memory then its corresponding MemDGNode should be inserted into the chain.
---
Full diff: https://github.com/llvm/llvm-project/pull/116896.diff
3 Files Affected:
- (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+12-6)
- (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+45)
- (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+33-8)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 68a2daca1403df..911ee3e839521c 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -329,13 +329,19 @@ class DependencyGraph {
/// chain.
void createNewNodes(const Interval<Instruction> &NewInterval);
+ /// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
+ /// before \p N, including or excluding \p N based on \p IncludingN, or
+ /// nullptr if not found.
+ MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
+ /// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
+ /// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
+ /// if not found.
+ MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
+
/// Called by the callbacks when a new instruction \p I has been created.
- void notifyCreateInstr(Instruction *I) {
- getOrCreateNode(I);
- // TODO: Update the dependencies for the new node.
- // TODO: Update the MemDGNode chain to include the new node if needed.
- }
- /// Called by the callbacks when instruction \p I is about to get deleted.
+ void notifyCreateInstr(Instruction *I);
+ /// Called by the callbacks when instruction \p I is about to get
+ /// deleted.
void notifyEraseInstr(Instruction *I) {
InstrToNodeMap.erase(I);
// TODO: Update the dependencies.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 4b0e12c28f07b7..5cf44ba9dcbaaa 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -325,6 +325,51 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
setDefUseUnscheduledSuccs(NewInterval);
}
+MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
+ bool IncludingN) const {
+ auto *I = N->getInstruction();
+ for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
+ PrevI = PrevI->getPrevNode()) {
+ auto *PrevN = getNodeOrNull(PrevI);
+ if (PrevN == nullptr)
+ return nullptr;
+ if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
+ return PrevMemN;
+ }
+ return nullptr;
+}
+
+MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
+ bool IncludingN) const {
+ auto *I = N->getInstruction();
+ for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
+ NextI = NextI->getNextNode()) {
+ auto *NextN = getNodeOrNull(NextI);
+ if (NextN == nullptr)
+ return nullptr;
+ if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
+ return NextMemN;
+ }
+ return nullptr;
+}
+
+void DependencyGraph::notifyCreateInstr(Instruction *I) {
+ auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
+ // TODO: Update the dependencies for the new node.
+
+ // Update the MemDGNode chain if this is a memory node.
+ if (MemN != nullptr) {
+ if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) {
+ PrevMemN->NextMemN = MemN;
+ MemN->PrevMemN = PrevMemN;
+ }
+ if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) {
+ NextMemN->PrevMemN = MemN;
+ MemN->NextMemN = NextMemN;
+ }
+ }
+}
+
Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
if (Instrs.empty())
return {};
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e6bb4b4684d262..1130c9c63c71da 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -814,21 +814,46 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
auto *BB = &*F->begin();
auto It = BB->begin();
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
- [[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S2 = cast<sandboxir::StoreInst>(&*It++);
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
// Check new instruction callback.
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
- DAG.extend({S1, S3});
+ DAG.extend({S1, Ret});
auto *Arg = F->getArg(3);
auto *Ptr = S1->getPointerOperand();
- sandboxir::StoreInst *NewS =
- sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
- /*IsVolatile=*/true, Ctx);
- auto *NewSN = DAG.getNode(NewS);
- EXPECT_TRUE(NewSN != nullptr);
+ {
+ sandboxir::StoreInst *NewS =
+ sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+ /*IsVolatile=*/true, Ctx);
+ auto *NewSN = DAG.getNode(NewS);
+ EXPECT_TRUE(NewSN != nullptr);
+
+ // Check the MemDGNode chain.
+ auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
+ auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
+ EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
+ EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
+ EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
+ }
+
+ {
+ // Also check if new node is at the end of the BB, after Ret.
+ sandboxir::StoreInst *NewS =
+ sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
+ /*IsVolatile=*/true, Ctx);
+ // Check the MemDGNode chain.
+ auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+ EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
+ EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
+ EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
+ }
+
// TODO: Check the dependencies to/from NewSN after they land.
- // TODO: Check the MemDGNode chain.
}
TEST_F(DependencyGraphTest, EraseInstrCallback) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/116896
More information about the llvm-commits
mailing list