[llvm] [SandboxVec][DAG] Implement PredIterator (PR #111604)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 8 16:24:06 PDT 2024
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/111604
This patch implements an iterator for iterating over both use-def and mem dependencies of MemDGNodes.
>From 6eeacad79d8134989c909d31c216f3ab6f5d5228 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 4 Oct 2024 10:57:03 -0700
Subject: [PATCH] [SandboxVec][DAG] Implement PredIterator
This patch implements an iterator for iterating over both use-def and
mem dependencies of MemDGNodes.
---
.../SandboxVectorizer/DependencyGraph.h | 73 +++++++++++++++++++
.../SandboxVectorizer/DependencyGraph.cpp | 35 +++++++++
.../SandboxVectorizer/DependencyGraphTest.cpp | 41 +++++++++++
3 files changed, 149 insertions(+)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index ab49c3aa27143c..8c466465438194 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -40,6 +40,54 @@ enum class DGNodeID {
MemDGNode,
};
+class DGNode;
+class MemDGNode;
+class DependencyGraph;
+
+/// While OpIt points to a Value that is not an Instruction keep incrementing
+/// it. \Returns the first iterator that points to an Instruction, or end.
+[[nodiscard]] static User::op_iterator skipNonInstr(User::op_iterator OpIt,
+ User::op_iterator OpItE) {
+ while (OpIt != OpItE && !isa<Instruction>((*OpIt).get()))
+ ++OpIt;
+ return OpIt;
+}
+
+/// Iterate over both def-use and mem dependencies.
+class PredIterator {
+ User::op_iterator OpIt;
+ User::op_iterator OpItE;
+ DenseSet<MemDGNode *>::iterator MemIt;
+ DGNode *N = nullptr;
+ DependencyGraph *DAG = nullptr;
+
+ PredIterator(const User::op_iterator &OpIt, const User::op_iterator &OpItE,
+ const DenseSet<MemDGNode *>::iterator &MemIt, DGNode *N,
+ DependencyGraph &DAG)
+ : OpIt(OpIt), OpItE(OpItE), MemIt(MemIt), N(N), DAG(&DAG) {}
+ PredIterator(const User::op_iterator &OpIt, const User::op_iterator &OpItE,
+ DGNode *N, DependencyGraph &DAG)
+ : OpIt(OpIt), OpItE(OpItE), N(N), DAG(&DAG) {}
+ friend class DGNode; // For constructor
+ friend class MemDGNode; // For constructor
+
+public:
+ using difference_type = std::ptrdiff_t;
+ using value_type = DGNode *;
+ using pointer = value_type *;
+ using reference = value_type &;
+ using iterator_category = std::input_iterator_tag;
+ value_type operator*();
+ PredIterator &operator++();
+ PredIterator operator++(int) {
+ auto Copy = *this;
+ ++(*this);
+ return Copy;
+ }
+ bool operator==(const PredIterator &Other) const;
+ bool operator!=(const PredIterator &Other) const { return !(*this == Other); }
+};
+
/// A DependencyGraph Node that points to an Instruction and contains memory
/// dependency edges.
class DGNode {
@@ -63,6 +111,23 @@ class DGNode {
virtual ~DGNode() = default;
/// \Returns true if this is before \p Other in program order.
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
+ using iterator = PredIterator;
+ virtual iterator preds_begin(DependencyGraph &DAG) {
+ return PredIterator(skipNonInstr(I->op_begin(), I->op_end()), I->op_end(),
+ this, DAG);
+ }
+ virtual iterator preds_end(DependencyGraph &DAG) {
+ return PredIterator(I->op_end(), I->op_end(), this, DAG);
+ }
+ iterator preds_begin(DependencyGraph &DAG) const {
+ return const_cast<DGNode *>(this)->preds_begin(DAG);
+ }
+ iterator preds_end(DependencyGraph &DAG) const {
+ return const_cast<DGNode *>(this)->preds_end(DAG);
+ }
+ iterator_range<iterator> preds(DependencyGraph &DAG) const {
+ return make_range(preds_begin(DAG), preds_end(DAG));
+ }
static bool isStackSaveOrRestoreIntrinsic(Instruction *I) {
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
@@ -145,6 +210,14 @@ class MemDGNode final : public DGNode {
static bool classof(const DGNode *Other) {
return Other->SubclassID == DGNodeID::MemDGNode;
}
+ iterator preds_begin(DependencyGraph &DAG) override {
+ auto OpEndIt = I->op_end();
+ return PredIterator(skipNonInstr(I->op_begin(), OpEndIt), OpEndIt,
+ MemPreds.begin(), this, DAG);
+ }
+ iterator preds_end(DependencyGraph &DAG) override {
+ return PredIterator(I->op_end(), I->op_end(), MemPreds.end(), this, DAG);
+ }
/// \Returns the previous Mem DGNode in instruction order.
MemDGNode *getPrevNode() const { return PrevMemN; }
/// \Returns the next Mem DGNode in instruction order.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 845fadefc9bf03..1730c4a9e4bc0f 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -8,10 +8,45 @@
#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Utils.h"
namespace llvm::sandboxir {
+PredIterator::value_type PredIterator::operator*() {
+ // If it's a DGNode, or a MemDGNode with an OpIt != end.
+ if (!isa<MemDGNode>(N) || OpIt != OpItE) {
+ assert(OpIt != OpItE && "Can't dereference end iterator!");
+ Value *OpV = *OpIt;
+ return DAG->getNode(cast<Instruction>(OpV));
+ }
+ // It's a MemDGNode with OpIt == end, so we need to use MemIt.
+ assert(MemIt != cast<MemDGNode>(N)->memPreds().end() &&
+ "Cant' dereference end iterator!");
+ return *MemIt;
+}
+
+PredIterator &PredIterator::operator++() {
+ // If it's a DGNode, or a MemDGNode with an OpIt != end.
+ if (!isa<MemDGNode>(N) || OpIt != OpItE) {
+ assert(OpIt != OpItE && "Already at end!");
+ ++OpIt;
+ // Skip operands that are not instructions.
+ OpIt = skipNonInstr(OpIt, OpItE);
+ return *this;
+ }
+ // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
+ assert(MemIt != cast<MemDGNode>(N)->memPreds().end() && "Already at end!");
+ ++MemIt;
+ return *this;
+}
+
+bool PredIterator::operator==(const PredIterator &Other) const {
+ assert(DAG == Other.DAG && "Iterators of different DAGs!");
+ assert(N == Other.N && "Iterators of different nodes!");
+ return OpIt == Other.OpIt && MemIt == Other.MemIt;
+}
+
#ifndef NDEBUG
void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
I->dumpOS(OS);
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e2f16919a5cddd..6b3d9cc77c9955 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -240,12 +240,53 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
EXPECT_TRUE(N1->hasMemPred(N0));
EXPECT_FALSE(N0->hasMemPred(N1));
+ // Check preds().
+ EXPECT_TRUE(N0->preds(DAG).empty());
+ EXPECT_THAT(N1->preds(DAG), testing::ElementsAre(N0));
+
// Check memPreds().
EXPECT_TRUE(N0->memPreds().empty());
EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
EXPECT_TRUE(N2->memPreds().empty());
}
+TEST_F(DependencyGraphTest, Preds) {
+ parseIR(C, R"IR(
+declare ptr @bar(i8)
+define i8 @foo(i8 %v0, i8 %v1) {
+ %add0 = add i8 %v0, %v0
+ %add1 = add i8 %v1, %v1
+ %add2 = add i8 %add0, %add1
+ %ptr = call ptr @bar(i8 %add1)
+ store i8 %add2, ptr %ptr
+ ret i8 %add2
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+ auto *AddN0 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+ auto *AddN1 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+ auto *AddN2 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+ auto *CallN = DAG.getNode(cast<sandboxir::CallInst>(&*It++));
+ auto *StN = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+ auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
+
+ // Check preds().
+ EXPECT_THAT(AddN0->preds(DAG), testing::ElementsAre());
+ EXPECT_THAT(AddN1->preds(DAG), testing::ElementsAre());
+ EXPECT_THAT(AddN2->preds(DAG), testing::ElementsAre(AddN0, AddN1));
+ EXPECT_THAT(CallN->preds(DAG), testing::ElementsAre(AddN1));
+ EXPECT_THAT(StN->preds(DAG),
+ testing::UnorderedElementsAre(CallN, CallN, AddN2));
+ EXPECT_THAT(RetN->preds(DAG), testing::ElementsAre(AddN2));
+}
+
TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
More information about the llvm-commits
mailing list