[llvm] [SandboxVec][DAG] Register move instr callback (PR #120146)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 16 13:06:52 PST 2024
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/120146
This patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG.
>From 26a263ae6232db7cf1592f51cec17cdfef7cd344 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 7 Nov 2024 12:46:53 -0800
Subject: [PATCH] [SandboxVec][DAG] Register move instr callback
This patch implements the move instruction notifier for the DAG.
Whenever an instruction moves the notifier will maintain the DAG.
---
.../SandboxVectorizer/DependencyGraph.h | 18 ++++++++
.../SandboxVectorizer/DependencyGraph.cpp | 46 +++++++++++++++++++
.../SandboxVectorizer/DependencyGraphTest.cpp | 35 +++++++++++++-
3 files changed, 98 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b1cad2421bc0d2..f423e1ee456cd1 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -220,6 +220,14 @@ class MemDGNode final : public DGNode {
void setNextNode(MemDGNode *N) { NextMemN = N; }
void setPrevNode(MemDGNode *N) { PrevMemN = N; }
friend class DependencyGraph; // For setNextNode(), setPrevNode().
+ void detachFromChain() {
+ if (PrevMemN != nullptr)
+ PrevMemN->NextMemN = NextMemN;
+ if (NextMemN != nullptr)
+ NextMemN->PrevMemN = PrevMemN;
+ PrevMemN = nullptr;
+ NextMemN = nullptr;
+ }
public:
MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) {
@@ -293,6 +301,7 @@ class DependencyGraph {
Context *Ctx = nullptr;
std::optional<Context::CallbackID> CreateInstrCB;
std::optional<Context::CallbackID> EraseInstrCB;
+ std::optional<Context::CallbackID> MoveInstrCB;
std::unique_ptr<BatchAAResults> BatchAA;
@@ -343,6 +352,9 @@ class DependencyGraph {
/// Called by the callbacks when instruction \p I is about to get
/// deleted.
void notifyEraseInstr(Instruction *I);
+ /// Called by the callbacks when instruction \p I is about to be moved to
+ /// \p To.
+ void notifyMoveInstr(Instruction *I, const BBIterator &To);
public:
/// This constructor also registers callbacks.
@@ -352,12 +364,18 @@ class DependencyGraph {
[this](Instruction *I) { notifyCreateInstr(I); });
EraseInstrCB = Ctx.registerEraseInstrCallback(
[this](Instruction *I) { notifyEraseInstr(I); });
+ MoveInstrCB = Ctx.registerMoveInstrCallback(
+ [this](Instruction *I, const BBIterator &To) {
+ notifyMoveInstr(I, To);
+ });
}
~DependencyGraph() {
if (CreateInstrCB)
Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
if (EraseInstrCB)
Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
+ if (MoveInstrCB)
+ Ctx->unregisterMoveInstrCallback(*MoveInstrCB);
}
DGNode *getNode(Instruction *I) const {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 25f2665d450d13..ba62c45a4e704e 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -370,6 +370,52 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
}
}
+void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
+ // Early return if `I` doesn't actually move.
+ BasicBlock *BB = To.getNodeParent();
+ if (To != BB->end() && &*To == I->getNextNode())
+ return;
+
+ // Maintain the DAGInterval.
+ DAGInterval.notifyMoveInstr(I, To);
+
+ // TODO: Perhaps check if this is legal by checking the dependencies?
+
+ // Update the MemDGNode chain to reflect the instr movement if necessary.
+ DGNode *N = getNodeOrNull(I);
+ if (N == nullptr)
+ return;
+ MemDGNode *MemN = dyn_cast<MemDGNode>(N);
+ if (MemN == nullptr)
+ return;
+ // First detach it from the existing chain.
+ MemN->detachFromChain();
+ // Now insert it back into the chain at the new location.
+ if (To != BB->end()) {
+ DGNode *ToN = getNodeOrNull(&*To);
+ if (ToN != nullptr) {
+ MemDGNode *PrevMemN = getMemDGNodeBefore(ToN, /*IncludingN=*/false);
+ MemDGNode *NextMemN = getMemDGNodeAfter(ToN, /*IncludingN=*/true);
+ MemN->PrevMemN = PrevMemN;
+ if (PrevMemN != nullptr)
+ PrevMemN->NextMemN = MemN;
+ MemN->NextMemN = NextMemN;
+ if (NextMemN != nullptr)
+ NextMemN->PrevMemN = MemN;
+ }
+ } else {
+ // MemN becomes the last instruction in the BB.
+ auto *TermN = getNodeOrNull(BB->getTerminator());
+ if (TermN != nullptr) {
+ MemDGNode *PrevMemN = getMemDGNodeBefore(TermN, /*IncludingN=*/false);
+ PrevMemN->NextMemN = MemN;
+ MemN->PrevMemN = PrevMemN;
+ } else {
+ // The terminator is outside the DAG interval so do nothing.
+ }
+ }
+}
+
void DependencyGraph::notifyEraseInstr(Instruction *I) {
// Update the MemDGNode chain if this is a memory node.
if (auto *MemN = dyn_cast_or_null<MemDGNode>(getNodeOrNull(I))) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 8c73ee1def8ae1..3fa4de501f3f5d 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -801,7 +801,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
TEST_F(DependencyGraphTest, CreateInstrCallback) {
parseIR(C, R"IR(
-define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
store i8 %v1, ptr %ptr
store i8 %v2, ptr %ptr
store i8 %v3, ptr %ptr
@@ -893,3 +893,36 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
// TODO: Check the dependencies to/from NewSN after they land.
}
+
+TEST_F(DependencyGraphTest, MoveInstrCallback) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+ %ld0 = load i8, ptr %ptr2
+ store i8 %v1, ptr %ptr
+ store i8 %v2, ptr %ptr
+ store i8 %v3, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+ DAG.extend({Ld, S3});
+ auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
+ auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ EXPECT_EQ(S1N->getPrevNode(), LdN);
+ S1->moveBefore(Ld);
+ EXPECT_EQ(S1N->getPrevNode(), nullptr);
+ EXPECT_EQ(S1N->getNextNode(), LdN);
+ EXPECT_EQ(LdN->getPrevNode(), S1N);
+ EXPECT_EQ(LdN->getNextNode(), S2N);
+}
More information about the llvm-commits
mailing list