[Mlir-commits] [mlir] [MLIR] Cyclic AttrType Replacer (PR #98206)

Billy Zhu llvmlistbot at llvm.org
Thu Jul 11 22:14:19 PDT 2024


https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/98206

>From bcc3edaa9b80c1ee3e87ad15894382a065bef9ad Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 10:26:05 -0700
Subject: [PATCH 1/5] add replacer cache and gtests

---
 .../mlir/Support/CyclicReplacerCache.h        | 277 ++++++++++
 mlir/unittests/Support/CMakeLists.txt         |   1 +
 .../Support/CyclicReplacerCacheTest.cpp       | 472 ++++++++++++++++++
 3 files changed, 750 insertions(+)
 create mode 100644 mlir/include/mlir/Support/CyclicReplacerCache.h
 create mode 100644 mlir/unittests/Support/CyclicReplacerCacheTest.cpp

diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
new file mode 100644
index 0000000000000..9a703676fff11
--- /dev/null
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -0,0 +1,277 @@
+//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains helper classes for caching replacer-like functions that
+// map values between two domains. They are able to handle replacer logic that
+// contains self-recursion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_CACHINGREPLACER_H
+#define MLIR_SUPPORT_CACHINGREPLACER_H
+
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include <set>
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// CyclicReplacerCache
+//===----------------------------------------------------------------------===//
+
+/// A cache for replacer-like functions that map values between two domains. The
+/// difference compared to just using a map to cache in-out pairs is that this
+/// class is able to handle replacer logic that is self-recursive (and thus may
+/// cause infinite recursion in the naive case).
+///
+/// This class provides a hook for the user to perform cycle pruning when a
+/// cycle is identified, and is able to perform context-sensitive caching so
+/// that the replacement result for an input that is part of a pruned cycle can
+/// be distinct from the replacement result for the same input when it is not
+/// part of a cycle.
+///
+/// In addition, this class allows deferring cycle pruning until specific inputs
+/// are repeated. This is useful for cases where not all elements in a cycle can
+/// perform pruning. The user still must guarantee that at least one element in
+/// any given cycle can perform pruning. Even if not, an assertion will
+/// eventually be tripped instead of infinite recursion (the run-time is
+/// linearly bounded by the maximum cycle length of its input).
+template <typename InT, typename OutT>
+class CyclicReplacerCache {
+public:
+  /// User-provided replacement function & cycle-breaking functions.
+  /// The cycle-breaking function must not make any more recursive invocations
+  /// to this cached replacer.
+  using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>;
+
+  CyclicReplacerCache() = delete;
+  CyclicReplacerCache(CycleBreakerFn cycleBreaker)
+      : cycleBreaker(std::move(cycleBreaker)) {}
+
+  /// A possibly unresolved cache entry.
+  /// If unresolved, the entry must be resolved before it goes out of scope.
+  struct CacheEntry {
+  public:
+    ~CacheEntry() { assert(result && "unresovled cache entry"); }
+
+    /// Check whether this node was repeated during recursive replacements.
+    /// This only makes sense to be called after all recursive replacements are
+    /// completed and the current element has resurfaced to the top of the
+    /// replacement stack.
+    bool wasRepeated() const {
+      // If the top frame includes itself as a dependency, then it must have
+      // been repeated.
+      ReplacementFrame &currFrame = cache.replacementStack.back();
+      size_t currFrameIndex = cache.replacementStack.size() - 1;
+      return currFrame.dependentFrames.count(currFrameIndex);
+    }
+
+    /// Resolve an unresolved cache entry by providing the result to be stored
+    /// in the cache.
+    void resolve(OutT result) {
+      assert(!this->result && "cache entry already resolved");
+      this->result = result;
+      cache.finalizeReplacement(element, result);
+    }
+
+    /// Get the resolved result if one exists.
+    std::optional<OutT> get() { return result; }
+
+  private:
+    friend class CyclicReplacerCache;
+    CacheEntry() = delete;
+    CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
+               std::optional<OutT> result = std::nullopt)
+        : cache(cache), element(element), result(result) {}
+
+    CyclicReplacerCache<InT, OutT> &cache;
+    InT element;
+    std::optional<OutT> result;
+  };
+
+  /// Lookup the cache for a pre-calculated replacement for `element`.
+  /// If one exists, a resolved CacheEntry will be returned. Otherwise, an
+  /// unresolved CacheEntry will be returned, and the caller must resolve it
+  /// with the calculated replacement so it can be registered in the cache for
+  /// future use.
+  /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
+  /// CacheEntries that are returned must be resolved in reverse order of
+  /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
+  /// the first retrieved CacheEntry must be resolved last. This should be
+  /// natural when used as a stack / inside recursion.
+  CacheEntry lookupOrInit(const InT &element);
+
+private:
+  /// Register the replacement in the cache and update the replacementStack.
+  void finalizeReplacement(const InT &element, const OutT &result);
+
+  CycleBreakerFn cycleBreaker;
+  DenseMap<InT, OutT> standaloneCache;
+
+  struct DependentReplacement {
+    OutT replacement;
+    /// The highest replacement frame index that this cache entry is dependent
+    /// on.
+    size_t highestDependentFrame;
+  };
+  DenseMap<InT, DependentReplacement> dependentCache;
+
+  struct ReplacementFrame {
+    /// The set of elements that is only legal while under this current frame.
+    /// They need to be removed from the cache when this frame is popped off the
+    /// replacement stack.
+    DenseSet<InT> dependingReplacements;
+    /// The set of frame indices that this current frame's replacement is
+    /// dependent on, ordered from highest to lowest.
+    std::set<size_t, std::greater<size_t>> dependentFrames;
+  };
+  /// Every element currently in the progress of being replaced pushes a frame
+  /// onto this stack.
+  SmallVector<ReplacementFrame> replacementStack;
+  /// Maps from each input element to its indices on the replacement stack.
+  DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame;
+  /// If set to true, we are currently asking an element to break a cycle. No
+  /// more recursive invocations is allowed while this is true (the replacement
+  /// stack can no longer grow).
+  bool resolvingCycle = false;
+};
+
+template <typename InT, typename OutT>
+typename CyclicReplacerCache<InT, OutT>::CacheEntry
+CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
+  assert(!resolvingCycle &&
+         "illegal recursive invocation while breaking cycle");
+
+  if (auto it = standaloneCache.find(element); it != standaloneCache.end())
+    return CacheEntry(*this, element, it->second);
+
+  if (auto it = dependentCache.find(element); it != dependentCache.end()) {
+    // pdate the current top frame (the element that invoked this current
+    // replacement) to include any dependencies the cache entry had.
+    ReplacementFrame &currFrame = replacementStack.back();
+    currFrame.dependentFrames.insert(it->second.highestDependentFrame);
+    return CacheEntry(*this, element, it->second.replacement);
+  }
+
+  auto [it, inserted] = cyclicElementFrame.try_emplace(element);
+  if (!inserted) {
+    // This is a repeat of a known element. Try to break cycle here.
+    resolvingCycle = true;
+    std::optional<OutT> result = cycleBreaker(element);
+    resolvingCycle = false;
+    if (result) {
+      // Cycle was broken.
+      size_t dependentFrame = it->second.back();
+      dependentCache[element] = {*result, dependentFrame};
+      ReplacementFrame &currFrame = replacementStack.back();
+      // If this is a repeat, there is no replacement frame to pop. Mark the top
+      // frame as being dependent on this element.
+      currFrame.dependentFrames.insert(dependentFrame);
+
+      return CacheEntry(*this, element, *result);
+    }
+
+    // Cycle could not be broken.
+    // A legal setup must ensure at least one element of each cycle can break
+    // cycles. Under this setup, each element can be seen at most twice before
+    // the cycle is broken. If we see an element more than twice, we know this
+    // is an illegal setup.
+    assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
+  }
+
+  // Otherwise, either this is the first time we see this element, or this
+  // element could not break this cycle.
+  it->second.push_back(replacementStack.size());
+  replacementStack.emplace_back();
+
+  return CacheEntry(*this, element);
+}
+
+template <typename InT, typename OutT>
+void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
+                                                         const OutT &result) {
+  ReplacementFrame &currFrame = replacementStack.back();
+  // With the conclusion of this replacement frame, the current element is no
+  // longer a dependent element.
+  currFrame.dependentFrames.erase(replacementStack.size() - 1);
+
+  auto prevLayerIter = ++replacementStack.rbegin();
+  if (prevLayerIter == replacementStack.rend()) {
+    // If this is the last frame, there should be zero dependents.
+    assert(currFrame.dependentFrames.empty() &&
+           "internal error: top-level dependent replacement");
+    // Cache standalone result.
+    standaloneCache[element] = result;
+  } else if (currFrame.dependentFrames.empty()) {
+    // Cache standalone result.
+    standaloneCache[element] = result;
+  } else {
+    // Cache dependent result.
+    size_t highestDependentFrame = *currFrame.dependentFrames.begin();
+    dependentCache[element] = {result, highestDependentFrame};
+
+    // Otherwise, the previous frame inherits the same dependent frames.
+    prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
+                                          currFrame.dependentFrames.end());
+
+    // Mark this current replacement as a depending replacement on the closest
+    // dependent frame.
+    replacementStack[highestDependentFrame].dependingReplacements.insert(
+        element);
+  }
+
+  // All depending replacements in the cache must be purged.
+  for (InT key : currFrame.dependingReplacements)
+    dependentCache.erase(key);
+
+  replacementStack.pop_back();
+  auto it = cyclicElementFrame.find(element);
+  it->second.pop_back();
+  if (it->second.empty())
+    cyclicElementFrame.erase(it);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer
+//===----------------------------------------------------------------------===//
+
+/// A helper class for cases where the input/output types of the replacer
+/// function is identical to the types stored in the cache. This class wraps
+/// the user-provided replacer function, and can be used in place of the user
+/// function.
+template <typename InT, typename OutT>
+class CachedCyclicReplacer {
+public:
+  using ReplacerFn = std::function<OutT(const InT &)>;
+  using CycleBreakerFn =
+      typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
+
+  CachedCyclicReplacer() = delete;
+  CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
+      : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
+
+  OutT operator()(const InT &element) {
+    auto cacheEntry = cache.lookupOrInit(element);
+    if (std::optional<OutT> result = cacheEntry.get())
+      return *result;
+
+    OutT result = replacer(element);
+    cacheEntry.resolve(result);
+    return result;
+  }
+
+private:
+  ReplacerFn replacer;
+  CyclicReplacerCache<InT, OutT> cache;
+};
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_CACHINGREPLACER_H
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 1dbf072bcbbfd..ec79a1c640909 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIRSupportTests
+  CyclicReplacerCacheTest.cpp
   IndentedOstreamTest.cpp
   StorageUniquerTest.cpp
 )
diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
new file mode 100644
index 0000000000000..23748e29765cb
--- /dev/null
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -0,0 +1,472 @@
+//===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/CyclicReplacerCache.h"
+#include "llvm/ADT/SetVector.h"
+#include "gmock/gmock.h"
+#include <map>
+#include <set>
+
+using namespace mlir;
+
+TEST(CachedCyclicReplacerTest, testNoRecursion) {
+  CachedCyclicReplacer<int, bool> replacer(
+      /*replacer=*/[](int n) { return static_cast<bool>(n); },
+      /*cycleBreaker=*/[](int n) { return std::nullopt; });
+
+  EXPECT_EQ(replacer(3), true);
+  EXPECT_EQ(replacer(0), false);
+}
+
+TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
+  // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
+  CachedCyclicReplacer<int, int> replacer(
+      /*replacer=*/[&](int n) { return replacer((n + 1) % 3); },
+      /*cycleBreaker=*/[&](int n) { return -1; });
+
+  // Starting at 0.
+  EXPECT_EQ(replacer(0), -1);
+  // Starting at 2.
+  EXPECT_EQ(replacer(2), -1);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+/// This set of tests uses a replacer function that maps ints into vectors of
+/// ints.
+///
+/// The replacement result for input `n` is the replacement result of `(n+1)%3`
+/// appended with an element `42`. Theoretically, this will produce an
+/// infinitely long vector. The cycle-breaker function prunes this infinite
+/// recursion in the replacer logic by returning an empty vector upon the first
+/// re-occurrence of an input value.
+class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+  // N ==> (N+1) % 3
+  // This will create a chain of infinite length without recursion pruning.
+  CachedCyclicReplacerChainRecursionPruningTest()
+      : replacer(
+            [&](int n) {
+              ++invokeCount;
+              std::vector<int> result = replacer((n + 1) % 3);
+              result.push_back(42);
+              return result;
+            },
+            [&](int n) -> std::optional<std::vector<int>> {
+              return baseCase.value_or(n) == n
+                         ? std::make_optional(std::vector<int>{})
+                         : std::nullopt;
+            }) {}
+
+  std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
+
+  CachedCyclicReplacer<int, std::vector<int>> replacer;
+  int invokeCount = 0;
+  std::optional<int> baseCase = std::nullopt;
+};
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+
+  // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(1), getChain(5));
+  EXPECT_EQ(invokeCount, 2);
+
+  // Starting at 2. Cycle length is 4. Entire result is cached.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(2), getChain(4));
+  EXPECT_EQ(invokeCount, 0);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
+  // Starting at 1. Cycle length is 3.
+  EXPECT_EQ(replacer(1), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
+  baseCase = 0;
+
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
+  baseCase = 0;
+
+  // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+  EXPECT_EQ(replacer(1), getChain(5));
+  EXPECT_EQ(invokeCount, 5);
+
+  // Starting at 0. Cycle length is 3. Entire result is cached.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 0);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: GraphReplacement
+//===----------------------------------------------------------------------===//
+
+/// This set of tests uses a replacer function that maps from cyclic graphs to
+/// trees, pruning out cycles in the process.
+///
+/// It consists of two helper classes:
+/// - Graph
+///   - A directed graph where nodes are non-negative integers.
+/// - PrunedGraph
+///   - A Graph where edges that used to cause cycles are now represented with
+///     an indirection (a recursionId).
+class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
+public:
+  /// A directed graph where nodes are non-negative integers.
+  struct Graph {
+    using Node = int64_t;
+
+    /// Use ordered containers for deterministic output.
+    /// Nodes without outgoing edges are considered nonexistent.
+    std::map<Node, std::set<Node>> edges;
+
+    void addEdge(Node src, Node sink) { edges[src].insert(sink); }
+
+    bool isCyclic() const {
+      DenseSet<Node> visited;
+      for (Node root : llvm::make_first_range(edges)) {
+        if (visited.contains(root))
+          continue;
+
+        SetVector<Node> path;
+        SmallVector<Node> workstack;
+        workstack.push_back(root);
+        while (!workstack.empty()) {
+          Node curr = workstack.back();
+          workstack.pop_back();
+
+          if (curr < 0) {
+            // A negative node signals the end of processing all of this node's
+            // children. Remove self from path.
+            assert(path.back() == -curr && "internal inconsistency");
+            path.pop_back();
+            continue;
+          }
+
+          if (path.contains(curr))
+            return true;
+
+          visited.insert(curr);
+          auto edgesIter = edges.find(curr);
+          if (edgesIter == edges.end() || edgesIter->second.empty())
+            continue;
+
+          path.insert(curr);
+          // Push negative node to signify recursion return.
+          workstack.push_back(-curr);
+          workstack.insert(workstack.end(), edgesIter->second.begin(),
+                           edgesIter->second.end());
+        }
+      }
+      return false;
+    }
+
+    /// Deterministic output for testing.
+    std::string serialize() const {
+      std::ostringstream oss;
+      for (const auto &[src, neighbors] : edges) {
+        oss << src << ":";
+        for (Graph::Node neighbor : neighbors)
+          oss << " " << neighbor;
+        oss << "\n";
+      }
+      return oss.str();
+    }
+  };
+
+  /// A Graph where edges that used to cause cycles (back-edges) are now
+  /// represented with an indirection (a recursionId).
+  ///
+  /// In addition to each node being an integer, each node also tracks the
+  /// original integer id it had in the original graph. This way for every
+  /// back-edge, we can represent it as pointing to a new instance of the
+  /// original node. Then we mark the original node and the new instance with
+  /// a new unique recursionId to indicate that they're supposed to be the same
+  /// graph.
+  struct PrunedGraph {
+    using Node = Graph::Node;
+    struct NodeInfo {
+      Graph::Node originalId;
+      // A negative recursive index means not recursive.
+      int64_t recursionId;
+    };
+
+    /// Add a regular non-recursive-self node.
+    Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
+      Node id = nextConnectionId++;
+      info[id] = {originalId, recursionIndex};
+      return id;
+    }
+    /// Add a recursive-self-node, i.e. a duplicate of the original node that is
+    /// meant to represent an indirection to it.
+    std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) {
+      return {addNode(originalId, nextRecursionId), nextRecursionId++};
+    }
+    void addEdge(Node src, Node sink) { connections.addEdge(src, sink); }
+
+    /// Deterministic output for testing.
+    std::string serialize() const {
+      std::ostringstream oss;
+      oss << "nodes\n";
+      for (const auto &[nodeId, nodeInfo] : info) {
+        oss << nodeId << ": n" << nodeInfo.originalId;
+        if (nodeInfo.recursionId >= 0)
+          oss << '<' << nodeInfo.recursionId << '>';
+        oss << "\n";
+      }
+      oss << "edges\n";
+      oss << connections.serialize();
+      return oss.str();
+    }
+
+    bool isCyclic() const { return connections.isCyclic(); }
+
+  private:
+    Graph connections;
+    int64_t nextRecursionId = 0;
+    int64_t nextConnectionId = 0;
+    // Use ordered map for deterministic output.
+    std::map<Graph::Node, NodeInfo> info;
+  };
+
+  PrunedGraph breakCycles(const Graph &input) {
+    assert(input.isCyclic() && "input graph is not cyclic");
+
+    PrunedGraph output;
+
+    DenseMap<Graph::Node, int64_t> recMap;
+    auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> {
+      auto [node, recId] = output.addRecursiveSelfNode(inNode);
+      recMap[inNode] = recId;
+      return node;
+    };
+
+    CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker);
+
+    std::function<Graph::Node(Graph::Node)> replaceNode =
+        [&](Graph::Node inNode) {
+          auto cacheEntry = cache.lookupOrInit(inNode);
+          if (std::optional<Graph::Node> result = cacheEntry.get())
+            return *result;
+
+          // Recursively replace its neighbors.
+          SmallVector<Graph::Node> neighbors;
+          if (auto it = input.edges.find(inNode); it != input.edges.end())
+            neighbors = SmallVector<Graph::Node>(
+                llvm::map_range(it->second, replaceNode));
+
+          // Create a new node in the output graph.
+          int64_t recursionIndex =
+              cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1;
+          Graph::Node result = output.addNode(inNode, recursionIndex);
+
+          for (Graph::Node neighbor : neighbors)
+            output.addEdge(result, neighbor);
+
+          cacheEntry.resolve(result);
+          return result;
+        };
+
+    /// Translate starting from each node.
+    for (Graph::Node root : llvm::make_first_range(input.edges))
+      replaceNode(root);
+
+    return output;
+  }
+
+  /// Helper for serialization tests that allow putting comments in the
+  /// serialized format. Every line that begins with a `;` is considered a
+  /// comment. The entire line, incl. the terminating `\n` is removed.
+  std::string trimComments(StringRef input) {
+    std::ostringstream oss;
+    bool isNewLine = false;
+    bool isComment = false;
+    for (char c : input) {
+      // Lines beginning with ';' are comments.
+      if (isNewLine && c == ';')
+        isComment = true;
+
+      if (!isComment)
+        oss << c;
+
+      if (c == '\n') {
+        isNewLine = true;
+        isComment = false;
+      }
+    }
+    return oss.str();
+  }
+};
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
+  // 0 -> 1 -> 2
+  // ^         |
+  // +---------+
+  Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}};
+  PrunedGraph output = breakCycles(input);
+  ASSERT_FALSE(output.isCyclic()) << output.serialize();
+  EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n2
+2: n1
+3: n0<0>
+; root 1
+4: n2
+; root 2
+5: n1
+edges
+1: 0
+2: 1
+3: 2
+4: 3
+5: 4
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) {
+  // +----> 1 -----+
+  // |             v
+  // 0 <---------- 3
+  // |             ^
+  // +----> 2 -----+
+  //
+  // Two loops:
+  // 0 -> 1 -> 3 -> 0
+  // 0 -> 2 -> 3 -> 0
+  Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
+  PrunedGraph output = breakCycles(input);
+  ASSERT_FALSE(output.isCyclic()) << output.serialize();
+  EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n3
+2: n1
+3: n2
+4: n0<0>
+; root 1
+5: n3
+6: n1
+; root 2
+7: n2
+edges
+1: 0
+2: 1
+3: 1
+4: 2 3
+5: 4
+6: 5
+7: 5
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) {
+  // +----> 1 -----+
+  // |      ^      v
+  // 0 <----+----- 2
+  //
+  // Two nested loops:
+  // 0 -> 1 -> 2 -> 0
+  //      1 -> 2 -> 1
+  Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}};
+  PrunedGraph output = breakCycles(input);
+  ASSERT_FALSE(output.isCyclic()) << output.serialize();
+  EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n1<1>
+2: n2
+3: n1<1>
+4: n0<0>
+; root 1
+5: n1<2>
+6: n2
+7: n1<2>
+; root 2
+8: n2
+edges
+2: 0 1
+3: 2
+4: 3
+6: 4 5
+7: 6
+8: 4 7
+)"));
+}
+
+TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) {
+  // +----> 1 -----+
+  // |      ^      v
+  // 0 <----+----- 3
+  // |      v      ^
+  // +----> 2 -----+
+  //
+  // Two sets of nested loops:
+  // 0 -> 1 -> 3 -> 0
+  //      1 -> 3 -> 1
+  // 0 -> 2 -> 3 -> 0
+  //      2 -> 3 -> 2
+  Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
+  PrunedGraph output = breakCycles(input);
+  ASSERT_FALSE(output.isCyclic()) << output.serialize();
+  EXPECT_EQ(output.serialize(), trimComments(R"(nodes
+; root 0
+0: n0<0>
+1: n1<1>
+2: n3<2>
+3: n2
+4: n3<2>
+5: n1<1>
+6: n2<3>
+7: n3
+8: n2<3>
+9: n0<0>
+; root 1
+10: n1<4>
+11: n3<5>
+12: n2
+13: n3<5>
+14: n1<4>
+; root 2
+15: n2<6>
+16: n3
+17: n2<6>
+; root 3
+18: n3
+edges
+; root 0
+3: 2
+4: 0 1 3
+5: 4
+7: 0 5 6
+8: 7
+9: 5 8
+; root 1
+12: 11
+13: 9 10 12
+14: 13
+; root 2
+16: 9 14 15
+17: 16
+; root 3
+18: 9 14 17
+)"));
+}

>From 8d1dd886c4a80507c8a97dda15e91acbfa7c3619 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 10:27:13 -0700
Subject: [PATCH 2/5] refactor attrtype replacers & add tests

---
 mlir/include/mlir/IR/AttrTypeSubElements.h | 138 ++++++++++--
 mlir/lib/IR/AttrTypeSubElements.cpp        | 146 ++++++++++---
 mlir/unittests/IR/AttrTypeReplacerTest.cpp | 231 +++++++++++++++++++++
 mlir/unittests/IR/CMakeLists.txt           |   1 +
 4 files changed, 467 insertions(+), 49 deletions(-)
 create mode 100644 mlir/unittests/IR/AttrTypeReplacerTest.cpp

diff --git a/mlir/include/mlir/IR/AttrTypeSubElements.h b/mlir/include/mlir/IR/AttrTypeSubElements.h
index 3105040b87631..234767deea00a 100644
--- a/mlir/include/mlir/IR/AttrTypeSubElements.h
+++ b/mlir/include/mlir/IR/AttrTypeSubElements.h
@@ -16,6 +16,7 @@
 
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Visitors.h"
+#include "mlir/Support/CyclicReplacerCache.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include <optional>
@@ -116,9 +117,21 @@ class AttrTypeWalker {
 /// AttrTypeReplacer
 //===----------------------------------------------------------------------===//
 
-/// This class provides a utility for replacing attributes/types, and their sub
-/// elements. Multiple replacement functions may be registered.
-class AttrTypeReplacer {
+namespace detail {
+
+/// This class provides a base utility for replacing attributes/types, and their
+/// sub elements. Multiple replacement functions may be registered.
+///
+/// This base utility is uncached. Users can choose between two cached versions
+/// of this replacer:
+///   * For non-cyclic replacer logic, use `AttrTypeReplacer`.
+///   * For cyclic replacer logic, use `CyclicAttrTypeReplacer`.
+///
+/// Concrete implementations implement the following `replace` entry functions:
+///   * Attribute replace(Attribute attr);
+///   * Type replace(Type type);
+template <typename Concrete>
+class AttrTypeReplacerBase {
 public:
   //===--------------------------------------------------------------------===//
   // Application
@@ -139,12 +152,6 @@ class AttrTypeReplacer {
                                     bool replaceLocs = false,
                                     bool replaceTypes = false);
 
-  /// Replace the given attribute/type, and recursively replace any sub
-  /// elements. Returns either the new attribute/type, or nullptr in the case of
-  /// failure.
-  Attribute replace(Attribute attr);
-  Type replace(Type type);
-
   //===--------------------------------------------------------------------===//
   // Registration
   //===--------------------------------------------------------------------===//
@@ -206,21 +213,114 @@ class AttrTypeReplacer {
     });
   }
 
-private:
-  /// Internal implementation of the `replace` methods above.
-  template <typename T, typename ReplaceFns>
-  T replaceImpl(T element, ReplaceFns &replaceFns);
-
-  /// Replace the sub elements of the given interface.
-  template <typename T>
-  T replaceSubElements(T interface);
+protected:
+  /// Invokes the registered replacement functions from most recently registered
+  /// to least recently registered until a successful replacement is returned.
+  /// Unless skipping is requested, invokes `replace` on sub-elements of the
+  /// current attr/type.
+  Attribute replaceBase(Attribute attr);
+  Type replaceBase(Type type);
 
+private:
   /// The set of replacement functions that map sub elements.
   std::vector<ReplaceFn<Attribute>> attrReplacementFns;
   std::vector<ReplaceFn<Type>> typeReplacementFns;
+};
+
+} // namespace detail
+
+/// This is an attribute/type replacer that is naively cached. It is best used
+/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any
+/// re-occurrence of an in-progress element will be skipped.
+class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> {
+public:
+  Attribute replace(Attribute attr);
+  Type replace(Type type);
+
+private:
+  /// Shared concrete implementation of the public `replace` functions. Invokes
+  /// `replaceBase` with caching.
+  template <typename T>
+  T cachedReplaceImpl(T element);
+
+  // Stores the opaque pointer of an attribute or type.
+  DenseMap<const void *, const void *> cache;
+};
+
+/// This is an attribute/type replacer that supports custom handling of cycles
+/// in the replacer logic. In addition to registering replacer functions, it
+/// allows registering cycle-breaking functions in the same style.
+class CyclicAttrTypeReplacer
+    : public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> {
+public:
+  CyclicAttrTypeReplacer();
 
-  /// The set of cached mappings for attributes/types.
-  DenseMap<const void *, const void *> attrTypeMap;
+  //===--------------------------------------------------------------------===//
+  // Application
+  //===--------------------------------------------------------------------===//
+
+  Attribute replace(Attribute attr);
+  Type replace(Type type);
+
+  //===--------------------------------------------------------------------===//
+  // Registration
+  //===--------------------------------------------------------------------===//
+
+  /// A cycle-breaking function. This is invoked if the same element is asked to
+  /// be replaced again when the first instance of it is still being replaced.
+  /// This function must not perform any more recursive `replace` calls.
+  /// If it is able to break the cycle, it should return a replacement result.
+  /// Otherwise, it can return std::nullopt to defer cycle breaking to the next
+  /// repeated element. However, the user must guarantee that, in any possible
+  /// cycle, there always exists at least one element that can break the cycle.
+  template <typename T>
+  using CycleBreakerFn = std::function<std::optional<T>(T)>;
+
+  /// Register a cycle-breaking function.
+  /// When breaking cycles, the mostly recently added cycle-breaking functions
+  /// will be invoked first.
+  void addCycleBreaker(CycleBreakerFn<Attribute> fn);
+  void addCycleBreaker(CycleBreakerFn<Type> fn);
+
+  /// Register a cycle-breaking function that doesn't match the default
+  /// signature.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<
+                std::decay_t<FnT>>::template arg_t<0>,
+            typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+                                                Attribute, Type>>
+  std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) {
+    addCycleBreaker([callback = std::forward<FnT>(callback)](
+                        BaseT base) -> std::optional<BaseT> {
+      if (auto derived = dyn_cast<T>(base))
+        return callback(derived);
+      return std::nullopt;
+    });
+  }
+
+private:
+  /// Invokes the registered cycle-breaker functions from most recently
+  /// registered to least recently registered until a successful result is
+  /// returned.
+  std::optional<const void *> breakCycleImpl(void *element);
+
+  /// Shared concrete implementation of the public `replace` functions.
+  template <typename T>
+  T cachedReplaceImpl(T element);
+
+  /// The set of registered cycle-breaker functions.
+  std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns;
+  std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns;
+
+  /// A cache of previously-replaced attr/types.
+  /// The key of the cache is the opaque value of an AttrOrType. Using
+  /// AttrOrType allows distinguishing between the two types when invoking
+  /// cycle-breakers. Using its opaque value avoids the cyclic dependency issue
+  /// of directly using `AttrOrType` to instantiate the cache.
+  /// The value of the cache is just the opaque value of the attr/type itself
+  /// (not the PointerUnion).
+  using AttrOrType = PointerUnion<Attribute, Type>;
+  CyclicReplacerCache<void *, const void *> cache;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AttrTypeSubElements.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp
index 79b04966be6eb..783236ed3a9df 100644
--- a/mlir/lib/IR/AttrTypeSubElements.cpp
+++ b/mlir/lib/IR/AttrTypeSubElements.cpp
@@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
 }
 
 //===----------------------------------------------------------------------===//
