[Mlir-commits] [mlir] [MLIR][Support] A cache for cyclical replacers/maps (PR #98202)

Billy Zhu llvmlistbot at llvm.org
Tue Jul 9 12:11:23 PDT 2024


https://github.com/zyx-billy created https://github.com/llvm/llvm-project/pull/98202

This is a support data structure that acts as a cache for replacer-like functions that map values between two domains. The difference compared to just using a map to cache in-out pairs is that this class is able to handle replacer logic that is self-recursive (and thus may cause infinite recursion in the naive case).

This class provides a hook for the user to perform cycle pruning when a cycle is identified, and is able to perform context-sensitive caching so that the replacement result for an input that is part of a pruned cycle can be distinct from the replacement result for the same input when it is not part of a cycle.

In addition, this class allows deferring cycle pruning until specific inputs are repeated. This is useful for cases where not all elements in a cycle can perform pruning. The user still must guarantee that at least one element in any given cycle can perform pruning. Even if not, an assertion will eventually be tripped instead of infinite recursion (the run-time is linearly bounded by the maximum cycle length of its input).

---

Users of this class are pushed in stacked PRs.

>From bcc3edaa9b80c1ee3e87ad15894382a065bef9ad Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 10:26:05 -0700
Subject: [PATCH] add replacer cache and gtests

---
 .../mlir/Support/CyclicReplacerCache.h        | 277 ++++++++++
 mlir/unittests/Support/CMakeLists.txt         |   1 +
 .../Support/CyclicReplacerCacheTest.cpp       | 472 ++++++++++++++++++
 3 files changed, 750 insertions(+)
 create mode 100644 mlir/include/mlir/Support/CyclicReplacerCache.h
 create mode 100644 mlir/unittests/Support/CyclicReplacerCacheTest.cpp

diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
new file mode 100644
index 0000000000000..9a703676fff11
--- /dev/null
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -0,0 +1,277 @@
+//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains helper classes for caching replacer-like functions that
+// map values between two domains. They are able to handle replacer logic that
+// contains self-recursion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_CACHINGREPLACER_H
+#define MLIR_SUPPORT_CACHINGREPLACER_H
+
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include <set>
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// CyclicReplacerCache
+//===----------------------------------------------------------------------===//
+
+/// A cache for replacer-like functions that map values between two domains. The
+/// difference compared to just using a map to cache in-out pairs is that this
+/// class is able to handle replacer logic that is self-recursive (and thus may
+/// cause infinite recursion in the naive case).
+///
+/// This class provides a hook for the user to perform cycle pruning when a
+/// cycle is identified, and is able to perform context-sensitive caching so
+/// that the replacement result for an input that is part of a pruned cycle can
+/// be distinct from the replacement result for the same input when it is not
+/// part of a cycle.
+///
+/// In addition, this class allows deferring cycle pruning until specific inputs
+/// are repeated. This is useful for cases where not all elements in a cycle can
+/// perform pruning. The user still must guarantee that at least one element in
+/// any given cycle can perform pruning. Even if not, an assertion will
+/// eventually be tripped instead of infinite recursion (the run-time is
+/// linearly bounded by the maximum cycle length of its input).
+template <typename InT, typename OutT>
+class CyclicReplacerCache {
+public:
+  /// User-provided replacement function & cycle-breaking functions.
+  /// The cycle-breaking function must not make any more recursive invocations
+  /// to this cached replacer.
+  using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>;
+
+  CyclicReplacerCache() = delete;
+  CyclicReplacerCache(CycleBreakerFn cycleBreaker)
+      : cycleBreaker(std::move(cycleBreaker)) {}
+
+  /// A possibly unresolved cache entry.
+  /// If unresolved, the entry must be resolved before it goes out of scope.
+  struct CacheEntry {
+  public:
+    ~CacheEntry() { assert(result && "unresovled cache entry"); }
+
+    /// Check whether this node was repeated during recursive replacements.
+    /// This only makes sense to be called after all recursive replacements are
+    /// completed and the current element has resurfaced to the top of the
+    /// replacement stack.
+    bool wasRepeated() const {
+      // If the top frame includes itself as a dependency, then it must have
+      // been repeated.
+      ReplacementFrame &currFrame = cache.replacementStack.back();
+      size_t currFrameIndex = cache.replacementStack.size() - 1;
+      return currFrame.dependentFrames.count(currFrameIndex);
+    }
+
+    /// Resolve an unresolved cache entry by providing the result to be stored
+    /// in the cache.
+    void resolve(OutT result) {
+      assert(!this->result && "cache entry already resolved");
+      this->result = result;
+      cache.finalizeReplacement(element, result);
+    }
+
+    /// Get the resolved result if one exists.
+    std::optional<OutT> get() { return result; }
+
+  private:
+    friend class CyclicReplacerCache;
+    CacheEntry() = delete;
+    CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
+               std::optional<OutT> result = std::nullopt)
+        : cache(cache), element(element), result(result) {}
+
+    CyclicReplacerCache<InT, OutT> &cache;
+    InT element;
+    std::optional<OutT> result;
+  };
+
+  /// Lookup the cache for a pre-calculated replacement for `element`.
+  /// If one exists, a resolved CacheEntry will be returned. Otherwise, an
+  /// unresolved CacheEntry will be returned, and the caller must resolve it
+  /// with the calculated replacement so it can be registered in the cache for
+  /// future use.
+  /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
+  /// CacheEntries that are returned must be resolved in reverse order of
+  /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
+  /// the first retrieved CacheEntry must be resolved last. This should be
+  /// natural when used as a stack / inside recursion.
+  CacheEntry lookupOrInit(const InT &element);
+
+private:
+  /// Register the replacement in the cache and update the replacementStack.
+  void finalizeReplacement(const InT &element, const OutT &result);
+
+  CycleBreakerFn cycleBreaker;
+  DenseMap<InT, OutT> standaloneCache;
+
+  struct DependentReplacement {
+    OutT replacement;
+    /// The highest replacement frame index that this cache entry is dependent
+    /// on.
+    size_t highestDependentFrame;
+  };
+  DenseMap<InT, DependentReplacement> dependentCache;
+
+  struct ReplacementFrame {
+    /// The set of elements that is only legal while under this current frame.
+    /// They need to be removed from the cache when this frame is popped off the
+    /// replacement stack.
+    DenseSet<InT> dependingReplacements;
+    /// The set of frame indices that this current frame's replacement is
+    /// dependent on, ordered from highest to lowest.
+    std::set<size_t, std::greater<size_t>> dependentFrames;
+  };
+  /// Every element currently in the progress of being replaced pushes a frame
+  /// onto this stack.
+  SmallVector<ReplacementFrame> replacementStack;
+  /// Maps from each input element to its indices on the replacement stack.
+  DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame;
+  /// If set to true, we are currently asking an element to break a cycle. No
+  /// more recursive invocations is allowed while this is true (the replacement
+  /// stack can no longer grow).
+  bool resolvingCycle = false;
+};
+
+template <typename InT, typename OutT>
+typename CyclicReplacerCache<InT, OutT>::CacheEntry
+CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
+  assert(!resolvingCycle &&
+         "illegal recursive invocation while breaking cycle");
+
+  if (auto it = standaloneCache.find(element); it != standaloneCache.end())
+    return CacheEntry(*this, element, it->second);
+
+  if (auto it = dependentCache.find(element); it != dependentCache.end()) {
+    // pdate the current top frame (the element that invoked this current
+    // replacement) to include any dependencies the cache entry had.
+    ReplacementFrame &currFrame = replacementStack.back();
+    currFrame.dependentFrames.insert(it->second.highestDependentFrame);
+    return CacheEntry(*this, element, it->second.replacement);
+  }
+
+  auto [it, inserted] = cyclicElementFrame.try_emplace(element);
+  if (!inserted) {
+    // This is a repeat of a known element. Try to break cycle here.
+    resolvingCycle = true;
+    std::optional<OutT> result = cycleBreaker(element);
+    resolvingCycle = false;
+    if (result) {
+      // Cycle was broken.
+      size_t dependentFrame = it->second.back();
+      dependentCache[element] = {*result, dependentFrame};
+      ReplacementFrame &currFrame = replacementStack.back();
+      // If this is a repeat, there is no replacement frame to pop. Mark the top
+      // frame as being dependent on this element.
+      currFrame.dependentFrames.insert(dependentFrame);
+
+      return CacheEntry(*this, element, *result);
+    }
+
+    // Cycle could not be broken.
+    // A legal setup must ensure at least one element of each cycle can break
+    // cycles. Under this setup, each element can be seen at most twice before
+    // the cycle is broken. If we see an element more than twice, we know this
+    // is an illegal setup.
+    assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
+  }
+
+  // Otherwise, either this is the first time we see this element, or this
+  // element could not break this cycle.
+  it->second.push_back(replacementStack.size());
+  replacementStack.emplace_back();
+
+  return CacheEntry(*this, element);
+}
+
+template <typename InT, typename OutT>
+void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
+                                                         const OutT &result) {
+  ReplacementFrame &currFrame = replacementStack.back();
+  // With the conclusion of this replacement frame, the current element is no
+  // longer a dependent element.
+  currFrame.dependentFrames.erase(replacementStack.size() - 1);
+
+  auto prevLayerIter = ++replacementStack.rbegin();
+  if (prevLayerIter == replacementStack.rend()) {
+    // If this is the last frame, there should be zero dependents.
+    assert(currFrame.dependentFrames.empty() &&
+           "internal error: top-level dependent replacement");
+    // Cache standalone result.
+    standaloneCache[element] = result;
+  } else if (currFrame.dependentFrames.empty()) {
+    // Cache standalone result.
+    standaloneCache[element] = result;
+  } else {
+    // Cache dependent result.
+    size_t highestDependentFrame = *currFrame.dependentFrames.begin();
+    dependentCache[element] = {result, highestDependentFrame};
+
+    // Otherwise, the previous frame inherits the same dependent frames.
+    prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
+                                          currFrame.dependentFrames.end());
+
+    // Mark this current replacement as a depending replacement on the closest
+    // dependent frame.
+    replacementStack[highestDependentFrame].dependingReplacements.insert(
+        element);
+  }
+
+  // All depending replacements in the cache must be purged.
+  for (InT key : currFrame.dependingReplacements)
+    dependentCache.erase(key);
+
+  replacementStack.pop_back();
+  auto it = cyclicElementFrame.find(element);
+  it->second.pop_back();
+  if (it->second.empty())
+    cyclicElementFrame.erase(it);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer
+//===----------------------------------------------------------------------===//
+
+/// A helper class for cases where the input/output types of the replacer
+/// function is identical to the types stored in the cache. This class wraps
+/// the user-provided replacer function, and can be used in place of the user
+/// function.
+template <typename InT, typename OutT>
+class CachedCyclicReplacer {
+public:
+  using ReplacerFn = std::function<OutT(const InT &)>;
+  using CycleBreakerFn =
+      typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
+
+  CachedCyclicReplacer() = delete;
+  CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
+      : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
+
+  OutT operator()(const InT &element) {
+    auto cacheEntry = cache.lookupOrInit(element);
+    if (std::optional<OutT> result = cacheEntry.get())
+      return *result;
+
+    OutT result = replacer(element);
+    cacheEntry.resolve(result);
+    return result;
+  }
+
+private:
+  ReplacerFn replacer;
+  CyclicReplacerCache<InT, OutT> cache;
+};
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_CACHINGREPLACER_H
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 1dbf072bcbbfd..ec79a1c640909 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIRSupportTests
+  CyclicReplacerCacheTest.cpp
   IndentedOstreamTest.cpp
   StorageUniquerTest.cpp
 )
diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
new file mode 100644
index 0000000000000..23748e29765cb
--- /dev/null
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -0,0 +1,472 @@
+//===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/CyclicReplacerCache.h"
+#include "llvm/ADT/SetVector.h"
+#include "gmock/gmock.h"
+#include <map>
+#include <set>
+
+using namespace mlir;
+
+TEST(CachedCyclicReplacerTest, testNoRecursion) {
+  CachedCyclicReplacer<int, bool> replacer(
+      /*replacer=*/[](int n) { return static_cast<bool>(n); },
+      /*cycleBreaker=*/[](int n) { return std::nullopt; });
+
+  EXPECT_EQ(replacer(3), true);
+  EXPECT_EQ(replacer(0), false);
+}
+
+TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
+  // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
+  CachedCyclicReplacer<int, int> replacer(
+      /*replacer=*/[&](int n) { return replacer((n + 1) % 3); },
+      /*cycleBreaker=*/[&](int n) { return -1; });
+
+  // Starting at 0.
+  EXPECT_EQ(replacer(0), -1);
+  // Starting at 2.
+  EXPECT_EQ(replacer(2), -1);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+/// This set of tests uses a replacer function that maps ints into vectors of
+/// ints.
+///
+/// The replacement result for input `n` is the replacement result of `(n+1)%3`
+/// appended with an element `42`. Theoretically, this will produce an
+/// infinitely long vector. The cycle-breaker function prunes this infinite
+/// recursion in the replacer logic by returning an empty vector upon the first
+/// re-occurrence of an input value.
+class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+  // N ==> (N+1) % 3
+  // This will create a chain of infinite length without recursion pruning.
+  CachedCyclicReplacerChainRecursionPruningTest()
+      : replacer(
+            [&](int n) {
+              ++invokeCount;
+              std::vector<int> result = replacer((n + 1) % 3);
+              result.push_back(42);
+              return result;
+            },
+            [&](int n) -> std::optional<std::vector<int>> {
+              return baseCase.value_or(n) == n
+                         ? std::make_optional(std::vector<int>{})
+                         : std::nullopt;
+            }) {}
+
+  std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
+
+  CachedCyclicReplacer<int, std::vector<int>> replacer;
+  int invokeCount = 0;
+  std::optional<int> baseCase = std::nullopt;
+};
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+
+  // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(1), getChain(5));
+  EXPECT_EQ(invokeCount, 2);
+
+  // Starting at 2. Cycle length is 4. Entire result is cached.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(2), getChain(4));
+  EXPECT_EQ(invokeCount, 0);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
+  // Starting at 1. Cycle length is 3.
+  EXPECT_EQ(replacer(1), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
+  baseCase = 0;
+
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
+  baseCase = 0;
+
+  // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+  EXPECT_EQ(replacer(1), getChain(5));
+  EXPECT_EQ(invokeCount, 5);
+
+  // Starting at 0. Cycle length is 3. Entire result is cached.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 0);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: GraphReplacement
+//===----------------------------------------------------------------------===//
+
+/// This set of tests uses a replacer function that maps from cyclic graphs to
+/// trees, pruning out cycles in the process.
+///
+/// It consists of two helper classes:
+/// - Graph
+///   - A directed graph where nodes are non-negative integers.
+/// - PrunedGraph
+///   - A Graph where edges that used to cause cycles are now represented with
+///     an indirection (a recursionId).
+class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
+public:
+  /// A directed graph where nodes are non-negative integers.
+  struct Graph {
+    using Node = int64_t;
+
+    /// Use ordered containers for deterministic output.
+    /// Nodes without outgoing edges are considered nonexistent.
+    std::map<Node, std::set<Node>> edges;
+
+    void addEdge(Node src, Node sink) { edges[src].insert(sink); }
+
+    bool isCyclic() const {
+      DenseSet<Node> visited;
+      for (Node root : llvm::make_first_range(edges)) {
+        if (visited.contains(root))
+          continue;
+
+        SetVector<Node> path;
+        SmallVector<Node> workstack;
+        workstack.push_back(root);
+        while (!workstack.empty()) {
+          Node curr = workstack.back();
+          workstack.pop_back();
+
+          if (curr < 0) {
+            // A negative node signals the end of processing all of this node's
+            // children. Remove self from path.
+            assert(path.back() == -curr && "internal inconsistency");
+            path.pop_back();
+            continue;
+          }
+
+          if (path.contains(curr))
+            return true;
+
+          visited.insert(curr);
+          auto edgesIter = edges.find(curr);
+          if (edgesIter == edges.end() || edgesIter->second.empty())
+            continue;
+
+          path.insert(curr);
+          // Push negative node to signify recursion return.
+          workstack.push_back(-curr);
+          workstack.insert(workstack.end(), edgesIter->second.begin(),
+                           edgesIter->second.end());
+        }
+      }
+      return false;
+    }
+
+    /// Deterministic output for testing.
+    std::string serialize() const {
+      std::ostringstream oss;
+      for (const auto &[src, neighbors] : edges) {
+        oss << src << ":";
+        for (Graph::Node neighbor : neighbors)
+          oss << " " << neighbor;
+        oss << "\n";
+      }
+      return oss.str();
+    }
+  };
+
+  /// A Graph where edges that used to cause cycles (back-edges) are now
+  /// represented with an indirection (a recursionId).
+  ///
+  /// In addition to each node being an integer, each node also tracks the
+  /// original integer id it had in the original graph. This way for every
+  /// back-edge, we can represent it as pointing to a new instance of the
+  /// original node. Then we mark the original node and the new instance with
+  /// a new unique recursionId to indicate that they're supposed to be the same
+  /// graph.
+  struct PrunedGraph {
+    using Node = Graph::Node;
+    struct NodeInfo {
+      Graph::Node originalId;
+      // A negative recursive index means not recursive.
+      int64_t recursionId;
+    };
+
+    /// Add a regular non-recursive-self node.
+    Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
+      Node id = nextConnectionId++;
+      info[id] = {originalId, recursionIndex};
+      return id;
+    }
+    /// Add a recursive-self-node, i.e. a duplicate of the original node that is
+    /// meant to represent an indirection to it.
+    std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) {
+      return {addNode(originalId, nextRecursionId), nextRecursionId++};
+    }
+    void addEdge(Node src, Node sink) { connections.addEdge(src, sink); }
+
+    /// Deterministic output for testing.
+    std::string serialize() const {
+      std::ostringstream oss;
+      oss << "nodes\n";
+      for (const auto &[nodeId, nodeInfo] : info) {
+        oss << nodeId << ": n" << nodeInfo.originalId;
+        if (nodeInfo.recursionId >= 0)
+          oss << '<' << nodeInfo.recursionId << '>';
+        oss << "\n";
+      }
+      oss << "edges\n";
+      oss << connections.serialize();
+      return oss.str();
+    }
+
+    bool isCyclic() const { return connections.isCyclic(); }
+
+  private:
+    Graph connections;
+    int64_t nextRecursionId = 0;
+    int64_t nextConnectionId = 0;
+    // Use ordered map for deterministic output.
+    std::map<Graph::Node, NodeInfo> info;
+  };
+
+  PrunedGraph breakCycles(const Graph &input) {
+    assert(input.isCyclic() && "input graph is not cyclic");
+
+    PrunedGraph output;
+
+    DenseMap<Graph::Node, int64_t> recMap;
+    auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> {
+      auto [node, recId] = output.addRecursiveSelfNode(inNode);
+      recMap[inNode] = recId;
+      return node;
+    };
+
+    CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker);
+
+    std::function<Graph::Node(Graph::Node)> replaceNode =
+        [&](Graph::Node inNode) {
+          auto cacheEntry = cache.lookupOrInit(inNode);
+          if (std::optional<Graph::Node> result = cacheEntry.get())
+            return *result;
+
+          // Recursively replace its neighbors.
+          SmallVector<Graph::Node> neighbors;
+          if (auto it = input.edges.find(inNode); it != input.edges.end())
+            neighbors = SmallVector<Graph::Node>(
+                llvm::map_range(it->second, replaceNode));
+
+          // Create a new node in the output graph.
+          int64_t recursionIndex =
+              cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1;
+          Graph::Node result = output.addNode(inNode, recursionIndex);
+
+          for (Graph::Node neighbor : neighbors)
+            output.addEdge(result, neighbor);
+
+          cacheEntry.resolve(result);
+          return result;
+        };
+
+    /// Translate starting from each node.
+    for (Graph::Node root : llvm::make_first_range(input.edges))
+      replaceNode(root);
+
+    return output;
+  }
+
+  /// Helper for serialization tests that allow putting comments in the
+  /// serialized format. Every line that begins with a `;` is considered a
+  /// comment. The entire line, incl. the terminating `\n` is removed.
+  std::string trimComments(StringRef input) {
+    std::ostringstream oss;
+    bool isNewLine = false;
+    bool isComment = false;
+    for (char c : input) {
+      // Lines beginning with ';' are comments.
+      if (isNewLine && c == ';')
+        isComment = true;
+
+      if (!isComment)
+        oss << c;
+
+      if (c == '\n') {
+        isNewLine = true;
+        isComment = false;
+      }
+    }
+    return oss.str();
+  }
+};
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
+  // 0 -> 1 -> 2
+  // ^         |
+  // +---------+
+  Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}};
+  PrunedGraph output = breakCycles(input);
+  ASSERT_FALSE(output.isCyclic()) << output.serialize();
+  EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n2
+2: n1
+3: n0<0>
+; root 1
+4: n2
+; root 2
+5: n1
+edges
+1: 0
+2: 1
+3: 2
+4: 3
+5: 4
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) {
+  // +----> 1 -----+
+  // |             v
+  // 0 <---------- 3
+  // |             ^
+  // +----> 2 -----+
+  //
+  // Two loops:
+  // 0 -> 1 -> 3 -> 0
+  // 0 -> 2 -> 3 -> 0
+  Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
+  PrunedGraph output = breakCycles(input);
+  ASSERT_FALSE(output.isCyclic()) << output.serialize();
+  EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n3
+2: n1
+3: n2
+4: n0<0>
+; root 1
+5: n3
+6: n1
+; root 2
+7: n2
+edges
+1: 0
+2: 1
+3: 1
+4: 2 3
+5: 4
+6: 5
+7: 5
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) {
+  // +----> 1 -----+
+  // |      ^      v
+  // 0 <----+----- 2
+  //
+  // Two nested loops:
+  // 0 -> 1 -> 2 -> 0
+  //      1 -> 2 -> 1
+  Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}};
+  PrunedGraph output = breakCycles(input);
+  ASSERT_FALSE(output.isCyclic()) << output.serialize();
+  EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n1<1>
+2: n2
+3: n1<1>
+4: n0<0>
+; root 1
+5: n1<2>
+6: n2
+7: n1<2>
+; root 2
+8: n2
+edges
+2: 0 1
+3: 2
+4: 3
+6: 4 5
+7: 6
+8: 4 7
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) {
+  // +----> 1 -----+
+  // |      ^      v
+  // 0 <----+----- 3
+  // |      v      ^
+  // +----> 2 -----+
+  //
+  // Two sets of nested loops:
+  // 0 -> 1 -> 3 -> 0
+  //      1 -> 3 -> 1
+  // 0 -> 2 -> 3 -> 0
+  //      2 -> 3 -> 2
+  Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
+  PrunedGraph output = breakCycles(input);
+  ASSERT_FALSE(output.isCyclic()) << output.serialize();
+  EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n1<1>
+2: n3<2>
+3: n2
+4: n3<2>
+5: n1<1>
+6: n2<3>
+7: n3
+8: n2<3>
+9: n0<0>
+; root 1
+10: n1<4>
+11: n3<5>
+12: n2
+13: n3<5>
+14: n1<4>
+; root 2
+15: n2<6>
+16: n3
+17: n2<6>
+; root 3
+18: n3
+edges
+; root 0
+3: 2
+4: 0 1 3
+5: 4
+7: 0 5 6
+8: 7
+9: 5 8
+; root 1
+12: 11
+13: 9 10 12
+14: 13
+; root 2
+16: 9 14 15
+17: 16
+; root 3
+18: 9 14 17
+)"));
+}



More information about the Mlir-commits mailing list