[llvm] [SandboxVec][DAG] Refactoring: Outline code that looks for mem nodes (PR #111750)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 9 13:11:01 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/111750
>From 719e9d3a2c43f7bfddfbc8dffb16f2c56e646219 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 9 Oct 2024 12:00:04 -0700
Subject: [PATCH] [SandboxVec][DAG] Refactoring: Outline code that looks for
mem nodes
---
.../SandboxVectorizer/DependencyGraph.h | 8 ++++
.../SandboxVectorizer/DependencyGraph.cpp | 42 ++++++++++++++-----
.../SandboxVectorizer/DependencyGraphTest.cpp | 14 +++++++
3 files changed, 53 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 134adc4b21ab12..5036dae1f0e278 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -154,6 +154,14 @@ class MemDGNode final : public DGNode {
/// Convenience builders for a MemDGNode interval.
class MemDGNodeIntervalBuilder {
public:
+ /// Scans the instruction chain in \p Intvl top-down, returning the top-most
+ /// MemDGNode, or nullptr.
+ static MemDGNode *getTopMemDGNode(const Interval<Instruction> &Intvl,
+ const DependencyGraph &DAG);
+ /// Scans the instruction chain in \p Intvl bottom-up, returning the
+ /// bottom-most MemDGNode, or nullptr.
+ static MemDGNode *getBotMemDGNode(const Interval<Instruction> &Intvl,
+ const DependencyGraph &DAG);
/// Given \p Instrs it finds their closest mem nodes in the interval and
/// returns the corresponding mem range. Note: BotN (or its neighboring mem
/// node) is included in the range.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 82f253d4c63231..c02eba167390d1 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -32,23 +32,43 @@ void DGNode::dump() const {
}
#endif // NDEBUG
+MemDGNode *
+MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval<Instruction> &Intvl,
+ const DependencyGraph &DAG) {
+ Instruction *I = Intvl.top();
+ Instruction *BeforeI = Intvl.bottom();
+ // Walk down the chain looking for a mem-dep candidate instruction.
+ while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
+ I = I->getNextNode();
+ if (!DGNode::isMemDepNodeCandidate(I))
+ return nullptr;
+ return cast<MemDGNode>(DAG.getNode(I));
+}
+
+MemDGNode *
+MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
+ const DependencyGraph &DAG) {
+ Instruction *I = Intvl.bottom();
+ Instruction *AfterI = Intvl.top();
+ // Walk up the chain looking for a mem-dep candidate instruction.
+ while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
+ I = I->getPrevNode();
+ if (!DGNode::isMemDepNodeCandidate(I))
+ return nullptr;
+ return cast<MemDGNode>(DAG.getNode(I));
+}
+
Interval<MemDGNode>
MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
DependencyGraph &DAG) {
- // If top or bottom instructions are not mem-dep candidate nodes we need to
- // walk down/up the chain and find the mem-dep ones.
- Instruction *MemTopI = Instrs.top();
- Instruction *MemBotI = Instrs.bottom();
- while (!DGNode::isMemDepNodeCandidate(MemTopI) && MemTopI != MemBotI)
- MemTopI = MemTopI->getNextNode();
- while (!DGNode::isMemDepNodeCandidate(MemBotI) && MemBotI != MemTopI)
- MemBotI = MemBotI->getPrevNode();
+ auto *TopMemN = getTopMemDGNode(Instrs, DAG);
// If we couldn't find a mem node in range TopN - BotN then it's empty.
- if (!DGNode::isMemDepNodeCandidate(MemTopI))
+ if (TopMemN == nullptr)
return {};
+ auto *BotMemN = getBotMemDGNode(Instrs, DAG);
+ assert(BotMemN != nullptr && "TopMemN should be null too!");
// Now that we have the mem-dep nodes, create and return the range.
- return Interval<MemDGNode>(cast<MemDGNode>(DAG.getNode(MemTopI)),
- cast<MemDGNode>(DAG.getNode(MemBotI)));
+ return Interval<MemDGNode>(TopMemN, BotMemN);
}
DependencyGraph::DependencyType
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e2f16919a5cddd..b425e5a8ad2145 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -305,6 +305,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ // Check getTopMemDGNode().
+ using B = sandboxir::MemDGNodeIntervalBuilder;
+ using InstrInterval = sandboxir::Interval<sandboxir::Instruction>;
+ EXPECT_EQ(B::getTopMemDGNode(InstrInterval(S0, S0), DAG), S0N);
+ EXPECT_EQ(B::getTopMemDGNode(InstrInterval(S0, Ret), DAG), S0N);
+ EXPECT_EQ(B::getTopMemDGNode(InstrInterval(Add0, Add1), DAG), S0N);
+ EXPECT_EQ(B::getTopMemDGNode(InstrInterval(Add0, Add0), DAG), nullptr);
+
+ // Check getBotMemDGNode().
+ EXPECT_EQ(B::getBotMemDGNode(InstrInterval(S1, S1), DAG), S1N);
+ EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Add0, S1), DAG), S1N);
+ EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Add0, Ret), DAG), S1N);
+ EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Ret, Ret), DAG), nullptr);
+
// Check empty range.
EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(),
testing::ElementsAre());
More information about the llvm-commits
mailing list