[llvm] [SandboxVec][DAG] Boilerplate (PR #108862)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 17 08:25:17 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/108862
>From ccef63e643718644fefd07b18dff975f43057107 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 13 Sep 2024 10:09:24 -0700
Subject: [PATCH] [SandboxVec][DAG] Boilerplate
This patch adds a very basic implementation of the Dependency Graph to be used
by the vectorizer.
---
.../SandboxVectorizer/DependencyGraph.h | 85 +++++++++++++++++++
llvm/lib/Transforms/Vectorize/CMakeLists.txt | 1 +
.../SandboxVectorizer/DependencyGraph.cpp | 64 ++++++++++++++
.../Transforms/Vectorize/CMakeLists.txt | 2 +
.../SandboxVectorizer/CMakeLists.txt | 12 +++
.../SandboxVectorizer/DependencyGraphTest.cpp | 63 ++++++++++++++
6 files changed, 227 insertions(+)
create mode 100644 llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
create mode 100644 llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
create mode 100644 llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
create mode 100644 llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
new file mode 100644
index 00000000000000..0adf1cbab74876
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -0,0 +1,85 @@
+//===- DependencyGraph.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This class is a dependency graph used by the vectorizer's instruction
+// scheduler.
+//
+// The nodes of the graph are objects of the `DGNode` class.
+// The edges between `DGNode`s are implicitly defined by an ordered set of
+// predecessor nodes, to save memory.
+// Finally the whole dependency graph is an object of the `DependencyGraph`
+// class, which also provides the API for creating/extending the graph from
+// input Sandbox IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_DEPENDENCYGRAPH_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_DEPENDENCYGRAPH_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+
+namespace llvm::sandboxir {
+
+/// A DependencyGraph Node that points to an Instruction and contains memory
+/// dependency edges.
+class DGNode {
+ Instruction *I;
+ /// Memory predecessors.
+ DenseSet<DGNode *> MemPreds;
+
+public:
+ DGNode(Instruction *I) : I(I) {}
+ Instruction *getInstruction() const { return I; }
+ void addMemPred(DGNode *PredN) { MemPreds.insert(PredN); }
+ /// \Returns all memory dependency predecessors.
+ iterator_range<DenseSet<DGNode *>::const_iterator> memPreds() const {
+ return make_range(MemPreds.begin(), MemPreds.end());
+ }
+ /// \Returns true if there is a memory dependency N->this.
+ bool hasMemPred(DGNode *N) const { return MemPreds.count(N); }
+#ifndef NDEBUG
+ void print(raw_ostream &OS, bool PrintDeps = true) const;
+ friend raw_ostream &operator<<(DGNode &N, raw_ostream &OS) {
+ N.print(OS);
+ return OS;
+ }
+ LLVM_DUMP_METHOD void dump() const;
+#endif // NDEBUG
+};
+
+class DependencyGraph {
+private:
+ DenseMap<Instruction *, std::unique_ptr<DGNode>> InstrToNodeMap;
+
+public:
+ DependencyGraph() {}
+
+ DGNode *getNode(Instruction *I) const {
+ auto It = InstrToNodeMap.find(I);
+ return It != InstrToNodeMap.end() ? It->second.get() : nullptr;
+ }
+ DGNode *getOrCreateNode(Instruction *I) {
+ auto [It, NotInMap] = InstrToNodeMap.try_emplace(I);
+ if (NotInMap)
+ It->second = std::make_unique<DGNode>(I);
+ return It->second.get();
+ }
+ // TODO: extend() should work with intervals not the whole BB.
+ /// Build the dependency graph for \p BB.
+ void extend(BasicBlock *BB);
+#ifndef NDEBUG
+ void print(raw_ostream &OS) const;
+ LLVM_DUMP_METHOD void dump() const;
+#endif // NDEBUG
+};
+
+} // namespace llvm::sandboxir
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_DEPENDENCYGRAPH_H
diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index b11631350e8b4e..59d04ac3cecd00 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -3,6 +3,7 @@ add_llvm_component_library(LLVMVectorize
LoopIdiomVectorize.cpp
LoopVectorizationLegality.cpp
LoopVectorize.cpp
+ SandboxVectorizer/DependencyGraph.cpp
SandboxVectorizer/Passes/BottomUpVec.cpp
SandboxVectorizer/SandboxVectorizer.cpp
SLPVectorizer.cpp
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
new file mode 100644
index 00000000000000..af4127655c785e
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -0,0 +1,64 @@
+//===- DependencyGraph.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
+
+using namespace llvm::sandboxir;
+
+#ifndef NDEBUG
+void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
+ I->dumpOS(OS);
+ if (PrintDeps) {
+ OS << "\n";
+ // Print memory preds.
+ static constexpr const unsigned Indent = 4;
+ for (auto *Pred : MemPreds) {
+ OS.indent(Indent) << "<-";
+ Pred->print(OS, false);
+ OS << "\n";
+ }
+ }
+}
+void DGNode::dump() const {
+ print(dbgs());
+ dbgs() << "\n";
+}
+#endif // NDEBUG
+
+void DependencyGraph::extend(BasicBlock *BB) {
+ if (BB->empty())
+ return;
+ // TODO: For now create a chain of dependencies.
+ DGNode *LastN = getOrCreateNode(&*BB->begin());
+ for (auto &I : drop_begin(*BB)) {
+ auto *N = getOrCreateNode(&I);
+ N->addMemPred(LastN);
+ LastN = N;
+ }
+}
+
+#ifndef NDEBUG
+void DependencyGraph::print(raw_ostream &OS) const {
+ // InstrToNodeMap is unordered so we need to create an ordered vector.
+ SmallVector<DGNode *> Nodes;
+ Nodes.reserve(InstrToNodeMap.size());
+ for (const auto &Pair : InstrToNodeMap)
+ Nodes.push_back(Pair.second.get());
+ // Sort them based on which one comes first in the BB.
+ stable_sort(Nodes, [](DGNode *N1, DGNode *N2) {
+ return N1->getInstruction()->comesBefore(N2->getInstruction());
+ });
+ for (auto *N : Nodes)
+ N->print(OS, /*PrintDeps=*/true);
+}
+
+void DependencyGraph::dump() const {
+ print(dbgs());
+ dbgs() << "\n";
+}
+#endif // NDEBUG
diff --git a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt
index 1354558a94f0d5..0df39c41a90414 100644
--- a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(SandboxVectorizer)
+
set(LLVM_LINK_COMPONENTS
Analysis
Core
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
new file mode 100644
index 00000000000000..488c9c2344b56c
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
@@ -0,0 +1,12 @@
+set(LLVM_LINK_COMPONENTS
+ Analysis
+ Core
+ Vectorize
+ AsmParser
+ TargetParser
+ SandboxIR
+ )
+
+add_llvm_unittest(SandboxVectorizerTests
+ DependencyGraphTest.cpp
+ )
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
new file mode 100644
index 00000000000000..dc85a22f7f4832
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -0,0 +1,63 @@
+//===- DependencyGraphTest.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct DependencyGraphTest : public testing::Test {
+ LLVMContext C;
+ std::unique_ptr<Module> M;
+
+ void parseIR(LLVMContext &C, const char *IR) {
+ SMDiagnostic Err;
+ M = parseAssemblyString(IR, Err, C);
+ if (!M)
+ Err.print("DependencyGraphTest", errs());
+ }
+};
+
+TEST_F(DependencyGraphTest, Basic) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+ store i8 %v0, ptr %ptr
+ store i8 %v1, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+ sandboxir::DependencyGraph DAG;
+ DAG.extend(BB);
+
+ sandboxir::DGNode *N0 = DAG.getNode(S0);
+ sandboxir::DGNode *N1 = DAG.getNode(S1);
+ sandboxir::DGNode *N2 = DAG.getNode(Ret);
+ // Check getInstruction().
+ EXPECT_EQ(N0->getInstruction(), S0);
+ EXPECT_EQ(N1->getInstruction(), S1);
+ // Check hasMemPred()
+ EXPECT_TRUE(N1->hasMemPred(N0));
+ EXPECT_FALSE(N0->hasMemPred(N1));
+
+ // Check memPreds().
+ EXPECT_TRUE(N0->memPreds().empty());
+ EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
+ EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1));
+}
More information about the llvm-commits
mailing list