[Mlir-commits] [mlir] [MLIR][LLVM] Use CyclicReplacerCache for recursive DIType import (PR #98203)
Billy Zhu
llvmlistbot at llvm.org
Thu Jul 11 22:14:57 PDT 2024
https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/98203
>From bcc3edaa9b80c1ee3e87ad15894382a065bef9ad Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 10:26:05 -0700
Subject: [PATCH 1/5] add replacer cache and gtests
---
.../mlir/Support/CyclicReplacerCache.h | 277 ++++++++++
mlir/unittests/Support/CMakeLists.txt | 1 +
.../Support/CyclicReplacerCacheTest.cpp | 472 ++++++++++++++++++
3 files changed, 750 insertions(+)
create mode 100644 mlir/include/mlir/Support/CyclicReplacerCache.h
create mode 100644 mlir/unittests/Support/CyclicReplacerCacheTest.cpp
diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
new file mode 100644
index 0000000000000..9a703676fff11
--- /dev/null
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -0,0 +1,277 @@
+//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains helper classes for caching replacer-like functions that
+// map values between two domains. They are able to handle replacer logic that
+// contains self-recursion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_CACHINGREPLACER_H
+#define MLIR_SUPPORT_CACHINGREPLACER_H
+
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include <set>
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// CyclicReplacerCache
+//===----------------------------------------------------------------------===//
+
+/// A cache for replacer-like functions that map values between two domains. The
+/// difference compared to just using a map to cache in-out pairs is that this
+/// class is able to handle replacer logic that is self-recursive (and thus may
+/// cause infinite recursion in the naive case).
+///
+/// This class provides a hook for the user to perform cycle pruning when a
+/// cycle is identified, and is able to perform context-sensitive caching so
+/// that the replacement result for an input that is part of a pruned cycle can
+/// be distinct from the replacement result for the same input when it is not
+/// part of a cycle.
+///
+/// In addition, this class allows deferring cycle pruning until specific inputs
+/// are repeated. This is useful for cases where not all elements in a cycle can
+/// perform pruning. The user still must guarantee that at least one element in
+/// any given cycle can perform pruning. Even if not, an assertion will
+/// eventually be tripped instead of infinite recursion (the run-time is
+/// linearly bounded by the maximum cycle length of its input).
+template <typename InT, typename OutT>
+class CyclicReplacerCache {
+public:
+ /// User-provided replacement function & cycle-breaking functions.
+ /// The cycle-breaking function must not make any more recursive invocations
+ /// to this cached replacer.
+ using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>;
+
+ CyclicReplacerCache() = delete;
+ CyclicReplacerCache(CycleBreakerFn cycleBreaker)
+ : cycleBreaker(std::move(cycleBreaker)) {}
+
+ /// A possibly unresolved cache entry.
+ /// If unresolved, the entry must be resolved before it goes out of scope.
+ struct CacheEntry {
+ public:
+ ~CacheEntry() { assert(result && "unresovled cache entry"); }
+
+ /// Check whether this node was repeated during recursive replacements.
+ /// This only makes sense to be called after all recursive replacements are
+ /// completed and the current element has resurfaced to the top of the
+ /// replacement stack.
+ bool wasRepeated() const {
+ // If the top frame includes itself as a dependency, then it must have
+ // been repeated.
+ ReplacementFrame &currFrame = cache.replacementStack.back();
+ size_t currFrameIndex = cache.replacementStack.size() - 1;
+ return currFrame.dependentFrames.count(currFrameIndex);
+ }
+
+ /// Resolve an unresolved cache entry by providing the result to be stored
+ /// in the cache.
+ void resolve(OutT result) {
+ assert(!this->result && "cache entry already resolved");
+ this->result = result;
+ cache.finalizeReplacement(element, result);
+ }
+
+ /// Get the resolved result if one exists.
+ std::optional<OutT> get() { return result; }
+
+ private:
+ friend class CyclicReplacerCache;
+ CacheEntry() = delete;
+ CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
+ std::optional<OutT> result = std::nullopt)
+ : cache(cache), element(element), result(result) {}
+
+ CyclicReplacerCache<InT, OutT> &cache;
+ InT element;
+ std::optional<OutT> result;
+ };
+
+ /// Lookup the cache for a pre-calculated replacement for `element`.
+ /// If one exists, a resolved CacheEntry will be returned. Otherwise, an
+ /// unresolved CacheEntry will be returned, and the caller must resolve it
+ /// with the calculated replacement so it can be registered in the cache for
+ /// future use.
+ /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
+ /// CacheEntries that are returned must be resolved in reverse order of
+ /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
+ /// the first retrieved CacheEntry must be resolved last. This should be
+ /// natural when used as a stack / inside recursion.
+ CacheEntry lookupOrInit(const InT &element);
+
+private:
+ /// Register the replacement in the cache and update the replacementStack.
+ void finalizeReplacement(const InT &element, const OutT &result);
+
+ CycleBreakerFn cycleBreaker;
+ DenseMap<InT, OutT> standaloneCache;
+
+ struct DependentReplacement {
+ OutT replacement;
+ /// The highest replacement frame index that this cache entry is dependent
+ /// on.
+ size_t highestDependentFrame;
+ };
+ DenseMap<InT, DependentReplacement> dependentCache;
+
+ struct ReplacementFrame {
+ /// The set of elements that is only legal while under this current frame.
+ /// They need to be removed from the cache when this frame is popped off the
+ /// replacement stack.
+ DenseSet<InT> dependingReplacements;
+ /// The set of frame indices that this current frame's replacement is
+ /// dependent on, ordered from highest to lowest.
+ std::set<size_t, std::greater<size_t>> dependentFrames;
+ };
+ /// Every element currently in the progress of being replaced pushes a frame
+ /// onto this stack.
+ SmallVector<ReplacementFrame> replacementStack;
+ /// Maps from each input element to its indices on the replacement stack.
+ DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame;
+ /// If set to true, we are currently asking an element to break a cycle. No
+ /// more recursive invocations is allowed while this is true (the replacement
+ /// stack can no longer grow).
+ bool resolvingCycle = false;
+};
+
+template <typename InT, typename OutT>
+typename CyclicReplacerCache<InT, OutT>::CacheEntry
+CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
+ assert(!resolvingCycle &&
+ "illegal recursive invocation while breaking cycle");
+
+ if (auto it = standaloneCache.find(element); it != standaloneCache.end())
+ return CacheEntry(*this, element, it->second);
+
+ if (auto it = dependentCache.find(element); it != dependentCache.end()) {
+ // pdate the current top frame (the element that invoked this current
+ // replacement) to include any dependencies the cache entry had.
+ ReplacementFrame &currFrame = replacementStack.back();
+ currFrame.dependentFrames.insert(it->second.highestDependentFrame);
+ return CacheEntry(*this, element, it->second.replacement);
+ }
+
+ auto [it, inserted] = cyclicElementFrame.try_emplace(element);
+ if (!inserted) {
+ // This is a repeat of a known element. Try to break cycle here.
+ resolvingCycle = true;
+ std::optional<OutT> result = cycleBreaker(element);
+ resolvingCycle = false;
+ if (result) {
+ // Cycle was broken.
+ size_t dependentFrame = it->second.back();
+ dependentCache[element] = {*result, dependentFrame};
+ ReplacementFrame &currFrame = replacementStack.back();
+ // If this is a repeat, there is no replacement frame to pop. Mark the top
+ // frame as being dependent on this element.
+ currFrame.dependentFrames.insert(dependentFrame);
+
+ return CacheEntry(*this, element, *result);
+ }
+
+ // Cycle could not be broken.
+ // A legal setup must ensure at least one element of each cycle can break
+ // cycles. Under this setup, each element can be seen at most twice before
+ // the cycle is broken. If we see an element more than twice, we know this
+ // is an illegal setup.
+ assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
+ }
+
+ // Otherwise, either this is the first time we see this element, or this
+ // element could not break this cycle.
+ it->second.push_back(replacementStack.size());
+ replacementStack.emplace_back();
+
+ return CacheEntry(*this, element);
+}
+
+template <typename InT, typename OutT>
+void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
+ const OutT &result) {
+ ReplacementFrame &currFrame = replacementStack.back();
+ // With the conclusion of this replacement frame, the current element is no
+ // longer a dependent element.
+ currFrame.dependentFrames.erase(replacementStack.size() - 1);
+
+ auto prevLayerIter = ++replacementStack.rbegin();
+ if (prevLayerIter == replacementStack.rend()) {
+ // If this is the last frame, there should be zero dependents.
+ assert(currFrame.dependentFrames.empty() &&
+ "internal error: top-level dependent replacement");
+ // Cache standalone result.
+ standaloneCache[element] = result;
+ } else if (currFrame.dependentFrames.empty()) {
+ // Cache standalone result.
+ standaloneCache[element] = result;
+ } else {
+ // Cache dependent result.
+ size_t highestDependentFrame = *currFrame.dependentFrames.begin();
+ dependentCache[element] = {result, highestDependentFrame};
+
+ // Otherwise, the previous frame inherits the same dependent frames.
+ prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
+ currFrame.dependentFrames.end());
+
+ // Mark this current replacement as a depending replacement on the closest
+ // dependent frame.
+ replacementStack[highestDependentFrame].dependingReplacements.insert(
+ element);
+ }
+
+ // All depending replacements in the cache must be purged.
+ for (InT key : currFrame.dependingReplacements)
+ dependentCache.erase(key);
+
+ replacementStack.pop_back();
+ auto it = cyclicElementFrame.find(element);
+ it->second.pop_back();
+ if (it->second.empty())
+ cyclicElementFrame.erase(it);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer
+//===----------------------------------------------------------------------===//
+
+/// A helper class for cases where the input/output types of the replacer
+/// function is identical to the types stored in the cache. This class wraps
+/// the user-provided replacer function, and can be used in place of the user
+/// function.
+template <typename InT, typename OutT>
+class CachedCyclicReplacer {
+public:
+ using ReplacerFn = std::function<OutT(const InT &)>;
+ using CycleBreakerFn =
+ typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
+
+ CachedCyclicReplacer() = delete;
+ CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
+ : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
+
+ OutT operator()(const InT &element) {
+ auto cacheEntry = cache.lookupOrInit(element);
+ if (std::optional<OutT> result = cacheEntry.get())
+ return *result;
+
+ OutT result = replacer(element);
+ cacheEntry.resolve(result);
+ return result;
+ }
+
+private:
+ ReplacerFn replacer;
+ CyclicReplacerCache<InT, OutT> cache;
+};
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_CACHINGREPLACER_H
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 1dbf072bcbbfd..ec79a1c640909 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_unittest(MLIRSupportTests
+ CyclicReplacerCacheTest.cpp
IndentedOstreamTest.cpp
StorageUniquerTest.cpp
)
diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
new file mode 100644
index 0000000000000..23748e29765cb
--- /dev/null
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -0,0 +1,472 @@
+//===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/CyclicReplacerCache.h"
+#include "llvm/ADT/SetVector.h"
+#include "gmock/gmock.h"
+#include <map>
+#include <set>
+
+using namespace mlir;
+
+TEST(CachedCyclicReplacerTest, testNoRecursion) {
+ CachedCyclicReplacer<int, bool> replacer(
+ /*replacer=*/[](int n) { return static_cast<bool>(n); },
+ /*cycleBreaker=*/[](int n) { return std::nullopt; });
+
+ EXPECT_EQ(replacer(3), true);
+ EXPECT_EQ(replacer(0), false);
+}
+
+TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
+ // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
+ CachedCyclicReplacer<int, int> replacer(
+ /*replacer=*/[&](int n) { return replacer((n + 1) % 3); },
+ /*cycleBreaker=*/[&](int n) { return -1; });
+
+ // Starting at 0.
+ EXPECT_EQ(replacer(0), -1);
+ // Starting at 2.
+ EXPECT_EQ(replacer(2), -1);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+/// This set of tests uses a replacer function that maps ints into vectors of
+/// ints.
+///
+/// The replacement result for input `n` is the replacement result of `(n+1)%3`
+/// appended with an element `42`. Theoretically, this will produce an
+/// infinitely long vector. The cycle-breaker function prunes this infinite
+/// recursion in the replacer logic by returning an empty vector upon the first
+/// re-occurrence of an input value.
+class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+ // N ==> (N+1) % 3
+ // This will create a chain of infinite length without recursion pruning.
+ CachedCyclicReplacerChainRecursionPruningTest()
+ : replacer(
+ [&](int n) {
+ ++invokeCount;
+ std::vector<int> result = replacer((n + 1) % 3);
+ result.push_back(42);
+ return result;
+ },
+ [&](int n) -> std::optional<std::vector<int>> {
+ return baseCase.value_or(n) == n
+ ? std::make_optional(std::vector<int>{})
+ : std::nullopt;
+ }) {}
+
+ std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
+
+ CachedCyclicReplacer<int, std::vector<int>> replacer;
+ int invokeCount = 0;
+ std::optional<int> baseCase = std::nullopt;
+};
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+ // Starting at 0. Cycle length is 3.
+ EXPECT_EQ(replacer(0), getChain(3));
+ EXPECT_EQ(invokeCount, 3);
+
+ // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+ invokeCount = 0;
+ EXPECT_EQ(replacer(1), getChain(5));
+ EXPECT_EQ(invokeCount, 2);
+
+ // Starting at 2. Cycle length is 4. Entire result is cached.
+ invokeCount = 0;
+ EXPECT_EQ(replacer(2), getChain(4));
+ EXPECT_EQ(invokeCount, 0);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
+ // Starting at 1. Cycle length is 3.
+ EXPECT_EQ(replacer(1), getChain(3));
+ EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
+ baseCase = 0;
+
+ // Starting at 0. Cycle length is 3.
+ EXPECT_EQ(replacer(0), getChain(3));
+ EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
+ baseCase = 0;
+
+ // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+ EXPECT_EQ(replacer(1), getChain(5));
+ EXPECT_EQ(invokeCount, 5);
+
+ // Starting at 0. Cycle length is 3. Entire result is cached.
+ invokeCount = 0;
+ EXPECT_EQ(replacer(0), getChain(3));
+ EXPECT_EQ(invokeCount, 0);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: GraphReplacement
+//===----------------------------------------------------------------------===//
+
+/// This set of tests uses a replacer function that maps from cyclic graphs to
+/// trees, pruning out cycles in the process.
+///
+/// It consists of two helper classes:
+/// - Graph
+/// - A directed graph where nodes are non-negative integers.
+/// - PrunedGraph
+/// - A Graph where edges that used to cause cycles are now represented with
+/// an indirection (a recursionId).
+class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
+public:
+ /// A directed graph where nodes are non-negative integers.
+ struct Graph {
+ using Node = int64_t;
+
+ /// Use ordered containers for deterministic output.
+ /// Nodes without outgoing edges are considered nonexistent.
+ std::map<Node, std::set<Node>> edges;
+
+ void addEdge(Node src, Node sink) { edges[src].insert(sink); }
+
+ bool isCyclic() const {
+ DenseSet<Node> visited;
+ for (Node root : llvm::make_first_range(edges)) {
+ if (visited.contains(root))
+ continue;
+
+ SetVector<Node> path;
+ SmallVector<Node> workstack;
+ workstack.push_back(root);
+ while (!workstack.empty()) {
+ Node curr = workstack.back();
+ workstack.pop_back();
+
+ if (curr < 0) {
+ // A negative node signals the end of processing all of this node's
+ // children. Remove self from path.
+ assert(path.back() == -curr && "internal inconsistency");
+ path.pop_back();
+ continue;
+ }
+
+ if (path.contains(curr))
+ return true;
+
+ visited.insert(curr);
+ auto edgesIter = edges.find(curr);
+ if (edgesIter == edges.end() || edgesIter->second.empty())
+ continue;
+
+ path.insert(curr);
+ // Push negative node to signify recursion return.
+ workstack.push_back(-curr);
+ workstack.insert(workstack.end(), edgesIter->second.begin(),
+ edgesIter->second.end());
+ }
+ }
+ return false;
+ }
+
+ /// Deterministic output for testing.
+ std::string serialize() const {
+ std::ostringstream oss;
+ for (const auto &[src, neighbors] : edges) {
+ oss << src << ":";
+ for (Graph::Node neighbor : neighbors)
+ oss << " " << neighbor;
+ oss << "\n";
+ }
+ return oss.str();
+ }
+ };
+
+ /// A Graph where edges that used to cause cycles (back-edges) are now
+ /// represented with an indirection (a recursionId).
+ ///
+ /// In addition to each node being an integer, each node also tracks the
+ /// original integer id it had in the original graph. This way for every
+ /// back-edge, we can represent it as pointing to a new instance of the
+ /// original node. Then we mark the original node and the new instance with
+ /// a new unique recursionId to indicate that they're supposed to be the same
+ /// graph.
+ struct PrunedGraph {
+ using Node = Graph::Node;
+ struct NodeInfo {
+ Graph::Node originalId;
+ // A negative recursive index means not recursive.
+ int64_t recursionId;
+ };
+
+ /// Add a regular non-recursive-self node.
+ Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
+ Node id = nextConnectionId++;
+ info[id] = {originalId, recursionIndex};
+ return id;
+ }
+ /// Add a recursive-self-node, i.e. a duplicate of the original node that is
+ /// meant to represent an indirection to it.
+ std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) {
+ return {addNode(originalId, nextRecursionId), nextRecursionId++};
+ }
+ void addEdge(Node src, Node sink) { connections.addEdge(src, sink); }
+
+ /// Deterministic output for testing.
+ std::string serialize() const {
+ std::ostringstream oss;
+ oss << "nodes\n";
+ for (const auto &[nodeId, nodeInfo] : info) {
+ oss << nodeId << ": n" << nodeInfo.originalId;
+ if (nodeInfo.recursionId >= 0)
+ oss << '<' << nodeInfo.recursionId << '>';
+ oss << "\n";
+ }
+ oss << "edges\n";
+ oss << connections.serialize();
+ return oss.str();
+ }
+
+ bool isCyclic() const { return connections.isCyclic(); }
+
+ private:
+ Graph connections;
+ int64_t nextRecursionId = 0;
+ int64_t nextConnectionId = 0;
+ // Use ordered map for deterministic output.
+ std::map<Graph::Node, NodeInfo> info;
+ };
+
+ PrunedGraph breakCycles(const Graph &input) {
+ assert(input.isCyclic() && "input graph is not cyclic");
+
+ PrunedGraph output;
+
+ DenseMap<Graph::Node, int64_t> recMap;
+ auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> {
+ auto [node, recId] = output.addRecursiveSelfNode(inNode);
+ recMap[inNode] = recId;
+ return node;
+ };
+
+ CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker);
+
+ std::function<Graph::Node(Graph::Node)> replaceNode =
+ [&](Graph::Node inNode) {
+ auto cacheEntry = cache.lookupOrInit(inNode);
+ if (std::optional<Graph::Node> result = cacheEntry.get())
+ return *result;
+
+ // Recursively replace its neighbors.
+ SmallVector<Graph::Node> neighbors;
+ if (auto it = input.edges.find(inNode); it != input.edges.end())
+ neighbors = SmallVector<Graph::Node>(
+ llvm::map_range(it->second, replaceNode));
+
+ // Create a new node in the output graph.
+ int64_t recursionIndex =
+ cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1;
+ Graph::Node result = output.addNode(inNode, recursionIndex);
+
+ for (Graph::Node neighbor : neighbors)
+ output.addEdge(result, neighbor);
+
+ cacheEntry.resolve(result);
+ return result;
+ };
+
+ /// Translate starting from each node.
+ for (Graph::Node root : llvm::make_first_range(input.edges))
+ replaceNode(root);
+
+ return output;
+ }
+
+ /// Helper for serialization tests that allow putting comments in the
+ /// serialized format. Every line that begins with a `;` is considered a
+ /// comment. The entire line, incl. the terminating `\n` is removed.
+ std::string trimComments(StringRef input) {
+ std::ostringstream oss;
+ bool isNewLine = false;
+ bool isComment = false;
+ for (char c : input) {
+ // Lines beginning with ';' are comments.
+ if (isNewLine && c == ';')
+ isComment = true;
+
+ if (!isComment)
+ oss << c;
+
+ if (c == '\n') {
+ isNewLine = true;
+ isComment = false;
+ }
+ }
+ return oss.str();
+ }
+};
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
+ // 0 -> 1 -> 2
+ // ^ |
+ // +---------+
+ Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}};
+ PrunedGraph output = breakCycles(input);
+ ASSERT_FALSE(output.isCyclic()) << output.serialize();
+ EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n2
+2: n1
+3: n0<0>
+; root 1
+4: n2
+; root 2
+5: n1
+edges
+1: 0
+2: 1
+3: 2
+4: 3
+5: 4
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) {
+ // +----> 1 -----+
+ // | v
+ // 0 <---------- 3
+ // | ^
+ // +----> 2 -----+
+ //
+ // Two loops:
+ // 0 -> 1 -> 3 -> 0
+ // 0 -> 2 -> 3 -> 0
+ Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
+ PrunedGraph output = breakCycles(input);
+ ASSERT_FALSE(output.isCyclic()) << output.serialize();
+ EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n3
+2: n1
+3: n2
+4: n0<0>
+; root 1
+5: n3
+6: n1
+; root 2
+7: n2
+edges
+1: 0
+2: 1
+3: 1
+4: 2 3
+5: 4
+6: 5
+7: 5
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) {
+ // +----> 1 -----+
+ // | ^ v
+ // 0 <----+----- 2
+ //
+ // Two nested loops:
+ // 0 -> 1 -> 2 -> 0
+ // 1 -> 2 -> 1
+ Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}};
+ PrunedGraph output = breakCycles(input);
+ ASSERT_FALSE(output.isCyclic()) << output.serialize();
+ EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n1<1>
+2: n2
+3: n1<1>
+4: n0<0>
+; root 1
+5: n1<2>
+6: n2
+7: n1<2>
+; root 2
+8: n2
+edges
+2: 0 1
+3: 2
+4: 3
+6: 4 5
+7: 6
+8: 4 7
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) {
+ // +----> 1 -----+
+ // | ^ v
+ // 0 <----+----- 3
+ // | v ^
+ // +----> 2 -----+
+ //
+ // Two sets of nested loops:
+ // 0 -> 1 -> 3 -> 0
+ // 1 -> 3 -> 1
+ // 0 -> 2 -> 3 -> 0
+ // 2 -> 3 -> 2
+ Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
+ PrunedGraph output = breakCycles(input);
+ ASSERT_FALSE(output.isCyclic()) << output.serialize();
+ EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n1<1>
+2: n3<2>
+3: n2
+4: n3<2>
+5: n1<1>
+6: n2<3>
+7: n3
+8: n2<3>
+9: n0<0>
+; root 1
+10: n1<4>
+11: n3<5>
+12: n2
+13: n3<5>
+14: n1<4>
+; root 2
+15: n2<6>
+16: n3
+17: n2<6>
+; root 3
+18: n3
+edges
+; root 0
+3: 2
+4: 0 1 3
+5: 4
+7: 0 5 6
+8: 7
+9: 5 8
+; root 1
+12: 11
+13: 9 10 12
+14: 13
+; root 2
+16: 9 14 15
+17: 16
+; root 3
+18: 9 14 17
+)"));
+}
>From a0b59b246a1b8860fad16f1f74537c38b5abf2cf Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 10:28:21 -0700
Subject: [PATCH 2/5] use cyclic cache in importer and update tests
---
mlir/lib/Target/LLVMIR/DebugImporter.cpp | 143 ++++---------------
mlir/lib/Target/LLVMIR/DebugImporter.h | 101 ++-----------
mlir/test/Target/LLVMIR/Import/debug-info.ll | 34 +++--
3 files changed, 51 insertions(+), 227 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 15737aa681c59..f104b72209c39 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -28,7 +28,7 @@ using namespace mlir::LLVM::detail;
DebugImporter::DebugImporter(ModuleOp mlirModule,
bool dropDICompositeTypeElements)
- : recursionPruner(mlirModule.getContext()),
+ : cache([&](llvm::DINode *node) { return createRecSelf(node); }),
context(mlirModule.getContext()), mlirModule(mlirModule),
dropDICompositeTypeElements(dropDICompositeTypeElements) {}
@@ -287,16 +287,9 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
return nullptr;
// Check for a cached instance.
- if (DINodeAttr attr = nodeToAttr.lookup(node))
- return attr;
-
- // Register with the recursive translator. If it can be handled without
- // recursing into it, return the result immediately.
- if (DINodeAttr attr = recursionPruner.pruneOrPushTranslationStack(node))
- return attr;
-
- auto guard = llvm::make_scope_exit(
- [&]() { recursionPruner.popTranslationStack(node); });
+ auto cacheEntry = cache.lookupOrInit(node);
+ if (std::optional<DINodeAttr> result = cacheEntry.get())
+ return *result;
// Convert the debug metadata if possible.
auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
@@ -335,20 +328,20 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
return nullptr;
};
if (DINodeAttr attr = translateNode(node)) {
- auto [result, isSelfContained] =
- recursionPruner.finalizeTranslation(node, attr);
- // Only cache fully self-contained nodes.
- if (isSelfContained)
- nodeToAttr.try_emplace(node, result);
- return result;
+ // If this node was repeated, lookup its recursive ID and assign it to the
+ // base result.
+ if (cacheEntry.wasRepeated()) {
+ DistinctAttr recId = nodeToRecId.lookup(node);
+ auto recType = cast<DIRecursiveTypeAttrInterface>(attr);
+ attr = cast<DINodeAttr>(recType.withRecId(recId));
+ }
+ cacheEntry.resolve(attr);
+ return attr;
}
+ cacheEntry.resolve(nullptr);
return nullptr;
}
-//===----------------------------------------------------------------------===//
-// RecursionPruner
-//===----------------------------------------------------------------------===//
-
/// Get the `getRecSelf` constructor for the translated type of `node` if its
/// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
static function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
@@ -361,104 +354,20 @@ getRecSelfConstructor(llvm::DINode *node) {
.Default(CtorType());
}
-DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
- llvm::DINode *node) {
- // If the node type is capable of being recursive, check if it's seen
- // before.
+std::optional<DINodeAttr> DebugImporter::createRecSelf(llvm::DINode *node) {
auto recSelfCtor = getRecSelfConstructor(node);
- if (recSelfCtor) {
- // If a cyclic dependency is detected since the same node is being
- // traversed twice, emit a recursive self type, and mark the duplicate
- // node on the translationStack so it can emit a recursive decl type.
- auto [iter, inserted] = translationStack.try_emplace(node);
- if (!inserted) {
- // The original node may have already been assigned a recursive ID from
- // a different self-reference. Use that if possible.
- DIRecursiveTypeAttrInterface recSelf = iter->second.recSelf;
- if (!recSelf) {
- DistinctAttr recId = nodeToRecId.lookup(node);
- if (!recId) {
- recId = DistinctAttr::create(UnitAttr::get(context));
- nodeToRecId[node] = recId;
- }
- recSelf = recSelfCtor(recId);
- iter->second.recSelf = recSelf;
- }
- // Inject the self-ref into the previous layer.
- translationStack.back().second.unboundSelfRefs.insert(recSelf);
- return cast<DINodeAttr>(recSelf);
- }
+ if (!recSelfCtor)
+ return std::nullopt;
+
+ // The original node may have already been assigned a recursive ID from
+ // a different self-reference. Use that if possible.
+ DistinctAttr recId = nodeToRecId.lookup(node);
+ if (!recId) {
+ recId = DistinctAttr::create(UnitAttr::get(context));
+ nodeToRecId[node] = recId;
}
-
- return lookup(node);
-}
-
-std::pair<DINodeAttr, bool>
-DebugImporter::RecursionPruner::finalizeTranslation(llvm::DINode *node,
- DINodeAttr result) {
- // If `node` is not a potentially recursive type, it will not be on the
- // translation stack. Nothing to set in this case.
- if (translationStack.empty())
- return {result, true};
- if (translationStack.back().first != node)
- return {result, translationStack.back().second.unboundSelfRefs.empty()};
-
- TranslationState &state = translationStack.back().second;
-
- // If this node is actually recursive, set the recId onto `result`.
- if (DIRecursiveTypeAttrInterface recSelf = state.recSelf) {
- auto recType = cast<DIRecursiveTypeAttrInterface>(result);
- result = cast<DINodeAttr>(recType.withRecId(recSelf.getRecId()));
- // Remove this recSelf from the set of unbound selfRefs.
- state.unboundSelfRefs.erase(recSelf);
- }
-
- // Insert the result into our internal cache if it's not self-contained.
- if (!state.unboundSelfRefs.empty()) {
- [[maybe_unused]] auto [_, inserted] = dependentCache.try_emplace(
- node, DependentTranslation{result, state.unboundSelfRefs});
- assert(inserted && "invalid state: caching the same DINode twice");
- return {result, false};
- }
- return {result, true};
-}
-
-void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
- // If `node` is not a potentially recursive type, it will not be on the
- // translation stack. Nothing to handle in this case.
- if (translationStack.empty() || translationStack.back().first != node)
- return;
-
- // At the end of the stack, all unbound self-refs must be resolved already,
- // and the entire cache should be accounted for.
- TranslationState &currLayerState = translationStack.back().second;
- if (translationStack.size() == 1) {
- assert(currLayerState.unboundSelfRefs.empty() &&
- "internal error: unbound recursive self reference at top level.");
- translationStack.pop_back();
- return;
- }
-
- // Copy unboundSelfRefs down to the previous level.
- TranslationState &nextLayerState = (++translationStack.rbegin())->second;
- nextLayerState.unboundSelfRefs.insert(currLayerState.unboundSelfRefs.begin(),
- currLayerState.unboundSelfRefs.end());
- translationStack.pop_back();
-}
-
-DINodeAttr DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
- auto cacheIter = dependentCache.find(node);
- if (cacheIter == dependentCache.end())
- return {};
-
- DependentTranslation &entry = cacheIter->second;
- if (llvm::set_is_subset(entry.unboundSelfRefs,
- translationStack.back().second.unboundSelfRefs))
- return entry.attr;
-
- // Stale cache entry.
- dependentCache.erase(cacheIter);
- return {};
+ DIRecursiveTypeAttrInterface recSelf = recSelfCtor(recId);
+ return cast<DINodeAttr>(recSelf);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index fac86dbe2cdd2..4a2bf35c160e1 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/CyclicReplacerCache.h"
#include "llvm/IR/DebugInfoMetadata.h"
namespace mlir {
@@ -87,102 +88,18 @@ class DebugImporter {
/// for it, or create a new one if not.
DistinctAttr getOrCreateDistinctID(llvm::DINode *node);
- /// A mapping between LLVM debug metadata and the corresponding attribute.
- DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
+ std::optional<DINodeAttr> createRecSelf(llvm::DINode *node);
+
/// A mapping between distinct LLVM debug metadata nodes and the corresponding
/// distinct id attribute.
DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
- /// Translation helper for recursive DINodes.
- /// Works alongside a stack-based DINode translator (the "main translator")
- /// for gracefully handling DINodes that are recursive.
- ///
- /// Usage:
- /// - Before translating a node, call `pruneOrPushTranslationStack` to see if
- /// the pruner can preempt this translation. If this is a node that the
- /// pruner already knows how to handle, it will return the translated
- /// DINodeAttr.
- /// - After a node is successfully translated by the main translator, call
- /// `finalizeTranslation` to save the translated result with the pruner, and
- /// give it a chance to further modify the result.
- /// - Regardless of success or failure by the main translator, always call
- /// `popTranslationStack` at the end of translating a node. This is
- /// necessary to keep the internal book-keeping in sync.
- ///
- /// This helper maintains an internal cache so that no recursive type will
- /// be translated more than once by the main translator.
- /// This internal cache is different from the cache maintained by the main
- /// translator because it may store nodes that are not self-contained (i.e.
- /// contain unbounded recursive self-references).
- class RecursionPruner {
- public:
- RecursionPruner(MLIRContext *context) : context(context) {}
-
- /// If this node is a recursive instance that was previously seen, returns a
- /// self-reference. If this node was previously cached, returns the cached
- /// result. Otherwise, returns null attr, and a translation stack frame is
- /// created for this node. Expects `finalizeTranslation` &
- /// `popTranslationStack` to be called on this node later.
- DINodeAttr pruneOrPushTranslationStack(llvm::DINode *node);
-
- /// Register the translated result of `node`. Returns the finalized result
- /// (with recId if recursive) and whether the result is self-contained
- /// (i.e. contains no unbound self-refs).
- std::pair<DINodeAttr, bool> finalizeTranslation(llvm::DINode *node,
- DINodeAttr result);
-
- /// Pop off a frame from the translation stack after a node is done being
- /// translated.
- void popTranslationStack(llvm::DINode *node);
-
- private:
- /// Returns the cached result (if exists) or null.
- /// The cache entry will be removed if not all of its dependent self-refs
- /// exists.
- DINodeAttr lookup(llvm::DINode *node);
-
- MLIRContext *context;
-
- /// A cached translation that contains the translated attribute as well
- /// as any unbound self-references that it depends on.
- struct DependentTranslation {
- /// The translated attr. May contain unbound self-references for other
- /// recursive attrs.
- DINodeAttr attr;
- /// The set of unbound self-refs that this cached entry refers to. All
- /// these self-refs must exist for the cached entry to be valid.
- DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
- };
- /// A mapping between LLVM debug metadata and the corresponding attribute.
- /// Only contains those with unboundSelfRefs. Fully self-contained attrs
- /// will be cached by the outer main translator.
- DenseMap<llvm::DINode *, DependentTranslation> dependentCache;
-
- /// Each potentially recursive node will have a TranslationState pushed onto
- /// the `translationStack` to keep track of whether this node is actually
- /// recursive (i.e. has self-references inside), and other book-keeping.
- struct TranslationState {
- /// The rec-self if this node is indeed a recursive node (i.e. another
- /// instance of itself is seen while translating it). Null if this node
- /// has not been seen again deeper in the translation stack.
- DIRecursiveTypeAttrInterface recSelf;
- /// All the unbound recursive self references in this layer of the
- /// translation stack.
- DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
- };
- /// A stack that stores the metadata nodes that are being traversed. The
- /// stack is used to handle cyclic dependencies during metadata translation.
- /// Each node is pushed with an empty TranslationState. If it is ever seen
- /// later when the stack is deeper, the node is recursive, and its
- /// TranslationState is assigned a recSelf.
- llvm::MapVector<llvm::DINode *, TranslationState> translationStack;
-
- /// A mapping between DINodes that are recursive, and their assigned recId.
- /// This is kept so that repeated occurrences of the same node can reuse the
- /// same ID and be deduplicated.
- DenseMap<llvm::DINode *, DistinctAttr> nodeToRecId;
- };
- RecursionPruner recursionPruner;
+ /// A mapping between DINodes that are recursive, and their assigned recId.
+ /// This is kept so that repeated occurrences of the same node can reuse the
+ /// same ID and be deduplicated.
+ DenseMap<llvm::DINode *, DistinctAttr> nodeToRecId;
+
+ CyclicReplacerCache<llvm::DINode *, DINodeAttr> cache;
MLIRContext *context;
ModuleOp mlirModule;
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index 2173cc33e68a4..a194ecbf2eb20 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -308,16 +308,16 @@ define void @class_method() {
}
; Verify the cyclic composite type is identified, even though conversion begins from the subprogram type.
-; CHECK: #[[COMP_SELF:.+]] = #llvm.di_composite_type<tag = DW_TAG_null, recId = [[REC_ID:.+]]>
-; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP_SELF]], sizeInBits = 64>
-; CHECK: #[[SP_TYPE:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR]]>
-; CHECK: #[[SP_INNER:.+]] = #llvm.di_subprogram<id = [[SP_ID:.+]], compileUnit = #{{.*}}, scope = #[[COMP_SELF]], name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE]]>
-; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, recId = [[REC_ID]], name = "class_name", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[SP_INNER]]>
+; CHECK-DAG: #[[COMP_SELF:.+]] = #llvm.di_composite_type<tag = DW_TAG_null, recId = [[REC_ID:.+]]>
+; CHECK-DAG: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP_SELF]], sizeInBits = 64>
+; CHECK-DAG: #[[SP_TYPE:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR]]>
+; CHECK-DAG: #[[SP_INNER:.+]] = #llvm.di_subprogram<id = [[SP_ID:.+]], compileUnit = #{{.*}}, scope = #[[COMP_SELF]], name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE]]>
+; CHECK-DAG: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, recId = [[REC_ID]], name = "class_name", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[SP_INNER]]>
-; CHECK: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]], sizeInBits = 64>
-; CHECK: #[[SP_TYPE_OUTER:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR_OUTER]]>
-; CHECK: #[[SP_OUTER:.+]] = #llvm.di_subprogram<id = [[SP_ID]], compileUnit = #{{.*}}, scope = #[[COMP]], name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE_OUTER]]>
-; CHECK: #[[LOC]] = loc(fused<#[[SP_OUTER]]>
+; CHECK-DAG: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]], sizeInBits = 64>
+; CHECK-DAG: #[[SP_TYPE_OUTER:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR_OUTER]]>
+; CHECK-DAG: #[[SP_OUTER:.+]] = #llvm.di_subprogram<id = [[SP_ID]], compileUnit = #{{.*}}, scope = #[[COMP]], name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE_OUTER]]>
+; CHECK-DAG: #[[LOC]] = loc(fused<#[[SP_OUTER]]>
!llvm.dbg.cu = !{!1}
!llvm.module.flags = !{!0}
@@ -335,12 +335,12 @@ define void @class_method() {
; // -----
; Verify the cyclic composite type is handled correctly.
-; CHECK: #[[COMP_SELF:.+]] = #llvm.di_composite_type<tag = DW_TAG_null, recId = [[REC_ID:.+]]>
-; CHECK: #[[COMP_PTR_INNER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP_SELF]]>
-; CHECK: #[[FIELD:.+]] = #llvm.di_derived_type<tag = DW_TAG_member, name = "call_field", baseType = #[[COMP_PTR_INNER]]>
-; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, recId = [[REC_ID]], name = "class_field", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[FIELD]]>
-; CHECK: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]]>
-; CHECK: #[[VAR0:.+]] = #llvm.di_local_variable<scope = #{{.*}}, name = "class_field", file = #{{.*}}, type = #[[COMP_PTR_OUTER]]>
+; CHECK-DAG: #[[COMP_SELF:.+]] = #llvm.di_composite_type<tag = DW_TAG_null, recId = [[REC_ID:.+]]>
+; CHECK-DAG: #[[COMP_PTR_INNER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP_SELF]]>
+; CHECK-DAG: #[[FIELD:.+]] = #llvm.di_derived_type<tag = DW_TAG_member, name = "call_field", baseType = #[[COMP_PTR_INNER]]>
+; CHECK-DAG: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, recId = [[REC_ID]], name = "class_field", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[FIELD]]>
+; CHECK-DAG: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]]>
+; CHECK-DAG: #[[VAR0:.+]] = #llvm.di_local_variable<scope = #{{.*}}, name = "class_field", file = #{{.*}}, type = #[[COMP_PTR_OUTER]]>
; CHECK: @class_field
; CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
@@ -727,9 +727,7 @@ define void @class_field(ptr %arg1) !dbg !18 {
; CHECK-DAG: #[[B_TO_C]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_FROM_B:.+]]>
; CHECK-DAG: #[[C_FROM_B]] = #llvm.di_composite_type<{{.*}}recId = [[C_RECID:.+]], {{.*}}name = "C", {{.*}}elements = #[[TO_A_SELF:.+]], #[[TO_B_SELF:.+]], #[[TO_C_SELF:.+]]>
-; CHECK-DAG: #[[C_FROM_A]] = #llvm.di_composite_type<{{.*}}recId = [[C_RECID]], {{.*}}name = "C", {{.*}}elements = #[[TO_A_SELF]], #[[TO_B_INNER:.+]], #[[TO_C_SELF]]
-; CHECK-DAG: #[[TO_B_INNER]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_INNER:.+]]>
-; CHECK-DAG: #[[B_INNER]] = #llvm.di_composite_type<{{.*}}name = "B", {{.*}}elements = #[[TO_C_SELF]]>
+; CHECK-DAG: #[[C_FROM_A]] = #llvm.di_composite_type<{{.*}}recId = [[C_RECID]], {{.*}}name = "C", {{.*}}elements = #[[TO_A_SELF]], #[[A_TO_B:.+]], #[[TO_C_SELF]]
; CHECK-DAG: #[[TO_A_SELF]] = #llvm.di_derived_type<{{.*}}name = "->A", {{.*}}baseType = #[[A_SELF:.+]]>
; CHECK-DAG: #[[TO_B_SELF]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_SELF:.+]]>
>From 5ccbf4bc13ac48240d9fcb14411617809e527304 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 13:58:34 -0700
Subject: [PATCH 3/5] typo & comments
---
mlir/unittests/Support/CyclicReplacerCacheTest.cpp | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
index 23748e29765cb..a4a92dbe147d4 100644
--- a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -195,17 +195,19 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
/// A Graph where edges that used to cause cycles (back-edges) are now
/// represented with an indirection (a recursionId).
///
- /// In addition to each node being an integer, each node also tracks the
- /// original integer id it had in the original graph. This way for every
+ /// In addition to each node having an integer ID, each node also tracks the
+ /// original integer ID it had in the original graph. This way for every
/// back-edge, we can represent it as pointing to a new instance of the
/// original node. Then we mark the original node and the new instance with
/// a new unique recursionId to indicate that they're supposed to be the same
- /// graph.
+ /// node.
struct PrunedGraph {
using Node = Graph::Node;
struct NodeInfo {
Graph::Node originalId;
- // A negative recursive index means not recursive.
+ /// A negative recursive index means not recursive. Otherwise nodes with
+ /// the same originalId & recursionId are the same node in the original
+ /// graph.
int64_t recursionId;
};
@@ -243,7 +245,7 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
Graph connections;
int64_t nextRecursionId = 0;
int64_t nextConnectionId = 0;
- // Use ordered map for deterministic output.
+ /// Use ordered map for deterministic output.
std::map<Graph::Node, NodeInfo> info;
};
>From a27223a7375afaaf5abd66c26d6b0f4fda372abf Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 11 Jul 2024 10:59:52 -0700
Subject: [PATCH 4/5] Apply suggestions from code review
Co-authored-by: Jeff Niu <jeff at modular.com>
---
mlir/include/mlir/Support/CyclicReplacerCache.h | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
index 9a703676fff11..7ad8c717199ec 100644
--- a/mlir/include/mlir/Support/CyclicReplacerCache.h
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -82,14 +82,14 @@ class CyclicReplacerCache {
}
/// Get the resolved result if one exists.
- std::optional<OutT> get() { return result; }
+ const std::optional<OutT> &get() { return result; }
private:
friend class CyclicReplacerCache;
CacheEntry() = delete;
CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
std::optional<OutT> result = std::nullopt)
- : cache(cache), element(element), result(result) {}
+ : cache(cache), element(std::move(element)), result(result) {}
CyclicReplacerCache<InT, OutT> &cache;
InT element;
@@ -153,7 +153,7 @@ CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
return CacheEntry(*this, element, it->second);
if (auto it = dependentCache.find(element); it != dependentCache.end()) {
- // pdate the current top frame (the element that invoked this current
+ // Update the current top frame (the element that invoked this current
// replacement) to include any dependencies the cache entry had.
ReplacementFrame &currFrame = replacementStack.back();
currFrame.dependentFrames.insert(it->second.highestDependentFrame);
>From 1abb63f81154de65757b25717a1e616b5b8971d1 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 11 Jul 2024 11:18:02 -0700
Subject: [PATCH 5/5] double down on trivial type support
---
.../mlir/Support/CyclicReplacerCache.h | 29 ++++++++++---------
.../Support/CyclicReplacerCacheTest.cpp | 4 +++
2 files changed, 20 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
index 7ad8c717199ec..42428c1507ffb 100644
--- a/mlir/include/mlir/Support/CyclicReplacerCache.h
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -12,8 +12,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_SUPPORT_CACHINGREPLACER_H
-#define MLIR_SUPPORT_CACHINGREPLACER_H
+#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
+#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H
#include "mlir/IR/Visitors.h"
#include "llvm/ADT/DenseSet.h"
@@ -43,13 +43,16 @@ namespace mlir {
/// any given cycle can perform pruning. Even if not, an assertion will
/// eventually be tripped instead of infinite recursion (the run-time is
/// linearly bounded by the maximum cycle length of its input).
+///
+/// WARNING: This class works best with InT & OutT that are trivial scalar
+/// types. The input/output elements will be frequently copied and hashed.
template <typename InT, typename OutT>
class CyclicReplacerCache {
public:
/// User-provided replacement function & cycle-breaking functions.
/// The cycle-breaking function must not make any more recursive invocations
/// to this cached replacer.
- using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>;
+ using CycleBreakerFn = std::function<std::optional<OutT>(InT)>;
CyclicReplacerCache() = delete;
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
@@ -77,12 +80,12 @@ class CyclicReplacerCache {
/// in the cache.
void resolve(OutT result) {
assert(!this->result && "cache entry already resolved");
- this->result = result;
cache.finalizeReplacement(element, result);
+ this->result = std::move(result);
}
/// Get the resolved result if one exists.
- const std::optional<OutT> &get() { return result; }
+ const std::optional<OutT> &get() const { return result; }
private:
friend class CyclicReplacerCache;
@@ -106,11 +109,11 @@ class CyclicReplacerCache {
/// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
/// the first retrieved CacheEntry must be resolved last. This should be
/// natural when used as a stack / inside recursion.
- CacheEntry lookupOrInit(const InT &element);
+ CacheEntry lookupOrInit(InT element);
private:
/// Register the replacement in the cache and update the replacementStack.
- void finalizeReplacement(const InT &element, const OutT &result);
+ void finalizeReplacement(InT element, OutT result);
CycleBreakerFn cycleBreaker;
DenseMap<InT, OutT> standaloneCache;
@@ -145,7 +148,7 @@ class CyclicReplacerCache {
template <typename InT, typename OutT>
typename CyclicReplacerCache<InT, OutT>::CacheEntry
-CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
+CyclicReplacerCache<InT, OutT>::lookupOrInit(InT element) {
assert(!resolvingCycle &&
"illegal recursive invocation while breaking cycle");
@@ -195,8 +198,8 @@ CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
}
template <typename InT, typename OutT>
-void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
- const OutT &result) {
+void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,
+ OutT result) {
ReplacementFrame &currFrame = replacementStack.back();
// With the conclusion of this replacement frame, the current element is no
// longer a dependent element.
@@ -249,7 +252,7 @@ void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
template <typename InT, typename OutT>
class CachedCyclicReplacer {
public:
- using ReplacerFn = std::function<OutT(const InT &)>;
+ using ReplacerFn = std::function<OutT(InT)>;
using CycleBreakerFn =
typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
@@ -257,7 +260,7 @@ class CachedCyclicReplacer {
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
: replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
- OutT operator()(const InT &element) {
+ OutT operator()(InT element) {
auto cacheEntry = cache.lookupOrInit(element);
if (std::optional<OutT> result = cacheEntry.get())
return *result;
@@ -274,4 +277,4 @@ class CachedCyclicReplacer {
} // namespace mlir
-#endif // MLIR_SUPPORT_CACHINGREPLACER_H
+#endif // MLIR_SUPPORT_CYCLICREPLACERCACHE_H
diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
index a4a92dbe147d4..ca02a3d692b2a 100644
--- a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -47,6 +47,7 @@ TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
/// infinitely long vector. The cycle-breaker function prunes this infinite
/// recursion in the replacer logic by returning an empty vector upon the first
/// re-occurrence of an input value.
+namespace {
class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
public:
// N ==> (N+1) % 3
@@ -71,6 +72,7 @@ class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
int invokeCount = 0;
std::optional<int> baseCase = std::nullopt;
};
+} // namespace
TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
// Starting at 0. Cycle length is 3.
@@ -128,6 +130,7 @@ TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
/// - PrunedGraph
/// - A Graph where edges that used to cause cycles are now represented with
/// an indirection (a recursionId).
+namespace {
class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
public:
/// A directed graph where nodes are non-negative integers.
@@ -317,6 +320,7 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
return oss.str();
}
};
+} // namespace
TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
// 0 -> 1 -> 2
More information about the Mlir-commits
mailing list