-/// AttrTypeReplacer
+/// AttrTypeReplacerBase
 //===----------------------------------------------------------------------===//
 
-void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
+    ReplaceFn<Attribute> fn) {
   attrReplacementFns.emplace_back(std::move(fn));
 }
-void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {
+
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
+    ReplaceFn<Type> fn) {
   typeReplacementFns.push_back(std::move(fn));
 }
 
-void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
-                                         bool replaceLocs, bool replaceTypes) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn(
+    Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
   // Functor that replaces the given element if the new value is different,
   // otherwise returns nullptr.
   auto replaceIfDifferent = [&](auto element) {
-    auto replacement = replace(element);
+    auto replacement = static_cast<Concrete *>(this)->replace(element);
     return (replacement && replacement != element) ? replacement : nullptr;
   };
 
@@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
   }
 }
 
-void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
-                                                    bool replaceAttrs,
-                                                    bool replaceLocs,
-                                                    bool replaceTypes) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn(
+    Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
   op->walk([&](Operation *nestedOp) {
     replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
   });
 }
 
-template <typename T>
-static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
+template <typename T, typename Replacer>
+static void updateSubElementImpl(T element, Replacer &replacer,
                                  SmallVectorImpl<T> &newElements,
                                  FailureOr<bool> &changed) {
   // Bail early if we failed at any point.
@@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
   }
 }
 
