[llvm] [SandboxVec][DAG] Extend DAG (PR #111908)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 10 17:59:37 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/111908
>From 7ddf808ee8bd09a7fca5cc18dcc7025072266578 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 7 Oct 2024 16:06:38 -0700
Subject: [PATCH] [SandboxVec][DAG] Extend DAG
This patch implements growing the DAG towards the top or bottom.
This does the necessary dependency checks and adds new mem dependencies.
---
.../SandboxVectorizer/DependencyGraph.h | 8 +-
.../SandboxVectorizer/DependencyGraph.cpp | 117 +++++++++++++++---
.../SandboxVectorizer/DependencyGraphTest.cpp | 67 ++++++++++
3 files changed, 176 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 5fa57efc1462e8..0da52c4236d77e 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -284,6 +284,10 @@ class DependencyGraph {
/// \p DstN.
void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
+ /// Create DAG nodes for instrs in \p NewInterval and update the MemNode
+ /// chain.
+ void createNewNodes(const Interval<Instruction> &NewInterval);
+
public:
DependencyGraph(AAResults &AA)
: BatchAA(std::make_unique<BatchAAResults>(AA)) {}
@@ -309,8 +313,10 @@ class DependencyGraph {
return It->second.get();
}
/// Build/extend the dependency graph such that it includes \p Instrs. Returns
- /// the interval spanning \p Instrs.
+ /// the range of instructions added to the DAG.
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
+ /// \Returns the range of instructions included in the DAG.
+ Interval<Instruction> getInterval() const { return DAGInterval; }
#ifndef NDEBUG
void print(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 0cd2240e7ff1b3..db58069de47051 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -215,17 +215,11 @@ void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
}
}
-Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
- if (Instrs.empty())
- return {};
-
- Interval<Instruction> InstrInterval(Instrs);
-
- DGNode *LastN = getOrCreateNode(InstrInterval.top());
- // Create DGNodes for all instrs in Interval to avoid future Instruction to
- // DGNode lookups.
+void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
+ // Create Nodes only for the new sections of the DAG.
+ DGNode *LastN = getOrCreateNode(NewInterval.top());
MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
- for (Instruction &I : drop_begin(InstrInterval)) {
+ for (Instruction &I : drop_begin(NewInterval)) {
auto *N = getOrCreateNode(&I);
// Build the Mem node chain.
if (auto *MemN = dyn_cast<MemDGNode>(N)) {
@@ -235,16 +229,109 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
LastMemN = MemN;
}
}
+ // Link new MemDGNode chain with the old one, if any.
+ if (!DAGInterval.empty()) {
+ // TODO: Implement Interval::comesBefore() to replace this check.
+ bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.top());
+ assert(
+ (NewIsAbove || DAGInterval.bottom()->comesBefore(NewInterval.top())) &&
+ "Expected NewInterval below DAGInterval.");
+ const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
+ const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
+ MemDGNode *LinkTopN =
+ MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
+ MemDGNode *LinkBotN =
+ MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
+ assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
+ if (LinkTopN != nullptr && LinkBotN != nullptr) {
+ LinkTopN->setNextNode(LinkBotN);
+ LinkBotN->setPrevNode(LinkTopN);
+ }
+#ifndef NDEBUG
+ // TODO: Remove this once we've done enough testing.
+ // Check that the chain is well formed.
+ auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
+ MemDGNode *ChainTopN =
+ MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
+ MemDGNode *ChainBotN =
+ MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
+ if (ChainTopN != nullptr && ChainBotN != nullptr) {
+ for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
+ LastN = N, N = N->getNextNode()) {
+ assert(N == LastN->getNextNode() && "Bad chain!");
+ assert(N->getPrevNode() == LastN && "Bad chain!");
+ }
+ }
+#endif // NDEBUG
+ }
+}
+
+Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
+ if (Instrs.empty())
+ return {};
+
+ Interval<Instruction> InstrsInterval(Instrs);
+ Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
+ auto NewInterval = Union.getSingleDiff(DAGInterval);
+ if (NewInterval.empty())
+ return {};
+
+ createNewNodes(NewInterval);
+
// Create the dependencies.
- auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this);
- if (!DstRange.empty()) {
- for (MemDGNode &DstN : drop_begin(DstRange)) {
- auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
+ //
+ // 1. DAGInterval empty 2. New is below Old 3. New is above old
+ // ------------------------ ------------------- -------------------
+ // Scan: DstN: Scan:
+ // +---+ -ScanTopN +---+DstTopN -ScanTopN
+ // | | | |New| |
+ // |Old| | +---+ -ScanBotN
+ // | | | +---+
+ // DstN: Scan: +---+DstN: | | |
+ // +---+DstTopN -ScanTopN +---+DstTopN | |Old|
+ // |New| | |New| | | |
+ // +---+DstBotN -ScanBotN +---+DstBotN -ScanBotN +---+DstBotN
+
+ // 1. This is a new DAG.
+ if (DAGInterval.empty()) {
+ assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
+ auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+ if (!DstRange.empty()) {
+ for (MemDGNode &DstN : drop_begin(DstRange)) {
+ auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
+ scanAndAddDeps(DstN, SrcRange);
+ }
+ }
+ }
+ // 2. The new section is below the old section.
+ else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
+ auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+ auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
+ DAGInterval.getUnionInterval(NewInterval), *this);
+ for (MemDGNode &DstN : DstRange) {
+ auto SrcRange =
+ Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
scanAndAddDeps(DstN, SrcRange);
}
}
+ // 3. The new section is above the old section.
+ else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
+ auto DstRange = MemDGNodeIntervalBuilder::make(
+ NewInterval.getUnionInterval(DAGInterval), *this);
+ auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+ if (!DstRange.empty()) {
+ for (MemDGNode &DstN : drop_begin(DstRange)) {
+ auto SrcRange =
+ Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
+ scanAndAddDeps(DstN, SrcRange);
+ }
+ }
+ } else {
+ llvm_unreachable("We don't expect extending in both directions!");
+ }
- return InstrInterval;
+ DAGInterval = Union;
+ return NewInterval;
}
#ifndef NDEBUG
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 7e2be25fa25ae6..3dbf03e4ba44e2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -681,3 +681,70 @@ define void @foo() {
EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
}
+
+TEST_F(DependencyGraphTest, Extend) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
+ store i8 %v1, ptr %ptr
+ store i8 %v2, ptr %ptr
+ store i8 %v3, ptr %ptr
+ store i8 %v4, ptr %ptr
+ store i8 %v5, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S4 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S5 = cast<sandboxir::StoreInst>(&*It++);
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ {
+ // Scenario 1: Build new DAG
+ auto NewIntvl = DAG.extend({S3, S3});
+ EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S3, S3));
+ EXPECT_EQ(DAG.getInterval().top(), S3);
+ EXPECT_EQ(DAG.getInterval().bottom(), S3);
+ [[maybe_unused]] auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ }
+ {
+ // Scenario 2: Extend below
+ auto NewIntvl = DAG.extend({S5, S5});
+ EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S4, S5));
+ auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
+ auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
+ EXPECT_TRUE(S4N->hasMemPred(S3N));
+ EXPECT_TRUE(S5N->hasMemPred(S4N));
+ EXPECT_TRUE(S5N->hasMemPred(S3N));
+ }
+ {
+ // Scenario 3: Extend above
+ auto NewIntvl = DAG.extend({S1, S2});
+ EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S1, S2));
+ auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+ auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
+ auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
+
+ EXPECT_TRUE(S2N->hasMemPred(S1N));
+
+ EXPECT_TRUE(S3N->hasMemPred(S2N));
+ EXPECT_TRUE(S3N->hasMemPred(S1N));
+
+ EXPECT_TRUE(S4N->hasMemPred(S3N));
+ EXPECT_TRUE(S4N->hasMemPred(S2N));
+ EXPECT_TRUE(S4N->hasMemPred(S1N));
+
+ EXPECT_TRUE(S5N->hasMemPred(S4N));
+ EXPECT_TRUE(S5N->hasMemPred(S3N));
+ EXPECT_TRUE(S5N->hasMemPred(S2N));
+ EXPECT_TRUE(S5N->hasMemPred(S1N));
+ }
+}
More information about the llvm-commits
mailing list