[Mlir-commits] [mlir] [MLIR][Support] A cache for cyclical replacers/maps (PR #98198)
Billy Zhu
llvmlistbot at llvm.org
Tue Jul 9 11:23:20 PDT 2024
https://github.com/zyx-billy created https://github.com/llvm/llvm-project/pull/98198
This is a support data structure that acts as a cache for replacer-like functions that map values between two domains. The difference compared to just using a map to cache in-out pairs is that this class is able to handle replacer logic that is self-recursive (and thus may cause infinite recursion in the naive case).
In addition, this class allows deferring cycle pruning until specific inputs are repeated. This is useful for cases where not all elements in a cycle can perform pruning. The user still must guarantee that at least one element in any given cycle can perform pruning. Even if not, an assertion will eventually be tripped instead of infinite recursion (the run-time is linearly bounded by the maximum cycle length of its input).
---
Users of this class are pushed in separate PRs.
>From 6ff2481653f1cb0bae49e0e64299039c6f7e30eb Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 10:26:05 -0700
Subject: [PATCH] add replacer cache and gtests
---
.../mlir/Support/CyclicReplacerCache.h | 275 +++++++++++
mlir/unittests/Support/CMakeLists.txt | 1 +
.../Support/CyclicReplacerCacheTest.cpp | 438 ++++++++++++++++++
3 files changed, 714 insertions(+)
create mode 100644 mlir/include/mlir/Support/CyclicReplacerCache.h
create mode 100644 mlir/unittests/Support/CyclicReplacerCacheTest.cpp
diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
new file mode 100644
index 0000000000000..33d1ef791efdc
--- /dev/null
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -0,0 +1,275 @@
+//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains helper classes for caching replacer-like functions that
+// map values between two domains. They are able to handle replacer logic that
+// contains self-recursion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_CACHINGREPLACER_H
+#define MLIR_SUPPORT_CACHINGREPLACER_H
+
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include <set>
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// CyclicReplacerCache
+//===----------------------------------------------------------------------===//
+
+/// A cache for replacer-like functions that map values between two domains.
+/// The difference compared to a naive map that just stores in-out pairs is that
+/// this class is able to work for replacer logic that is infinitely
+/// self-recursive.
+///
+/// This class provides a hook for the user to perform cycle pruning when a
+/// cycle is identified, and is able to perform context-sensitive caching so
+/// that the replacement result for an input that is part of a pruned cycle can
+/// be distinct from the replacement result for the same input when it is not
+/// part of a cycle.
+///
+/// In addition, this class allows deferring cycle pruning until specific inputs
+/// are repeated. This is useful for cases where not all elements in a cycle can
+/// perform pruning. The user still must guarantee that at least one element in
+/// any given cycle can perform pruning.
+template <typename InT, typename OutT>
+class CyclicReplacerCache {
+public:
+ /// User-provided replacement function & cycle-breaking functions.
+ /// The cycle-breaking function must not make any more recursive invocations
+ /// to this cached replacer.
+ using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>;
+
+ CyclicReplacerCache() = delete;
+ CyclicReplacerCache(CycleBreakerFn cycleBreaker)
+ : cycleBreaker(std::move(cycleBreaker)) {}
+
+ /// A possibly unresolved cache entry.
+ /// If unresolved, the entry must be resolved before it goes out of scope.
+ struct CacheEntry {
+ public:
+ ~CacheEntry() { assert(result && "unresovled cache entry"); }
+
+ /// Check whether this node was repeated during recursive replacements.
+ /// This only makes sense to be called after all recursive replacements are
+ /// completed and the current element has resurfaced to the top of the
+ /// replacement stack.
+ bool wasRepeated() const {
+ // If the top frame includes itself as a dependency, then it must have
+ // been repeated.
+ ReplacementFrame &currFrame = cache.replacementStack.back();
+ size_t currFrameIndex = cache.replacementStack.size() - 1;
+ return currFrame.dependentFrames.count(currFrameIndex);
+ }
+
+ /// Resolve an unresolved cache entry by providing the result to be stored
+ /// in the cache.
+ void resolve(OutT result) {
+ assert(!this->result && "cache entry already resolved");
+ this->result = result;
+ cache.finalizeReplacement(element, result);
+ }
+
+ /// Get the resolved result if one exists.
+ std::optional<OutT> get() { return result; }
+
+ private:
+ friend class CyclicReplacerCache;
+ CacheEntry() = delete;
+ CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
+ std::optional<OutT> result = std::nullopt)
+ : cache(cache), element(element), result(result) {}
+
+ CyclicReplacerCache<InT, OutT> &cache;
+ InT element;
+ std::optional<OutT> result;
+ };
+
+ /// Lookup the cache for a pre-calculated replacement for `element`.
+ /// If one exists, a resolved CacheEntry will be returned. Otherwise, an
+ /// unresolved CacheEntry will be returned, and the caller must resolve it
+ /// with the calculated replacement so it can be registered in the cache for
+ /// future use.
+ /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
+ /// CacheEntries that are returned must be resolved in reverse order of
+ /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
+ /// the first retrieved CacheEntry must be resolved last. This should be
+ /// natural when used as a stack / inside recursion.
+ CacheEntry lookupOrInit(const InT &element);
+
+private:
+ /// Register the replacement in the cache and update the replacementStack.
+ void finalizeReplacement(const InT &element, const OutT &result);
+
+ CycleBreakerFn cycleBreaker;
+ DenseMap<InT, OutT> standaloneCache;
+
+ struct DependentReplacement {
+ OutT replacement;
+ /// The highest replacement frame index that this cache entry is dependent
+ /// on.
+ size_t highestDependentFrame;
+ };
+ DenseMap<InT, DependentReplacement> dependentCache;
+
+ struct ReplacementFrame {
+ /// The set of elements that is only legal while under this current frame.
+ /// They need to be removed from the cache when this frame is popped off the
+ /// replacement stack.
+ DenseSet<InT> dependingReplacements;
+ /// The set of frame indices that this current frame's replacement is
+ /// dependent on, ordered from highest to lowest.
+ std::set<size_t, std::greater<size_t>> dependentFrames;
+ };
+ /// Every element currently in the progress of being replaced pushes a frame
+ /// onto this stack.
+ SmallVector<ReplacementFrame> replacementStack;
+ /// Maps from each input element to its indices on the replacement stack.
+ DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame;
+ /// If set to true, we are currently asking an element to break a cycle. No
+ /// more recursive invocations is allowed while this is true (the replacement
+ /// stack can no longer grow).
+ bool resolvingCycle = false;
+};
+
+template <typename InT, typename OutT>
+typename CyclicReplacerCache<InT, OutT>::CacheEntry
+CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
+ assert(!resolvingCycle &&
+ "illegal recursive invocation while breaking cycle");
+
+ if (auto it = standaloneCache.find(element); it != standaloneCache.end())
+ return CacheEntry(*this, element, it->second);
+
+ if (auto it = dependentCache.find(element); it != dependentCache.end()) {
+ // pdate the current top frame (the element that invoked this current
+ // replacement) to include any dependencies the cache entry had.
+ ReplacementFrame &currFrame = replacementStack.back();
+ currFrame.dependentFrames.insert(it->second.highestDependentFrame);
+ return CacheEntry(*this, element, it->second.replacement);
+ }
+
+ auto [it, inserted] = cyclicElementFrame.try_emplace(element);
+ if (!inserted) {
+ // This is a repeat of a known element. Try to break cycle here.
+ resolvingCycle = true;
+ std::optional<OutT> result = cycleBreaker(element);
+ resolvingCycle = false;
+ if (result) {
+ // Cycle was broken.
+ size_t dependentFrame = it->second.back();
+ dependentCache[element] = {*result, dependentFrame};
+ ReplacementFrame &currFrame = replacementStack.back();
+ // If this is a repeat, there is no replacement frame to pop. Mark the top
+ // frame as being dependent on this element.
+ currFrame.dependentFrames.insert(dependentFrame);
+
+ return CacheEntry(*this, element, *result);
+ }
+
+ // Cycle could not be broken.
+ // A legal setup must ensure at least one element of each cycle can break
+ // cycles. Under this setup, each element can be seen at most twice before
+ // the cycle is broken. If we see an element more than twice, we know this
+ // is an illegal setup.
+ assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
+ }
+
+ // Otherwise, either this is the first time we see this element, or this
+ // element could not break this cycle.
+ it->second.push_back(replacementStack.size());
+ replacementStack.emplace_back();
+
+ return CacheEntry(*this, element);
+}
+
+template <typename InT, typename OutT>
+void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
+ const OutT &result) {
+ ReplacementFrame &currFrame = replacementStack.back();
+ // With the conclusion of this replacement frame, the current element is no
+ // longer a dependent element.
+ currFrame.dependentFrames.erase(replacementStack.size() - 1);
+
+ auto prevLayerIter = ++replacementStack.rbegin();
+ if (prevLayerIter == replacementStack.rend()) {
+ // If this is the last frame, there should be zero dependents.
+ assert(currFrame.dependentFrames.empty() &&
+ "internal error: top-level dependent replacement");
+ // Cache standalone result.
+ standaloneCache[element] = result;
+ } else if (currFrame.dependentFrames.empty()) {
+ // Cache standalone result.
+ standaloneCache[element] = result;
+ } else {
+ // Cache dependent result.
+ size_t highestDependentFrame = *currFrame.dependentFrames.begin();
+ dependentCache[element] = {result, highestDependentFrame};
+
+ // Otherwise, the previous frame inherits the same dependent frames.
+ prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
+ currFrame.dependentFrames.end());
+
+ // Mark this current replacement as a depending replacement on the closest
+ // dependent frame.
+ replacementStack[highestDependentFrame].dependingReplacements.insert(
+ element);
+ }
+
+ // All depending replacements in the cache must be purged.
+ for (InT key : currFrame.dependingReplacements)
+ dependentCache.erase(key);
+
+ replacementStack.pop_back();
+ auto it = cyclicElementFrame.find(element);
+ it->second.pop_back();
+ if (it->second.empty())
+ cyclicElementFrame.erase(it);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer
+//===----------------------------------------------------------------------===//
+
+/// A helper class for cases where the input/output types of the replacer
+/// function is identical to the types stored in the cache. This class wraps
+/// the user-provided replacer function, and can be used in place of the user
+/// function.
+template <typename InT, typename OutT>
+class CachedCyclicReplacer {
+public:
+ using ReplacerFn = std::function<OutT(const InT &)>;
+ using CycleBreakerFn =
+ typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
+
+ CachedCyclicReplacer() = delete;
+ CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
+ : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
+
+ OutT operator()(const InT &element) {
+ auto cacheEntry = cache.lookupOrInit(element);
+ if (std::optional<OutT> result = cacheEntry.get())
+ return *result;
+
+ OutT result = replacer(element);
+ cacheEntry.resolve(result);
+ return result;
+ }
+
+private:
+ ReplacerFn replacer;
+ CyclicReplacerCache<InT, OutT> cache;
+};
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_CACHINGREPLACER_H
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 1dbf072bcbbfd..ec79a1c640909 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_unittest(MLIRSupportTests
+ CyclicReplacerCacheTest.cpp
IndentedOstreamTest.cpp
StorageUniquerTest.cpp
)
diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
new file mode 100644
index 0000000000000..f5e536614f009
--- /dev/null
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -0,0 +1,438 @@
+//===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/CyclicReplacerCache.h"
+#include "llvm/ADT/SetVector.h"
+#include "gmock/gmock.h"
+#include <map>
+#include <set>
+
+using namespace mlir;
+
+TEST(CachedCyclicReplacerTest, testNoRecursion) {
+ CachedCyclicReplacer<int, bool> replacer(
+ /*replacer=*/[](int n) { return static_cast<bool>(n); },
+ /*cycleBreaker=*/[](int n) { return std::nullopt; });
+
+ EXPECT_EQ(replacer(3), true);
+ EXPECT_EQ(replacer(0), false);
+}
+
+TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
+ // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
+ CachedCyclicReplacer<int, int> replacer(
+ /*replacer=*/[&](int n) { return replacer((n + 1) % 3); },
+ /*cycleBreaker=*/[&](int n) { return -1; });
+
+ // Starting at 0.
+ EXPECT_EQ(replacer(0), -1);
+ // Starting at 2.
+ EXPECT_EQ(replacer(2), -1);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+ // N ==> (N+1) % 3
+ // This will create a chain of infinite length without recursion pruning.
+ CachedCyclicReplacerChainRecursionPruningTest()
+ : replacer(
+ [&](int n) {
+ ++invokeCount;
+ std::vector<int> result = replacer((n + 1) % 3);
+ result.push_back(42);
+ return result;
+ },
+ [&](int n) -> std::optional<std::vector<int>> {
+ return baseCase.value_or(n) == n
+ ? std::make_optional(std::vector<int>{})
+ : std::nullopt;
+ }) {}
+
+ std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
+
+ CachedCyclicReplacer<int, std::vector<int>> replacer;
+ int invokeCount = 0;
+ std::optional<int> baseCase = std::nullopt;
+};
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+ // Starting at 0. Cycle length is 3.
+ EXPECT_EQ(replacer(0), getChain(3));
+ EXPECT_EQ(invokeCount, 3);
+
+ // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+ invokeCount = 0;
+ EXPECT_EQ(replacer(1), getChain(5));
+ EXPECT_EQ(invokeCount, 2);
+
+ // Starting at 2. Cycle length is 4. Entire result is cached.
+ invokeCount = 0;
+ EXPECT_EQ(replacer(2), getChain(4));
+ EXPECT_EQ(invokeCount, 0);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
+ // Starting at 1. Cycle length is 3.
+ EXPECT_EQ(replacer(1), getChain(3));
+ EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
+ baseCase = 0;
+
+ // Starting at 0. Cycle length is 3.
+ EXPECT_EQ(replacer(0), getChain(3));
+ EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
+ baseCase = 0;
+
+ // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+ EXPECT_EQ(replacer(1), getChain(5));
+ EXPECT_EQ(invokeCount, 5);
+
+ // Starting at 0. Cycle length is 3. Entire result is cached.
+ invokeCount = 0;
+ EXPECT_EQ(replacer(0), getChain(3));
+ EXPECT_EQ(invokeCount, 0);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: GraphReplacement
+//===----------------------------------------------------------------------===//
+
+class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
+public:
+ struct Graph {
+ using Node = int64_t;
+
+ // Use ordered containers for deterministic output.
+ // Nodes without outgoing edges are considered nonexistent.
+ std::map<Node, std::set<Node>> edges;
+
+ void addEdge(Node src, Node sink) { edges[src].insert(sink); }
+
+ bool isCyclic() const {
+ DenseSet<Node> visited;
+ for (Node root : llvm::make_first_range(edges)) {
+ if (visited.contains(root))
+ continue;
+
+ SetVector<Node> path;
+ SmallVector<Node> workstack;
+ workstack.push_back(root);
+ while (!workstack.empty()) {
+ Node curr = workstack.back();
+ workstack.pop_back();
+
+ if (curr < 0) {
+ // A negative node signals the end of processing all of this node's
+ // children. Remove self from path.
+ assert(path.back() == -curr && "internal inconsistency");
+ path.pop_back();
+ continue;
+ }
+
+ if (path.contains(curr))
+ return true;
+
+ visited.insert(curr);
+ auto edgesIter = edges.find(curr);
+ if (edgesIter == edges.end() || edgesIter->second.empty())
+ continue;
+
+ path.insert(curr);
+ // Push negative node to signify recursion return.
+ workstack.push_back(-curr);
+ workstack.insert(workstack.end(), edgesIter->second.begin(),
+ edgesIter->second.end());
+ }
+ }
+ return false;
+ }
+
+ /// Deterministic output for testing.
+ std::string serialize() const {
+ std::ostringstream oss;
+ for (const auto &[src, neighbors] : edges) {
+ oss << src << ":";
+ for (Graph::Node neighbor : neighbors)
+ oss << " " << neighbor;
+ oss << "\n";
+ }
+ return oss.str();
+ }
+ };
+
+ struct Tree {
+ using Node = Graph::Node;
+ struct NodeInfo {
+ Graph::Node originalId;
+ // A negative recursive index means not recursive.
+ int64_t recursionId;
+ };
+
+ Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
+ Node id = nextConnectionId++;
+ info[id] = {originalId, recursionIndex};
+ return id;
+ }
+ std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) {
+ return {addNode(originalId, nextRecursionId), nextRecursionId++};
+ }
+ void addEdge(Node src, Node sink) { connections.addEdge(src, sink); }
+
+ /// Deterministic output for testing.
+ std::string serialize() const {
+ std::ostringstream oss;
+ oss << "nodes\n";
+ for (const auto &[nodeId, nodeInfo] : info) {
+ oss << nodeId << ": n" << nodeInfo.originalId;
+ if (nodeInfo.recursionId >= 0)
+ oss << '<' << nodeInfo.recursionId << '>';
+ oss << "\n";
+ }
+ oss << "edges\n";
+ oss << connections.serialize();
+ return oss.str();
+ }
+
+ bool isCyclic() const { return connections.isCyclic(); }
+
+ private:
+ Graph connections;
+ int64_t nextRecursionId = 0;
+ int64_t nextConnectionId = 0;
+ // Use ordered map for deterministic output.
+ std::map<Graph::Node, NodeInfo> info;
+ };
+
+ Tree breakCycles(const Graph &input) {
+ assert(input.isCyclic() && "input graph is not cyclic");
+
+ Tree output;
+
+ DenseMap<Graph::Node, int64_t> recMap;
+ auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> {
+ auto [node, recId] = output.addRecursiveSelfNode(inNode);
+ recMap[inNode] = recId;
+ return node;
+ };
+
+ CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker);
+
+ std::function<Graph::Node(Graph::Node)> replaceNode =
+ [&](Graph::Node inNode) {
+ auto cacheEntry = cache.lookupOrInit(inNode);
+ if (std::optional<Graph::Node> result = cacheEntry.get())
+ return *result;
+
+ // Recursively replace its neighbors.
+ SmallVector<Graph::Node> neighbors;
+ if (auto it = input.edges.find(inNode); it != input.edges.end())
+ neighbors = SmallVector<Graph::Node>(
+ llvm::map_range(it->second, replaceNode));
+
+ // Create a new node in the output graph.
+ int64_t recursionIndex =
+ cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1;
+ Graph::Node result = output.addNode(inNode, recursionIndex);
+
+ for (Graph::Node neighbor : neighbors)
+ output.addEdge(result, neighbor);
+
+ cacheEntry.resolve(result);
+ return result;
+ };
+
+ for (Graph::Node root : llvm::make_first_range(input.edges))
+ replaceNode(root);
+
+ return output;
+ }
+
+ std::string trimComments(StringRef input) {
+ std::ostringstream oss;
+ bool isNewLine = false;
+ bool isComment = false;
+ for (char c : input) {
+ // Lines beginning with ';' are comments.
+ if (isNewLine && c == ';')
+ isComment = true;
+
+ if (!isComment)
+ oss << c;
+
+ if (c == '\n') {
+ isNewLine = true;
+ isComment = false;
+ }
+ }
+ return oss.str();
+ }
+};
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
+ // 0 -> 1 -> 2
+ // ^ |
+ // +---------+
+ Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}};
+ Tree output = breakCycles(input);
+ ASSERT_FALSE(output.isCyclic()) << output.serialize();
+ EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n2
+2: n1
+3: n0<0>
+; root 1
+4: n2
+; root 2
+5: n1
+edges
+1: 0
+2: 1
+3: 2
+4: 3
+5: 4
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) {
+ // +----> 1 -----+
+ // | v
+ // 0 <---------- 3
+ // | ^
+ // +----> 2 -----+
+ //
+ // Two loops:
+ // 0 -> 1 -> 3 -> 0
+ // 0 -> 2 -> 3 -> 0
+ Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
+ Tree output = breakCycles(input);
+ ASSERT_FALSE(output.isCyclic()) << output.serialize();
+ EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n3
+2: n1
+3: n2
+4: n0<0>
+; root 1
+5: n3
+6: n1
+; root 2
+7: n2
+edges
+1: 0
+2: 1
+3: 1
+4: 2 3
+5: 4
+6: 5
+7: 5
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) {
+ // +----> 1 -----+
+ // | ^ v
+ // 0 <----+----- 2
+ //
+ // Two nested loops:
+ // 0 -> 1 -> 2 -> 0
+ // 1 -> 2 -> 1
+ Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}};
+ Tree output = breakCycles(input);
+ ASSERT_FALSE(output.isCyclic()) << output.serialize();
+ EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n1<1>
+2: n2
+3: n1<1>
+4: n0<0>
+; root 1
+5: n1<2>
+6: n2
+7: n1<2>
+; root 2
+8: n2
+edges
+2: 0 1
+3: 2
+4: 3
+6: 4 5
+7: 6
+8: 4 7
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) {
+ // +----> 1 -----+
+ // | ^ v
+ // 0 <----+----- 3
+ // | v ^
+ // +----> 2 -----+
+ //
+ // Two sets of nested loops:
+ // 0 -> 1 -> 3 -> 0
+ // 1 -> 3 -> 1
+ // 0 -> 2 -> 3 -> 0
+ // 2 -> 3 -> 2
+ Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
+ Tree output = breakCycles(input);
+ ASSERT_FALSE(output.isCyclic()) << output.serialize();
+ EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n1<1>
+2: n3<2>
+3: n2
+4: n3<2>
+5: n1<1>
+6: n2<3>
+7: n3
+8: n2<3>
+9: n0<0>
+; root 1
+10: n1<4>
+11: n3<5>
+12: n2
+13: n3<5>
+14: n1<4>
+; root 2
+15: n2<6>
+16: n3
+17: n2<6>
+; root 3
+18: n3
+edges
+; root 0
+3: 2
+4: 0 1 3
+5: 4
+7: 0 5 6
+8: 7
+9: 5 8
+; root 1
+12: 11
+13: 9 10 12
+14: 13
+; root 2
+16: 9 14 15
+17: 16
+; root 3
+18: 9 14 17
+)"));
+}
More information about the Mlir-commits
mailing list