-template <typename T>
-T AttrTypeReplacer::replaceSubElements(T interface) {
+template <typename T, typename Replacer>
+static T replaceSubElements(T interface, Replacer &replacer) {
   // Walk the current sub-elements, replacing them as necessary.
   SmallVector<Attribute, 16> newAttrs;
   SmallVector<Type, 16> newTypes;
   FailureOr<bool> changed = false;
   interface.walkImmediateSubElements(
       [&](Attribute element) {
-        updateSubElementImpl(element, *this, newAttrs, changed);
+        updateSubElementImpl(element, replacer, newAttrs, changed);
       },
       [&](Type element) {
-        updateSubElementImpl(element, *this, newTypes, changed);
+        updateSubElementImpl(element, replacer, newTypes, changed);
       });
   if (failed(changed))
     return nullptr;
@@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) {
 }
 
 /// Shared implementation of replacing a given attribute or type element.
-template <typename T, typename ReplaceFns>
-T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
-  const void *opaqueElement = element.getAsOpaquePointer();
-  auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
-  if (!inserted)
-    return T::getFromOpaquePointer(it->second);
-
+template <typename T, typename ReplaceFns, typename Replacer>
+static T replaceElementImpl(T element, ReplaceFns &replaceFns,
+                            Replacer &replacer) {
   T result = element;
   WalkResult walkResult = WalkResult::advance();
   for (auto &replaceFn : llvm::reverse(replaceFns)) {
@@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
 
   // If an error occurred, return nullptr to indicate failure.
   if (walkResult.wasInterrupted() || !result) {
-    attrTypeMap[opaqueElement] = nullptr;
     return nullptr;
   }
 
   // Handle replacing sub-elements if this element is also a container.
   if (!walkResult.wasSkipped()) {
     // Replace the sub elements of this element, bailing if we fail.
-    if (!(result = replaceSubElements(result))) {
-      attrTypeMap[opaqueElement] = nullptr;
+    if (!(result = replaceSubElements(result, replacer))) {
       return nullptr;
     }
   }
 
-  attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
+  return result;
+}
+
+template <typename Concrete>
+Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) {
+  return replaceElementImpl(attr, attrReplacementFns,
+                            *static_cast<Concrete *>(this));
+}
+
+template <typename Concrete>
+Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) {
+  return replaceElementImpl(type, typeReplacementFns,
+                            *static_cast<Concrete *>(this));
+}
+
+//===----------------------------------------------------------------------===//
+/// AttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+template class detail::AttrTypeReplacerBase<AttrTypeReplacer>;
+
+template <typename T>
+T AttrTypeReplacer::cachedReplaceImpl(T element) {
+  const void *opaqueElement = element.getAsOpaquePointer();
+  auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
+  if (!inserted)
+    return T::getFromOpaquePointer(it->second);
+
+  T result = replaceBase(element);
+
+  cache[opaqueElement] = result.getAsOpaquePointer();
   return result;
 }
 
 Attribute AttrTypeReplacer::replace(Attribute attr) {
-  return replaceImpl(attr, attrReplacementFns);
+  return cachedReplaceImpl(attr);
 }
 
-Type AttrTypeReplacer::replace(Type type) {
-  return replaceImpl(type, typeReplacementFns);
+Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }
+
+//===----------------------------------------------------------------------===//
+/// CyclicAttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>;
+
+CyclicAttrTypeReplacer::CyclicAttrTypeReplacer()
+    : cache([&](void *attr) { return breakCycleImpl(attr); }) {}
+
+void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) {
+  attrCycleBreakerFns.emplace_back(std::move(fn));
+}
+
+void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) {
+  typeCycleBreakerFns.emplace_back(std::move(fn));
+}
+
+template <typename T>
+T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
+  void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
+  CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry =
+      cache.lookupOrInit(opaqueTaggedElement);
+  if (auto resultOpt = cacheEntry.get())
+    return T::getFromOpaquePointer(*resultOpt);
+
+  T result = replaceBase(element);
+
+  cacheEntry.resolve(result.getAsOpaquePointer());
+  return result;
+}
+
+Attribute CyclicAttrTypeReplacer::replace(Attribute attr) {
+  return cachedReplaceImpl(attr);
+}
+
+Type CyclicAttrTypeReplacer::replace(Type type) {
+  return cachedReplaceImpl(type);
+}
+
+std::optional<const void *>
+CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
+  AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
+  if (auto attr = dyn_cast<Attribute>(attrType)) {
+    for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
+      if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
+        return newRes->getAsOpaquePointer();
+      }
+    }
+  } else {
+    auto type = dyn_cast<Type>(attrType);
+    for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
+      if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
+        return newRes->getAsOpaquePointer();
+      }
+    }
+  }
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/IR/AttrTypeReplacerTest.cpp b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
new file mode 100644
index 0000000000000..c7b42eb267c7a
--- /dev/null
+++ b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
@@ -0,0 +1,231 @@
+//===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AttrTypeSubElements.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+TEST(CyclicAttrTypeReplacerTest, testNoRecursion) {
+  MLIRContext ctx;
+
+  CyclicAttrTypeReplacer replacer;
+  replacer.addReplacement([&](BoolAttr b) {
+    return StringAttr::get(&ctx, b.getValue() ? "true" : "false");
+  });
+
+  EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)),
+            StringAttr::get(&ctx, "true"));
+  EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)),
+            StringAttr::get(&ctx, "false"));
+  EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
+            mlir::UnitAttr::get(&ctx));
+}
+
+TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) {
+  MLIRContext ctx;
+  Builder b(&ctx);
+
+  CyclicAttrTypeReplacer replacer;
+  // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ...
+  replacer.addReplacement([&](IntegerAttr attr) {
+    return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3));
+  });
+  // The first repeat of any integer attr is pruned into a unit attr.
+  replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); });
+
+  // No recursion case.
+  EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
+            mlir::UnitAttr::get(&ctx));
+  // Starting at 0.
+  EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx));
+  // Starting at 2.
+  EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx));
+}
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacerTest: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+  CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) {
+    // IntegerType<width = N>
+    // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>.
+    // This will create a chain of infinite length without recursion pruning.
+    replacer.addReplacement([&](mlir::IntegerType intType) {
+      ++invokeCount;
+      return b.getFunctionType(
+          {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)});
+    });
+  }
+
+  void setBaseCase(std::optional<unsigned> pruneAt) {
+    replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
+      return (!pruneAt || intType.getWidth() == *pruneAt)
+                 ? std::make_optional(b.getIndexType())
+                 : std::nullopt;
+    });
+  }
+
+  Type getFunctionTypeChain(unsigned N) {
+    Type type = b.getIndexType();
+    for (unsigned i = 0; i < N; i++)
+      type = b.getFunctionType({}, type);
+    return type;
+  };
+
+  MLIRContext ctx;
+  Builder b;
+  CyclicAttrTypeReplacer replacer;
+  int invokeCount = 0;
+};
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+  setBaseCase(std::nullopt);
+
+  // No recursion case.
+  EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
+  EXPECT_EQ(invokeCount, 0);
+
+  // Starting at 0. Cycle length is 3.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeChain(3));
+  EXPECT_EQ(invokeCount, 3);
+
+  // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeChain(5));
+  EXPECT_EQ(invokeCount, 2);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) {
+  setBaseCase(std::nullopt);
+
+  // Starting at 1. Cycle length is 3.
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) {
+  setBaseCase(0);
+
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) {
+  setBaseCase(0);
+
+  // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeChain(5));
+  EXPECT_EQ(invokeCount, 5);
+}
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacerTest: BranchingRecusion
+//===----------------------------------------------------------------------===//
+
+class CyclicAttrTypeReplacerBranchingRecusionPruningTest
+    : public ::testing::Test {
+public:
+  CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) {
+    // IntegerType<width = N>
+    // ==> FunctionType<
+    //       IntegerType< width = (N+1) % 3> =>
+    //         IntegerType< width = (N+1) % 3>>.
+    // This will create a binary tree of infinite depth without pruning.
+    replacer.addReplacement([&](mlir::IntegerType intType) {
+      ++invokeCount;
+      Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3);
+      return b.getFunctionType({child}, {child});
+    });
+  }
+
+  void setBaseCase(std::optional<unsigned> pruneAt) {
+    replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
+      return (!pruneAt || intType.getWidth() == *pruneAt)
+                 ? std::make_optional(b.getIndexType())
+                 : std::nullopt;
+    });
+  }
+
+  Type getFunctionTypeTree(unsigned N) {
+    Type type = b.getIndexType();
+    for (unsigned i = 0; i < N; i++)
+      type = b.getFunctionType(type, type);
+    return type;
+  };
+
+  MLIRContext ctx;
+  Builder b;
+  CyclicAttrTypeReplacer replacer;
+  int invokeCount = 0;
+};
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) {
+  setBaseCase(std::nullopt);
+
+  // No recursion case.
+  EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
+  EXPECT_EQ(invokeCount, 0);
+
+  // Starting at 0. Cycle length is 3.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeTree(3));
+  // Since both branches are identical, this should incur linear invocations
+  // of the replacement function instead of exponential.
+  EXPECT_EQ(invokeCount, 3);
+
+  // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeTree(5));
+  EXPECT_EQ(invokeCount, 2);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) {
+  setBaseCase(std::nullopt);
+
+  // Starting at 1. Cycle length is 3.
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeTree(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) {
+  setBaseCase(0);
+
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeTree(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) {
+  setBaseCase(0);
+
+  // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeTree(5));
+  EXPECT_EQ(invokeCount, 5);
+}
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 71f8f449756ec..05cb36e190316 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRIRTests
   AffineExprTest.cpp
   AffineMapTest.cpp
   AttributeTest.cpp
+  AttrTypeReplacerTest.cpp
   DialectTest.cpp
   InterfaceTest.cpp
   IRMapping.cpp

>From 5ccbf4bc13ac48240d9fcb14411617809e527304 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 13:58:34 -0700
Subject: [PATCH 3/5] typo & comments

---
 mlir/unittests/Support/CyclicReplacerCacheTest.cpp | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
index 23748e29765cb..a4a92dbe147d4 100644
--- a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -195,17 +195,19 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
   /// A Graph where edges that used to cause cycles (back-edges) are now
   /// represented with an indirection (a recursionId).
   ///
-  /// In addition to each node being an integer, each node also tracks the
-  /// original integer id it had in the original graph. This way for every
+  /// In addition to each node having an integer ID, each node also tracks the
+  /// original integer ID it had in the original graph. This way for every
   /// back-edge, we can represent it as pointing to a new instance of the
   /// original node. Then we mark the original node and the new instance with
   /// a new unique recursionId to indicate that they're supposed to be the same
-  /// graph.
+  /// node.
   struct PrunedGraph {
     using Node = Graph::Node;
     struct NodeInfo {
       Graph::Node originalId;
-      // A negative recursive index means not recursive.
+      /// A negative recursive index means not recursive. Otherwise nodes with
+      /// the same originalId & recursionId are the same node in the original
+      /// graph.
       int64_t recursionId;
     };
 
@@ -243,7 +245,7 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
     Graph connections;
     int64_t nextRecursionId = 0;
     int64_t nextConnectionId = 0;
-    // Use ordered map for deterministic output.
+    /// Use ordered map for deterministic output.
     std::map<Graph::Node, NodeInfo> info;
   };
 

>From a27223a7375afaaf5abd66c26d6b0f4fda372abf Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 11 Jul 2024 10:59:52 -0700
Subject: [PATCH 4/5] Apply suggestions from code review

Co-authored-by: Jeff Niu <jeff at modular.com>
---
 mlir/include/mlir/Support/CyclicReplacerCache.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
index 9a703676fff11..7ad8c717199ec 100644
--- a/mlir/include/mlir/Support/CyclicReplacerCache.h
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -82,14 +82,14 @@ class CyclicReplacerCache {
     }
 
     /// Get the resolved result if one exists.
-    std::optional<OutT> get() { return result; }
+    const std::optional<OutT> &get() { return result; }
 
   private:
     friend class CyclicReplacerCache;
     CacheEntry() = delete;
     CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
                std::optional<OutT> result = std::nullopt)
-        : cache(cache), element(element), result(result) {}
+        : cache(cache), element(std::move(element)), result(result) {}
 
     CyclicReplacerCache<InT, OutT> &cache;
     InT element;
@@ -153,7 +153,7 @@ CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
     return CacheEntry(*this, element, it->second);
 
   if (auto it = dependentCache.find(element); it != dependentCache.end()) {
-    // pdate the current top frame (the element that invoked this current
+    // Update the current top frame (the element that invoked this current
     // replacement) to include any dependencies the cache entry had.
     ReplacementFrame &currFrame = replacementStack.back();
     currFrame.dependentFrames.insert(it->second.highestDependentFrame);

>From 1abb63f81154de65757b25717a1e616b5b8971d1 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 11 Jul 2024 11:18:02 -0700
Subject: [PATCH 5/5] double down on trivial type support

---
 .../mlir/Support/CyclicReplacerCache.h        | 29 ++++++++++---------
 .../Support/CyclicReplacerCacheTest.cpp       |  4 +++
 2 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
index 7ad8c717199ec..42428c1507ffb 100644
--- a/mlir/include/mlir/Support/CyclicReplacerCache.h
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -12,8 +12,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_SUPPORT_CACHINGREPLACER_H
-#define MLIR_SUPPORT_CACHINGREPLACER_H
+#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
+#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H
 
 #include "mlir/IR/Visitors.h"
 #include "llvm/ADT/DenseSet.h"
@@ -43,13 +43,16 @@ namespace mlir {
 /// any given cycle can perform pruning. Even if not, an assertion will
 /// eventually be tripped instead of infinite recursion (the run-time is
 /// linearly bounded by the maximum cycle length of its input).
+///
+/// WARNING: This class works best with InT & OutT that are trivial scalar
+/// types. The input/output elements will be frequently copied and hashed.
 template <typename InT, typename OutT>
 class CyclicReplacerCache {
 public:
   /// User-provided replacement function & cycle-breaking functions.
   /// The cycle-breaking function must not make any more recursive invocations
   /// to this cached replacer.
-  using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>;
+  using CycleBreakerFn = std::function<std::optional<OutT>(InT)>;
 
   CyclicReplacerCache() = delete;
   CyclicReplacerCache(CycleBreakerFn cycleBreaker)
@@ -77,12 +80,12 @@ class CyclicReplacerCache {
     /// in the cache.
     void resolve(OutT result) {
       assert(!this->result && "cache entry already resolved");
-      this->result = result;
       cache.finalizeReplacement(element, result);
+      this->result = std::move(result);
     }
 
     /// Get the resolved result if one exists.
-    const std::optional<OutT> &get() { return result; }
+    const std::optional<OutT> &get() const { return result; }
 
   private:
     friend class CyclicReplacerCache;
@@ -106,11 +109,11 @@ class CyclicReplacerCache {
   /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
   /// the first retrieved CacheEntry must be resolved last. This should be
   /// natural when used as a stack / inside recursion.
-  CacheEntry lookupOrInit(const InT &element);
+  CacheEntry lookupOrInit(InT element);
 
 private:
   /// Register the replacement in the cache and update the replacementStack.
-  void finalizeReplacement(const InT &element, const OutT &result);
+  void finalizeReplacement(InT element, OutT result);
 
   CycleBreakerFn cycleBreaker;
   DenseMap<InT, OutT> standaloneCache;
@@ -145,7 +148,7 @@ class CyclicReplacerCache {
 
 template <typename InT, typename OutT>
 typename CyclicReplacerCache<InT, OutT>::CacheEntry
-CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
+CyclicReplacerCache<InT, OutT>::lookupOrInit(InT element) {
   assert(!resolvingCycle &&
          "illegal recursive invocation while breaking cycle");
 
@@ -195,8 +198,8 @@ CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
 }
 
 template <typename InT, typename OutT>
-void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
-                                                         const OutT &result) {
+void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,
+                                                         OutT result) {
   ReplacementFrame &currFrame = replacementStack.back();
   // With the conclusion of this replacement frame, the current element is no
   // longer a dependent element.
@@ -249,7 +252,7 @@ void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
 template <typename InT, typename OutT>
 class CachedCyclicReplacer {
 public:
-  using ReplacerFn = std::function<OutT(const InT &)>;
+  using ReplacerFn = std::function<OutT(InT)>;
   using CycleBreakerFn =
       typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
 
@@ -257,7 +260,7 @@ class CachedCyclicReplacer {
   CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
       : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
 
-  OutT operator()(const InT &element) {
+  OutT operator()(InT element) {
     auto cacheEntry = cache.lookupOrInit(element);
     if (std::optional<OutT> result = cacheEntry.get())
       return *result;
@@ -274,4 +277,4 @@ class CachedCyclicReplacer {
 
 } // namespace mlir
 
-#endif // MLIR_SUPPORT_CACHINGREPLACER_H
+#endif // MLIR_SUPPORT_CYCLICREPLACERCACHE_H
diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
index a4a92dbe147d4..ca02a3d692b2a 100644
--- a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -47,6 +47,7 @@ TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
 /// infinitely long vector. The cycle-breaker function prunes this infinite
 /// recursion in the replacer logic by returning an empty vector upon the first
 /// re-occurrence of an input value.
+namespace {
 class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
 public:
   // N ==> (N+1) % 3
@@ -71,6 +72,7 @@ class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
   int invokeCount = 0;
   std::optional<int> baseCase = std::nullopt;
 };
+} // namespace
 
 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
   // Starting at 0. Cycle length is 3.
@@ -128,6 +130,7 @@ TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
 /// - PrunedGraph
 ///   - A Graph where edges that used to cause cycles are now represented with
 ///     an indirection (a recursionId).
+namespace {
 class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
 public:
   /// A directed graph where nodes are non-negative integers.
@@ -317,6 +320,7 @@ class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
     return oss.str();
   }
 };
+} // namespace
 
 TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
   // 0 -> 1 -> 2



More information about the Mlir-commits mailing list