[llvm] [SandboxVec][DAG] Cleanup: Move callback registration from Scheduler to DAG (PR #116455)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 15 16:17:00 PST 2024
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/116455
This is a refactoring patch that moves the callback registration for getting notified about new instructions from the scheduler to the DAG. This makes sense from a design and testing point of view:
- the DAG should not rely on the scheduler for getting notified
- the notifiers don't need to be public
- it's easier to test the notifiers directly from within the DAG unit tests
>From 56284621c8941b63627a93592c940fa7e779da5f Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 7 Nov 2024 09:25:34 -0800
Subject: [PATCH] [SandboxVec][DAG] Cleanup: Move callback registration from
Scheduler to DAG
This is a refactoring patch that moves the callback registration for
getting notified about new instructions from the scheduler to the DAG.
This makes sense from a design and testing point of view:
- the DAG should not rely on the scheduler for getting notified
- the notifiers don't need to be public
- it's easier to test the notifiers directly from within the DAG unit tests
---
.../SandboxVectorizer/DependencyGraph.h | 25 ++++++++++++---
.../Vectorize/SandboxVectorizer/Scheduler.h | 9 ++----
.../SandboxVectorizer/DependencyGraphTest.cpp | 32 +++++++++++++++++++
3 files changed, 54 insertions(+), 12 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 5211c7922ea2fd..c2c7a6697f173e 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -290,6 +290,9 @@ class DependencyGraph {
/// The DAG spans across all instructions in this interval.
Interval<Instruction> DAGInterval;
+ Context *Ctx = nullptr;
+ std::optional<Context::CallbackID> CreateInstrCB;
+
std::unique_ptr<BatchAAResults> BatchAA;
enum class DependencyType {
@@ -325,9 +328,26 @@ class DependencyGraph {
/// chain.
void createNewNodes(const Interval<Instruction> &NewInterval);
+ /// Called by the callbacks when a new instruction \p I has been created.
+ void notifyCreateInstr(Instruction *I) {
+ getOrCreateNode(I);
+ // TODO: Update the dependencies for the new node.
+ // TODO: Update the MemDGNode chain to include the new node if needed.
+ }
+
public:
DependencyGraph(AAResults &AA)
: BatchAA(std::make_unique<BatchAAResults>(AA)) {}
+ /// This constructor also registers callbacks.
+ DependencyGraph(AAResults &AA, Context &Ctx)
+ : Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
+ CreateInstrCB = Ctx.registerCreateInstrCallback(
+ [this](Instruction *I) { notifyCreateInstr(I); });
+ }
+ ~DependencyGraph() {
+ if (CreateInstrCB)
+ Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
+ }
DGNode *getNode(Instruction *I) const {
auto It = InstrToNodeMap.find(I);
@@ -354,11 +374,6 @@ class DependencyGraph {
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
/// \Returns the range of instructions included in the DAG.
Interval<Instruction> getInterval() const { return DAGInterval; }
- /// Called by the scheduler when a new instruction \p I has been created.
- void notifyCreateInstr(Instruction *I) {
- getOrCreateNode(I);
- // TODO: Update the dependencies for the new node.
- }
void clear() {
InstrToNodeMap.clear();
DAGInterval = {};
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 9c11b5dbc16432..022fd71df67dc6 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -106,8 +106,6 @@ class Scheduler {
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
// TODO: This is wasting memory in exchange for fast removal using a raw ptr.
DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
- Context &Ctx;
- Context::CallbackID CreateInstrCB;
/// \Returns a scheduling bundle containing \p Instrs.
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
@@ -137,11 +135,8 @@ class Scheduler {
Scheduler &operator=(const Scheduler &) = delete;
public:
- Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
- CreateInstrCB = Ctx.registerCreateInstrCallback(
- [this](Instruction *I) { DAG.notifyCreateInstr(I); });
- }
- ~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }
+ Scheduler(AAResults &AA, Context &Ctx) : DAG(AA, Ctx) {}
+ ~Scheduler() {}
bool trySchedule(ArrayRef<Instruction *> Instrs);
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 061d57c31ce236..6129664c6e429a 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -798,3 +798,35 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
EXPECT_EQ(S1N->getNumUnscheduledSuccs(), 0u); // S1 is scheduled
}
}
+
+TEST_F(DependencyGraphTest, CreateInstrCallback) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+ store i8 %v1, ptr %ptr
+ store i8 %v2, ptr %ptr
+ store i8 %v3, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ [[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+ // Check new instruction callback.
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+ DAG.extend({S1, S3});
+ auto *Arg = F->getArg(3);
+ auto *Ptr = S1->getPointerOperand();
+ sandboxir::StoreInst *NewS =
+ sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+ /*IsVolatile=*/true, Ctx);
+ auto *NewSN = DAG.getNode(NewS);
+ EXPECT_TRUE(NewSN != nullptr);
+ // TODO: Check the dependencies to/from NewSN after they land.
+ // TODO: Check the MemDGNode chain.
+}
More information about the llvm-commits
mailing list