[llvm] [SandboxVec][DAG] Refactoring: Move MemPreds from DGNode to MemDGNode (PR #111897)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 10 12:15:18 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-vectorizers
Author: vporpo (vporpo)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/111897.diff
3 Files Affected:
- (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+20-16)
- (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+11-9)
- (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+41-32)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index eba6d7562e41de..da50e5326ea069 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -1,4 +1,4 @@
-//===- DependencyGraph.h ----------------------------------*- C++ -*-===//
+//===- DependencyGraph.h ----------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -96,9 +96,6 @@ class DGNode {
// TODO: Use a PointerIntPair for SubclassID and I.
/// For isa/dyn_cast etc.
DGNodeID SubclassID;
- // TODO: Move MemPreds to MemDGNode.
- /// Memory predecessors.
- DenseSet<MemDGNode *> MemPreds;
DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
friend class MemDGNode; // For constructor.
@@ -170,17 +167,6 @@ class DGNode {
}
Instruction *getInstruction() const { return I; }
- void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
- /// \Returns all memory dependency predecessors.
- iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
- return make_range(MemPreds.begin(), MemPreds.end());
- }
- /// \Returns true if there is a memory dependency N->this.
- bool hasMemPred(DGNode *N) const {
- if (auto *MN = dyn_cast<MemDGNode>(N))
- return MemPreds.count(MN);
- return false;
- }
#ifndef NDEBUG
virtual void print(raw_ostream &OS, bool PrintDeps = true) const;
@@ -198,6 +184,9 @@ class DGNode {
class MemDGNode final : public DGNode {
MemDGNode *PrevMemN = nullptr;
MemDGNode *NextMemN = nullptr;
+ /// Memory predecessors.
+ DenseSet<MemDGNode *> MemPreds;
+ friend class PredIterator; // For MemPreds.
void setNextNode(MemDGNode *N) { NextMemN = N; }
void setPrevNode(MemDGNode *N) { PrevMemN = N; }
@@ -222,6 +211,21 @@ class MemDGNode final : public DGNode {
MemDGNode *getPrevNode() const { return PrevMemN; }
/// \Returns the next Mem DGNode in instruction order.
MemDGNode *getNextNode() const { return NextMemN; }
+ /// Adds the mem dependency edge PredN->this.
+ void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
+ /// \Returns true if there is a memory dependency N->this.
+ bool hasMemPred(DGNode *N) const {
+ if (auto *MN = dyn_cast<MemDGNode>(N))
+ return MemPreds.count(MN);
+ return false;
+ }
+ /// \Returns all memory dependency predecessors. Used by tests.
+ iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
+ return make_range(MemPreds.begin(), MemPreds.end());
+ }
+#ifndef NDEBUG
+ virtual void print(raw_ostream &OS, bool PrintDeps = true) const override;
+#endif // NDEBUG
};
/// Convenience builders for a MemDGNode interval.
@@ -266,7 +270,7 @@ class DependencyGraph {
/// Go through all mem nodes in \p SrcScanRange and try to add dependencies to
/// \p DstN.
- void scanAndAddDeps(DGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
+ void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
public:
DependencyGraph(AAResults &AA)
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 7aea466ed6d8db..70843812ff65bc 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -23,7 +23,8 @@ PredIterator::value_type PredIterator::operator*() {
// or a mem predecessor.
if (OpIt != OpItE)
return DAG->getNode(cast<Instruction>((Value *)*OpIt));
- assert(MemIt != cast<MemDGNode>(N)->memPreds().end() &&
+ // It's a MemDGNode with OpIt == end, so we need to use MemIt.
+ assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
"Cant' dereference end iterator!");
return *MemIt;
}
@@ -45,7 +46,8 @@ PredIterator &PredIterator::operator++() {
OpIt = skipNonInstr(OpIt, OpItE);
return *this;
}
- assert(MemIt != cast<MemDGNode>(N)->memPreds().end() && "Already at end!");
+ // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
+ assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
++MemIt;
return *this;
}
@@ -57,10 +59,14 @@ bool PredIterator::operator==(const PredIterator &Other) const {
}
#ifndef NDEBUG
-void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
+void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); }
+void DGNode::dump() const {
+ print(dbgs());
+ dbgs() << "\n";
+}
+void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
I->dumpOS(OS);
if (PrintDeps) {
- OS << "\n";
// Print memory preds.
static constexpr const unsigned Indent = 4;
for (auto *Pred : MemPreds) {
@@ -70,10 +76,6 @@ void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
}
}
}
-void DGNode::dump() const {
- print(dbgs());
- dbgs() << "\n";
-}
#endif // NDEBUG
Interval<MemDGNode>
@@ -179,7 +181,7 @@ bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
llvm_unreachable("Unknown DependencyType enum");
}
-void DependencyGraph::scanAndAddDeps(DGNode &DstN,
+void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
const Interval<MemDGNode> &SrcScanRange) {
assert(isa<MemDGNode>(DstN) &&
"DstN is the mem dep destination, so it must be mem");
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 6b3d9cc77c9955..5a9c9815ca42fa 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -50,10 +50,10 @@ struct DependencyGraphTest : public testing::Test {
return *AA;
}
/// \Returns true if there is a dependency: SrcN->DstN.
- bool dependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
- const auto &Preds = DstN->memPreds();
- auto It = find(Preds, SrcN);
- return It != Preds.end();
+ bool memDependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
+ if (auto *MemDstN = dyn_cast<sandboxir::MemDGNode>(DstN))
+ return MemDstN->hasMemPred(SrcN);
+ return false;
}
};
@@ -230,9 +230,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
EXPECT_EQ(Span.top(), &*BB->begin());
EXPECT_EQ(Span.bottom(), BB->getTerminator());
- sandboxir::DGNode *N0 = DAG.getNode(S0);
- sandboxir::DGNode *N1 = DAG.getNode(S1);
- sandboxir::DGNode *N2 = DAG.getNode(Ret);
+ auto *N0 = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
+ auto *N1 = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *N2 = DAG.getNode(Ret);
+
// Check getInstruction().
EXPECT_EQ(N0->getInstruction(), S0);
EXPECT_EQ(N1->getInstruction(), S1);
@@ -247,7 +248,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
// Check memPreds().
EXPECT_TRUE(N0->memPreds().empty());
EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
- EXPECT_TRUE(N2->memPreds().empty());
+ EXPECT_TRUE(N2->preds(DAG).empty());
}
TEST_F(DependencyGraphTest, Preds) {
@@ -399,12 +400,14 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
- auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
- auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+ auto *Store0N = cast<sandboxir::MemDGNode>(
+ DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
+ auto *Store1N = cast<sandboxir::MemDGNode>(
+ DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
EXPECT_TRUE(Store0N->memPreds().empty());
EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N));
- EXPECT_TRUE(RetN->memPreds().empty());
+ EXPECT_TRUE(RetN->preds(DAG).empty());
}
TEST_F(DependencyGraphTest, NonAliasingStores) {
@@ -422,13 +425,15 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
- auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
- auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+ auto *Store0N = cast<sandboxir::MemDGNode>(
+ DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
+ auto *Store1N = cast<sandboxir::MemDGNode>(
+ DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
// We expect no dependencies because the stores don't alias.
EXPECT_TRUE(Store0N->memPreds().empty());
EXPECT_TRUE(Store1N->memPreds().empty());
- EXPECT_TRUE(RetN->memPreds().empty());
+ EXPECT_TRUE(RetN->preds(DAG).empty());
}
TEST_F(DependencyGraphTest, VolatileLoads) {
@@ -446,12 +451,14 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
- auto *Ld0N = DAG.getNode(cast<sandboxir::LoadInst>(&*It++));
- auto *Ld1N = DAG.getNode(cast<sandboxir::LoadInst>(&*It++));
+ auto *Ld0N = cast<sandboxir::MemDGNode>(
+ DAG.getNode(cast<sandboxir::LoadInst>(&*It++)));
+ auto *Ld1N = cast<sandboxir::MemDGNode>(
+ DAG.getNode(cast<sandboxir::LoadInst>(&*It++)));
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
EXPECT_TRUE(Ld0N->memPreds().empty());
EXPECT_THAT(Ld1N->memPreds(), testing::ElementsAre(Ld0N));
- EXPECT_TRUE(RetN->memPreds().empty());
+ EXPECT_TRUE(RetN->preds(DAG).empty());
}
TEST_F(DependencyGraphTest, VolatileSotres) {
@@ -469,12 +476,14 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) {
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto It = BB->begin();
- auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
- auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+ auto *Store0N = cast<sandboxir::MemDGNode>(
+ DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
+ auto *Store1N = cast<sandboxir::MemDGNode>(
+ DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
EXPECT_TRUE(Store0N->memPreds().empty());
EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N));
- EXPECT_TRUE(RetN->memPreds().empty());
+ EXPECT_TRUE(RetN->preds(DAG).empty());
}
TEST_F(DependencyGraphTest, Call) {
@@ -498,12 +507,12 @@ define void @foo(float %v1, float %v2) {
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
auto It = BB->begin();
- auto *Call1N = DAG.getNode(&*It++);
+ auto *Call1N = cast<sandboxir::MemDGNode>(DAG.getNode(&*It++));
auto *AddN = DAG.getNode(&*It++);
- auto *Call2N = DAG.getNode(&*It++);
+ auto *Call2N = cast<sandboxir::MemDGNode>(DAG.getNode(&*It++));
EXPECT_THAT(Call1N->memPreds(), testing::ElementsAre());
- EXPECT_THAT(AddN->memPreds(), testing::ElementsAre());
+ EXPECT_THAT(AddN->preds(DAG), testing::ElementsAre());
EXPECT_THAT(Call2N->memPreds(), testing::ElementsAre(Call1N));
}
@@ -534,8 +543,8 @@ define void @foo() {
auto *AllocaN = DAG.getNode(&*It++);
auto *StackRestoreN = DAG.getNode(&*It++);
- EXPECT_TRUE(dependency(AllocaN, StackRestoreN));
- EXPECT_TRUE(dependency(StackSaveN, AllocaN));
+ EXPECT_TRUE(memDependency(AllocaN, StackRestoreN));
+ EXPECT_TRUE(memDependency(StackSaveN, AllocaN));
}
// Checks that stacksave and stackrestore depend on other mem instrs.
@@ -567,9 +576,9 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) {
auto *StackRestoreN = DAG.getNode(&*It++);
auto *Store1N = DAG.getNode(&*It++);
- EXPECT_TRUE(dependency(Store0N, StackSaveN));
- EXPECT_TRUE(dependency(StackSaveN, StackRestoreN));
- EXPECT_TRUE(dependency(StackRestoreN, Store1N));
+ EXPECT_TRUE(memDependency(Store0N, StackSaveN));
+ EXPECT_TRUE(memDependency(StackSaveN, StackRestoreN));
+ EXPECT_TRUE(memDependency(StackRestoreN, Store1N));
}
// Make sure there is a dependency between a stackrestore and an alloca.
@@ -596,7 +605,7 @@ define void @foo(ptr %ptr) {
auto *StackRestoreN = DAG.getNode(&*It++);
auto *AllocaN = DAG.getNode(&*It++);
- EXPECT_TRUE(dependency(StackRestoreN, AllocaN));
+ EXPECT_TRUE(memDependency(StackRestoreN, AllocaN));
}
// Make sure there is a dependency between the alloca and stacksave
@@ -623,7 +632,7 @@ define void @foo(ptr %ptr) {
auto *AllocaN = DAG.getNode(&*It++);
auto *StackSaveN = DAG.getNode(&*It++);
- EXPECT_TRUE(dependency(AllocaN, StackSaveN));
+ EXPECT_TRUE(memDependency(AllocaN, StackSaveN));
}
// A non-InAlloca in a stacksave-stackrestore region does not need extra
@@ -655,6 +664,6 @@ define void @foo() {
auto *AllocaN = DAG.getNode(&*It++);
auto *StackRestoreN = DAG.getNode(&*It++);
- EXPECT_FALSE(dependency(StackSaveN, AllocaN));
- EXPECT_FALSE(dependency(AllocaN, StackRestoreN));
+ EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
+ EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/111897
More information about the llvm-commits
mailing list