[llvm] 318d2f5 - [SandboxVec][DAG] Boilerplate (#108862)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 17 12:03:58 PDT 2024


Author: vporpo
Date: 2024-09-17T12:03:52-07:00
New Revision: 318d2f5e5d4d8245ab419193266b956194116989

URL: https://github.com/llvm/llvm-project/commit/318d2f5e5d4d8245ab419193266b956194116989
DIFF: https://github.com/llvm/llvm-project/commit/318d2f5e5d4d8245ab419193266b956194116989.diff

LOG: [SandboxVec][DAG] Boilerplate (#108862)

This patch adds a very basic implementation of the Dependency Graph to
be used by the vectorizer.

Added: 
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
    llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
    llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
    llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Modified: 
    llvm/lib/Transforms/Vectorize/CMakeLists.txt
    llvm/unittests/Transforms/Vectorize/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
new file mode 100644
index 00000000000000..8a2021a5e6ba60
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -0,0 +1,86 @@
+//===- DependencyGraph.h ----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the dependency graph used by the vectorizer's instruction
+// scheduler.
+//
+// The nodes of the graph are objects of the `DGNode` class. Each `DGNode`
+// object points to an instruction.
+// The edges between `DGNode`s are implicitly defined by an ordered set of
+// predecessor nodes, to save memory.
+// Finally the whole dependency graph is an object of the `DependencyGraph`
+// class, which also provides the API for creating/extending the graph from
+// input Sandbox IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_DEPENDENCYGRAPH_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_DEPENDENCYGRAPH_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+
+namespace llvm::sandboxir {
+
+/// A DependencyGraph Node that points to an Instruction and contains memory
+/// dependency edges.
+class DGNode {
+  Instruction *I;
+  /// Memory predecessors.
+  DenseSet<DGNode *> MemPreds;
+
+public:
+  DGNode(Instruction *I) : I(I) {}
+  Instruction *getInstruction() const { return I; }
+  void addMemPred(DGNode *PredN) { MemPreds.insert(PredN); }
+  /// \Returns all memory dependency predecessors.
+  iterator_range<DenseSet<DGNode *>::const_iterator> memPreds() const {
+    return make_range(MemPreds.begin(), MemPreds.end());
+  }
+  /// \Returns true if there is a memory dependency N->this.
+  bool hasMemPred(DGNode *N) const { return MemPreds.count(N); }
+#ifndef NDEBUG
+  void print(raw_ostream &OS, bool PrintDeps = true) const;
+  friend raw_ostream &operator<<(DGNode &N, raw_ostream &OS) {
+    N.print(OS);
+    return OS;
+  }
+  LLVM_DUMP_METHOD void dump() const;
+#endif // NDEBUG
+};
+
+class DependencyGraph {
+private:
+  DenseMap<Instruction *, std::unique_ptr<DGNode>> InstrToNodeMap;
+
+public:
+  DependencyGraph() {}
+
+  DGNode *getNode(Instruction *I) const {
+    auto It = InstrToNodeMap.find(I);
+    return It != InstrToNodeMap.end() ? It->second.get() : nullptr;
+  }
+  DGNode *getOrCreateNode(Instruction *I) {
+    auto [It, NotInMap] = InstrToNodeMap.try_emplace(I);
+    if (NotInMap)
+      It->second = std::make_unique<DGNode>(I);
+    return It->second.get();
+  }
+  // TODO: extend() should work with intervals not the whole BB.
+  /// Build the dependency graph for \p BB.
+  void extend(BasicBlock *BB);
+#ifndef NDEBUG
+  void print(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif // NDEBUG
+};
+
+} // namespace llvm::sandboxir
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_DEPENDENCYGRAPH_H

diff  --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index b11631350e8b4e..59d04ac3cecd00 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -3,6 +3,7 @@ add_llvm_component_library(LLVMVectorize
   LoopIdiomVectorize.cpp
   LoopVectorizationLegality.cpp
   LoopVectorize.cpp
+  SandboxVectorizer/DependencyGraph.cpp
   SandboxVectorizer/Passes/BottomUpVec.cpp
   SandboxVectorizer/SandboxVectorizer.cpp
   SLPVectorizer.cpp

diff  --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
new file mode 100644
index 00000000000000..41e50953a4ec8a
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -0,0 +1,64 @@
+//===- DependencyGraph.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
+
+using namespace llvm::sandboxir;
+
+#ifndef NDEBUG
+void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
+  I->dumpOS(OS);
+  if (PrintDeps) {
+    OS << "\n";
+    // Print memory preds.
+    static constexpr const unsigned Indent = 4;
+    for (auto *Pred : MemPreds) {
+      OS.indent(Indent) << "<-";
+      Pred->print(OS, false);
+      OS << "\n";
+    }
+  }
+}
+void DGNode::dump() const {
+  print(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+void DependencyGraph::extend(BasicBlock *BB) {
+  if (BB->empty())
+    return;
+  // TODO: For now create a chain of dependencies.
+  DGNode *LastN = getOrCreateNode(&*BB->begin());
+  for (auto &I : drop_begin(*BB)) {
+    auto *N = getOrCreateNode(&I);
+    N->addMemPred(LastN);
+    LastN = N;
+  }
+}
+
+#ifndef NDEBUG
+void DependencyGraph::print(raw_ostream &OS) const {
+  // InstrToNodeMap is unordered so we need to create an ordered vector.
+  SmallVector<DGNode *> Nodes;
+  Nodes.reserve(InstrToNodeMap.size());
+  for (const auto &Pair : InstrToNodeMap)
+    Nodes.push_back(Pair.second.get());
+  // Sort them based on which one comes first in the BB.
+  sort(Nodes, [](DGNode *N1, DGNode *N2) {
+    return N1->getInstruction()->comesBefore(N2->getInstruction());
+  });
+  for (auto *N : Nodes)
+    N->print(OS, /*PrintDeps=*/true);
+}
+
+void DependencyGraph::dump() const {
+  print(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG

diff  --git a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt
index 1354558a94f0d5..0df39c41a90414 100644
--- a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(SandboxVectorizer)
+
 set(LLVM_LINK_COMPONENTS
   Analysis
   Core

diff  --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
new file mode 100644
index 00000000000000..488c9c2344b56c
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
@@ -0,0 +1,12 @@
+set(LLVM_LINK_COMPONENTS
+  Analysis
+  Core
+  Vectorize
+  AsmParser
+  TargetParser
+  SandboxIR
+  )
+
+add_llvm_unittest(SandboxVectorizerTests
+  DependencyGraphTest.cpp
+  )

diff  --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
new file mode 100644
index 00000000000000..dc85a22f7f4832
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -0,0 +1,63 @@
+//===- DependencyGraphTest.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct DependencyGraphTest : public testing::Test {
+  LLVMContext C;
+  std::unique_ptr<Module> M;
+
+  void parseIR(LLVMContext &C, const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("DependencyGraphTest", errs());
+  }
+};
+
+TEST_F(DependencyGraphTest, Basic) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  store i8 %v0, ptr %ptr
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+  sandboxir::DependencyGraph DAG;
+  DAG.extend(BB);
+
+  sandboxir::DGNode *N0 = DAG.getNode(S0);
+  sandboxir::DGNode *N1 = DAG.getNode(S1);
+  sandboxir::DGNode *N2 = DAG.getNode(Ret);
+  // Check getInstruction().
+  EXPECT_EQ(N0->getInstruction(), S0);
+  EXPECT_EQ(N1->getInstruction(), S1);
+  // Check hasMemPred()
+  EXPECT_TRUE(N1->hasMemPred(N0));
+  EXPECT_FALSE(N0->hasMemPred(N1));
+
+  // Check memPreds().
+  EXPECT_TRUE(N0->memPreds().empty());
+  EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
+  EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1));
+}


        


More information about the llvm-commits mailing list