[llvm] 166b2e8 - [SandboxVec][DAG] Update DAG when a new instruction is created (#126124)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 6 14:12:06 PST 2025
Author: vporpo
Date: 2025-02-06T14:12:03-08:00
New Revision: 166b2e88378bae4d74f9bdc56f1521150162fbf1
URL: https://github.com/llvm/llvm-project/commit/166b2e88378bae4d74f9bdc56f1521150162fbf1
DIFF: https://github.com/llvm/llvm-project/commit/166b2e88378bae4d74f9bdc56f1521150162fbf1.diff
LOG: [SandboxVec][DAG] Update DAG when a new instruction is created (#126124)
The DAG will now receive a callback whenever a new instruction is
created and will update itself accordingly.
Added:
Modified:
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index f4e74fdee84c919..fab456d925526c2 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -263,6 +263,7 @@ class MemDGNode final : public DGNode {
void addMemPred(MemDGNode *PredN) {
[[maybe_unused]] auto Inserted = MemPreds.insert(PredN).second;
assert(Inserted && "PredN already exists!");
+ assert(PredN != this && "Trying to add a dependency to self!");
if (!Scheduled) {
++PredN->UnscheduledSuccs;
}
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
index 18cd29e9e14ee40..f6c5a204673372f 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
@@ -108,6 +108,10 @@ template <typename T> class Interval {
return (Top == I || Top->comesBefore(I)) &&
(I == Bottom || I->comesBefore(Bottom));
}
+ /// \Returns true if \p Elm is right before the top or right after the bottom.
+ bool touches(T *Elm) const {
+ return Top == Elm->getNextNode() || Bottom == Elm->getPrevNode();
+ }
T *top() const { return Top; }
T *bottom() const { return Bottom; }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index e03cf32be024406..2680667afc4de29 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -368,8 +368,13 @@ MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
}
void DependencyGraph::notifyCreateInstr(Instruction *I) {
- auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
- // TODO: Update the dependencies for the new node.
+ // Nothing to do if the node is not in the focus range of the DAG.
+ if (!(DAGInterval.contains(I) || DAGInterval.touches(I)))
+ return;
+ // Include `I` into the interval.
+ DAGInterval = DAGInterval.getUnionInterval({I, I});
+ auto *N = getOrCreateNode(I);
+ auto *MemN = dyn_cast<MemDGNode>(N);
// Update the MemDGNode chain if this is a memory node.
if (MemN != nullptr) {
@@ -381,6 +386,21 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
NextMemN->PrevMemN = MemN;
MemN->NextMemN = NextMemN;
}
+
+ // Add Mem dependencies.
+ // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN.
+ if (DAGInterval.top()->comesBefore(I)) {
+ Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode());
+ auto SrcInterval = MemDGNodeIntervalBuilder::make(AboveIntvl, *this);
+ scanAndAddDeps(*MemN, SrcInterval);
+ }
+ // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN.
+ if (I->comesBefore(DAGInterval.bottom())) {
+ Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom());
+ for (MemDGNode &BelowN :
+ MemDGNodeIntervalBuilder::make(BelowIntvl, *this))
+ scanAndAddDeps(BelowN, Interval<MemDGNode>(MemN, MemN));
+ }
}
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 263a37ac335d2ae..f1e9afefb45311b 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -832,9 +832,10 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
}
}
+// Check that the DAG gets updated when we create a new instruction.
TEST_F(DependencyGraphTest, CreateInstrCallback) {
parseIR(C, R"IR(
-define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %new1, i8 %new2) {
store i8 %v1, ptr %ptr
store i8 %v2, ptr %ptr
store i8 %v3, ptr %ptr
@@ -851,42 +852,52 @@ define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- // Check new instruction callback.
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
- DAG.extend({S1, Ret});
- auto *Arg = F->getArg(3);
+ // Create a DAG spanning S1 to S3.
+ DAG.extend({S1, S3});
+ auto *ArgNew1 = F->getArg(4);
+ auto *ArgNew2 = F->getArg(5);
auto *Ptr = S1->getPointerOperand();
+
+ auto *S1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ sandboxir::MemDGNode *New1MemN = nullptr;
+ sandboxir::MemDGNode *New2MemN = nullptr;
{
+ // Create a new store before S3 (within the span of the DAG).
sandboxir::StoreInst *NewS =
- sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+ sandboxir::StoreInst::create(ArgNew1, Ptr, Align(8), S3->getIterator(),
/*IsVolatile=*/true, Ctx);
- auto *NewSN = DAG.getNode(NewS);
- EXPECT_TRUE(NewSN != nullptr);
-
// Check the MemDGNode chain.
- auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
- auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
- auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
- EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
- EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
- EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
- EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
+ New1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+ EXPECT_EQ(S2MemN->getNextNode(), New1MemN);
+ EXPECT_EQ(New1MemN->getPrevNode(), S2MemN);
+ EXPECT_EQ(New1MemN->getNextNode(), S3MemN);
+ EXPECT_EQ(S3MemN->getPrevNode(), New1MemN);
+
+ // Check dependencies.
+ EXPECT_TRUE(memDependency(S1MemN, New1MemN));
+ EXPECT_TRUE(memDependency(S2MemN, New1MemN));
+ EXPECT_TRUE(memDependency(New1MemN, S3MemN));
}
-
{
- // Also check if new node is at the end of the BB, after Ret.
+ // Create a new store before Ret (outside the current DAG).
sandboxir::StoreInst *NewS =
- sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
+ sandboxir::StoreInst::create(ArgNew2, Ptr, Align(8), Ret->getIterator(),
/*IsVolatile=*/true, Ctx);
// Check the MemDGNode chain.
- auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
- auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
- EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
- EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
- EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
+ New2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+ EXPECT_EQ(S3MemN->getNextNode(), New2MemN);
+ EXPECT_EQ(New2MemN->getPrevNode(), S3MemN);
+ EXPECT_EQ(New2MemN->getNextNode(), nullptr);
+
+ // Check dependencies.
+ EXPECT_TRUE(memDependency(S1MemN, New2MemN));
+ EXPECT_TRUE(memDependency(S2MemN, New2MemN));
+ EXPECT_TRUE(memDependency(New1MemN, New2MemN));
+ EXPECT_TRUE(memDependency(S3MemN, New2MemN));
}
-
- // TODO: Check the dependencies to/from NewSN after they land.
}
TEST_F(DependencyGraphTest, EraseInstrCallback) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
index 32521ed79a314be..59498371b4d73e6 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
@@ -87,6 +87,15 @@ define void @foo(i8 %v0) {
EXPECT_FALSE(One.contains(I1));
EXPECT_FALSE(One.contains(I2));
EXPECT_FALSE(One.contains(Ret));
+ // Check touches().
+ {
+ sandboxir::Interval<sandboxir::Instruction> Intvl(I2, I2);
+ EXPECT_TRUE(Intvl.touches(I1));
+ EXPECT_TRUE(Intvl.contains(I2));
+ EXPECT_FALSE(Intvl.touches(I2));
+ EXPECT_TRUE(Intvl.touches(Ret));
+ EXPECT_FALSE(Intvl.touches(I0));
+ }
// Check iterator.
auto BBIt = BB->begin();
for (auto &I : Intvl)
More information about the llvm-commits
mailing list