[llvm] [SandboxVec][DAG] Cleanup: Move callback registration from Scheduler to DAG (PR #116455)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 18 19:46:51 PST 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/116455
>From df777524a41b8b5cb5206dd4536d70724025f08e Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 7 Nov 2024 09:25:34 -0800
Subject: [PATCH] [SandboxVec][DAG] Cleanup: Move callback registration from
Scheduler to DAG
This is a refactoring patch that moves the callback registration for
getting notified about new instructions from the scheduler to the DAG.
This makes sense from a design and testing point of view:
- the DAG should not rely on the scheduler for getting notified
- the notifiers don't need to be public
- it's easier to test the notifiers directly from within the DAG unit tests
---
.../SandboxVectorizer/DependencyGraph.h | 27 ++++++--
.../Vectorize/SandboxVectorizer/Scheduler.h | 9 +--
.../SandboxVectorizer/DependencyGraphTest.cpp | 66 ++++++++++++++-----
.../SandboxVectorizer/SchedulerTest.cpp | 2 +-
4 files changed, 72 insertions(+), 32 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 5211c7922ea2fd..765b65c4971bed 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -290,6 +290,9 @@ class DependencyGraph {
/// The DAG spans across all instructions in this interval.
Interval<Instruction> DAGInterval;
+ Context *Ctx = nullptr;
+ std::optional<Context::CallbackID> CreateInstrCB;
+
std::unique_ptr<BatchAAResults> BatchAA;
enum class DependencyType {
@@ -325,9 +328,24 @@ class DependencyGraph {
/// chain.
void createNewNodes(const Interval<Instruction> &NewInterval);
+ /// Called by the callbacks when a new instruction \p I has been created.
+ void notifyCreateInstr(Instruction *I) {
+ getOrCreateNode(I);
+ // TODO: Update the dependencies for the new node.
+ // TODO: Update the MemDGNode chain to include the new node if needed.
+ }
+
public:
- DependencyGraph(AAResults &AA)
- : BatchAA(std::make_unique<BatchAAResults>(AA)) {}
+ /// This constructor also registers callbacks.
+ DependencyGraph(AAResults &AA, Context &Ctx)
+ : Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
+ CreateInstrCB = Ctx.registerCreateInstrCallback(
+ [this](Instruction *I) { notifyCreateInstr(I); });
+ }
+ ~DependencyGraph() {
+ if (CreateInstrCB)
+ Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
+ }
DGNode *getNode(Instruction *I) const {
auto It = InstrToNodeMap.find(I);
@@ -354,11 +372,6 @@ class DependencyGraph {
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
/// \Returns the range of instructions included in the DAG.
Interval<Instruction> getInterval() const { return DAGInterval; }
- /// Called by the scheduler when a new instruction \p I has been created.
- void notifyCreateInstr(Instruction *I) {
- getOrCreateNode(I);
- // TODO: Update the dependencies for the new node.
- }
void clear() {
InstrToNodeMap.clear();
DAGInterval = {};
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 9c11b5dbc16432..022fd71df67dc6 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -106,8 +106,6 @@ class Scheduler {
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
// TODO: This is wasting memory in exchange for fast removal using a raw ptr.
DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
- Context &Ctx;
- Context::CallbackID CreateInstrCB;
/// \Returns a scheduling bundle containing \p Instrs.
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
@@ -137,11 +135,8 @@ class Scheduler {
Scheduler &operator=(const Scheduler &) = delete;
public:
- Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
- CreateInstrCB = Ctx.registerCreateInstrCallback(
- [this](Instruction *I) { DAG.notifyCreateInstr(I); });
- }
- ~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }
+ Scheduler(AAResults &AA, Context &Ctx) : DAG(AA, Ctx) {}
+ ~Scheduler() {}
bool trySchedule(ArrayRef<Instruction *> Instrs);
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 061d57c31ce236..206f6c5b4c1359 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -194,7 +194,7 @@ define void @foo(i8 %v1, ptr %ptr) {
auto *Call = cast<sandboxir::CallInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Store)));
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Load)));
@@ -224,7 +224,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()});
// Check extend().
EXPECT_EQ(Span.top(), &*BB->begin());
@@ -285,7 +285,7 @@ define i8 @foo(i8 %v0, i8 %v1) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto *AddN0 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
@@ -332,7 +332,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
@@ -366,7 +366,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
@@ -436,7 +436,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -461,7 +461,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -487,7 +487,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Ld0N = cast<sandboxir::MemDGNode>(
@@ -512,7 +512,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) {
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -542,7 +542,7 @@ define void @foo(float %v1, float %v2) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
auto It = BB->begin();
@@ -574,7 +574,7 @@ define void @foo() {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
auto It = BB->begin();
@@ -606,7 +606,7 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
auto It = BB->begin();
@@ -637,7 +637,7 @@ define void @foo(ptr %ptr) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
auto It = BB->begin();
@@ -664,7 +664,7 @@ define void @foo(ptr %ptr) {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
auto It = BB->begin();
@@ -695,7 +695,7 @@ define void @foo() {
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
auto It = BB->begin();
@@ -728,7 +728,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
auto *S4 = cast<sandboxir::StoreInst>(&*It++);
auto *S5 = cast<sandboxir::StoreInst>(&*It++);
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
{
// Scenario 1: Build new DAG
auto NewIntvl = DAG.extend({S3, S3});
@@ -788,7 +788,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
{
// Check UnscheduledSuccs when a node is scheduled
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({S2, S2});
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
S2N->setScheduled(true);
@@ -798,3 +798,35 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
EXPECT_EQ(S1N->getNumUnscheduledSuccs(), 0u); // S1 is scheduled
}
}
+
+TEST_F(DependencyGraphTest, CreateInstrCallback) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+ store i8 %v1, ptr %ptr
+ store i8 %v2, ptr %ptr
+ store i8 %v3, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ [[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+ // Check new instruction callback.
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+ DAG.extend({S1, S3});
+ auto *Arg = F->getArg(3);
+ auto *Ptr = S1->getPointerOperand();
+ sandboxir::StoreInst *NewS =
+ sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+ /*IsVolatile=*/true, Ctx);
+ auto *NewSN = DAG.getNode(NewS);
+ EXPECT_TRUE(NewSN != nullptr);
+ // TODO: Check the dependencies to/from NewSN after they land.
+ // TODO: Check the MemDGNode chain.
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
index 94a57914429748..c5e44a97976a72 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
@@ -70,7 +70,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto *SN0 = DAG.getNode(S0);
auto *SN1 = DAG.getNode(S1);
More information about the llvm-commits
mailing list