[llvm-branch-commits] [mlir] [MLIR] Cyclic AttrType Replacer (PR #98206)

Billy Zhu via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jul 9 12:33:37 PDT 2024


https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/98206

>From 8d1dd886c4a80507c8a97dda15e91acbfa7c3619 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 10:27:13 -0700
Subject: [PATCH] refactor attrtype replacers & add tests

---
 mlir/include/mlir/IR/AttrTypeSubElements.h | 138 ++++++++++--
 mlir/lib/IR/AttrTypeSubElements.cpp        | 146 ++++++++++---
 mlir/unittests/IR/AttrTypeReplacerTest.cpp | 231 +++++++++++++++++++++
 mlir/unittests/IR/CMakeLists.txt           |   1 +
 4 files changed, 467 insertions(+), 49 deletions(-)
 create mode 100644 mlir/unittests/IR/AttrTypeReplacerTest.cpp

diff --git a/mlir/include/mlir/IR/AttrTypeSubElements.h b/mlir/include/mlir/IR/AttrTypeSubElements.h
index 3105040b87631..234767deea00a 100644
--- a/mlir/include/mlir/IR/AttrTypeSubElements.h
+++ b/mlir/include/mlir/IR/AttrTypeSubElements.h
@@ -16,6 +16,7 @@
 
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Visitors.h"
+#include "mlir/Support/CyclicReplacerCache.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include <optional>
@@ -116,9 +117,21 @@ class AttrTypeWalker {
 /// AttrTypeReplacer
 //===----------------------------------------------------------------------===//
 
-/// This class provides a utility for replacing attributes/types, and their sub
-/// elements. Multiple replacement functions may be registered.
-class AttrTypeReplacer {
+namespace detail {
+
+/// This class provides a base utility for replacing attributes/types, and their
+/// sub elements. Multiple replacement functions may be registered.
+///
+/// This base utility is uncached. Users can choose between two cached versions
+/// of this replacer:
+///   * For non-cyclic replacer logic, use `AttrTypeReplacer`.
+///   * For cyclic replacer logic, use `CyclicAttrTypeReplacer`.
+///
+/// Concrete implementations implement the following `replace` entry functions:
+///   * Attribute replace(Attribute attr);
+///   * Type replace(Type type);
+template <typename Concrete>
+class AttrTypeReplacerBase {
 public:
   //===--------------------------------------------------------------------===//
   // Application
@@ -139,12 +152,6 @@ class AttrTypeReplacer {
                                     bool replaceLocs = false,
                                     bool replaceTypes = false);
 
-  /// Replace the given attribute/type, and recursively replace any sub
-  /// elements. Returns either the new attribute/type, or nullptr in the case of
-  /// failure.
-  Attribute replace(Attribute attr);
-  Type replace(Type type);
-
   //===--------------------------------------------------------------------===//
   // Registration
   //===--------------------------------------------------------------------===//
@@ -206,21 +213,114 @@ class AttrTypeReplacer {
     });
   }
 
-private:
-  /// Internal implementation of the `replace` methods above.
-  template <typename T, typename ReplaceFns>
-  T replaceImpl(T element, ReplaceFns &replaceFns);
-
-  /// Replace the sub elements of the given interface.
-  template <typename T>
-  T replaceSubElements(T interface);
+protected:
+  /// Invokes the registered replacement functions from most recently registered
+  /// to least recently registered until a successful replacement is returned.
+  /// Unless skipping is requested, invokes `replace` on sub-elements of the
+  /// current attr/type.
+  Attribute replaceBase(Attribute attr);
+  Type replaceBase(Type type);
 
+private:
   /// The set of replacement functions that map sub elements.
   std::vector<ReplaceFn<Attribute>> attrReplacementFns;
   std::vector<ReplaceFn<Type>> typeReplacementFns;
+};
+
+} // namespace detail
+
+/// This is an attribute/type replacer that is naively cached. It is best used
+/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any
+/// re-occurrence of an in-progress element will be skipped.
+class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> {
+public:
+  Attribute replace(Attribute attr);
+  Type replace(Type type);
+
+private:
+  /// Shared concrete implementation of the public `replace` functions. Invokes
+  /// `replaceBase` with caching.
+  template <typename T>
+  T cachedReplaceImpl(T element);
+
+  // Stores the opaque pointer of an attribute or type.
+  DenseMap<const void *, const void *> cache;
+};
+
+/// This is an attribute/type replacer that supports custom handling of cycles
+/// in the replacer logic. In addition to registering replacer functions, it
+/// allows registering cycle-breaking functions in the same style.
+class CyclicAttrTypeReplacer
+    : public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> {
+public:
+  CyclicAttrTypeReplacer();
 
-  /// The set of cached mappings for attributes/types.
-  DenseMap<const void *, const void *> attrTypeMap;
+  //===--------------------------------------------------------------------===//
+  // Application
+  //===--------------------------------------------------------------------===//
+
+  Attribute replace(Attribute attr);
+  Type replace(Type type);
+
+  //===--------------------------------------------------------------------===//
+  // Registration
+  //===--------------------------------------------------------------------===//
+
+  /// A cycle-breaking function. This is invoked if the same element is asked to
+  /// be replaced again when the first instance of it is still being replaced.
+  /// This function must not perform any more recursive `replace` calls.
+  /// If it is able to break the cycle, it should return a replacement result.
+  /// Otherwise, it can return std::nullopt to defer cycle breaking to the next
+  /// repeated element. However, the user must guarantee that, in any possible
+  /// cycle, there always exists at least one element that can break the cycle.
+  template <typename T>
+  using CycleBreakerFn = std::function<std::optional<T>(T)>;
+
+  /// Register a cycle-breaking function.
+  /// When breaking cycles, the mostly recently added cycle-breaking functions
+  /// will be invoked first.
+  void addCycleBreaker(CycleBreakerFn<Attribute> fn);
+  void addCycleBreaker(CycleBreakerFn<Type> fn);
+
+  /// Register a cycle-breaking function that doesn't match the default
+  /// signature.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<
+                std::decay_t<FnT>>::template arg_t<0>,
+            typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+                                                Attribute, Type>>
+  std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) {
+    addCycleBreaker([callback = std::forward<FnT>(callback)](
+                        BaseT base) -> std::optional<BaseT> {
+      if (auto derived = dyn_cast<T>(base))
+        return callback(derived);
+      return std::nullopt;
+    });
+  }
+
+private:
+  /// Invokes the registered cycle-breaker functions from most recently
+  /// registered to least recently registered until a successful result is
+  /// returned.
+  std::optional<const void *> breakCycleImpl(void *element);
+
+  /// Shared concrete implementation of the public `replace` functions.
+  template <typename T>
+  T cachedReplaceImpl(T element);
+
+  /// The set of registered cycle-breaker functions.
+  std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns;
+  std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns;
+
+  /// A cache of previously-replaced attr/types.
+  /// The key of the cache is the opaque value of an AttrOrType. Using
+  /// AttrOrType allows distinguishing between the two types when invoking
+  /// cycle-breakers. Using its opaque value avoids the cyclic dependency issue
+  /// of directly using `AttrOrType` to instantiate the cache.
+  /// The value of the cache is just the opaque value of the attr/type itself
+  /// (not the PointerUnion).
+  using AttrOrType = PointerUnion<Attribute, Type>;
+  CyclicReplacerCache<void *, const void *> cache;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AttrTypeSubElements.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp
index 79b04966be6eb..783236ed3a9df 100644
--- a/mlir/lib/IR/AttrTypeSubElements.cpp
+++ b/mlir/lib/IR/AttrTypeSubElements.cpp
@@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
 }
 
 //===----------------------------------------------------------------------===//
-/// AttrTypeReplacer
+/// AttrTypeReplacerBase
 //===----------------------------------------------------------------------===//
 
-void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
+    ReplaceFn<Attribute> fn) {
   attrReplacementFns.emplace_back(std::move(fn));
 }
-void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {
+
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
+    ReplaceFn<Type> fn) {
   typeReplacementFns.push_back(std::move(fn));
 }
 
-void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
-                                         bool replaceLocs, bool replaceTypes) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn(
+    Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
   // Functor that replaces the given element if the new value is different,
   // otherwise returns nullptr.
   auto replaceIfDifferent = [&](auto element) {
-    auto replacement = replace(element);
+    auto replacement = static_cast<Concrete *>(this)->replace(element);
     return (replacement && replacement != element) ? replacement : nullptr;
   };
 
@@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
   }
 }
 
-void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
-                                                    bool replaceAttrs,
-                                                    bool replaceLocs,
-                                                    bool replaceTypes) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn(
+    Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
   op->walk([&](Operation *nestedOp) {
     replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
   });
 }
 
-template <typename T>
-static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
+template <typename T, typename Replacer>
+static void updateSubElementImpl(T element, Replacer &replacer,
                                  SmallVectorImpl<T> &newElements,
                                  FailureOr<bool> &changed) {
   // Bail early if we failed at any point.
@@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
   }
 }
 
-template <typename T>
-T AttrTypeReplacer::replaceSubElements(T interface) {
+template <typename T, typename Replacer>
+static T replaceSubElements(T interface, Replacer &replacer) {
   // Walk the current sub-elements, replacing them as necessary.
   SmallVector<Attribute, 16> newAttrs;
   SmallVector<Type, 16> newTypes;
   FailureOr<bool> changed = false;
   interface.walkImmediateSubElements(
       [&](Attribute element) {
-        updateSubElementImpl(element, *this, newAttrs, changed);
+        updateSubElementImpl(element, replacer, newAttrs, changed);
       },
       [&](Type element) {
-        updateSubElementImpl(element, *this, newTypes, changed);
+        updateSubElementImpl(element, replacer, newTypes, changed);
       });
   if (failed(changed))
     return nullptr;
@@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) {
 }
 
 /// Shared implementation of replacing a given attribute or type element.
-template <typename T, typename ReplaceFns>
-T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
-  const void *opaqueElement = element.getAsOpaquePointer();
-  auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
-  if (!inserted)
-    return T::getFromOpaquePointer(it->second);
-
+template <typename T, typename ReplaceFns, typename Replacer>
+static T replaceElementImpl(T element, ReplaceFns &replaceFns,
+                            Replacer &replacer) {
   T result = element;
   WalkResult walkResult = WalkResult::advance();
   for (auto &replaceFn : llvm::reverse(replaceFns)) {
@@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
 
   // If an error occurred, return nullptr to indicate failure.
   if (walkResult.wasInterrupted() || !result) {
-    attrTypeMap[opaqueElement] = nullptr;
     return nullptr;
   }
 
   // Handle replacing sub-elements if this element is also a container.
   if (!walkResult.wasSkipped()) {
     // Replace the sub elements of this element, bailing if we fail.
-    if (!(result = replaceSubElements(result))) {
-      attrTypeMap[opaqueElement] = nullptr;
+    if (!(result = replaceSubElements(result, replacer))) {
       return nullptr;
     }
   }
 
-  attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
+  return result;
+}
+
+template <typename Concrete>
+Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) {
+  return replaceElementImpl(attr, attrReplacementFns,
+                            *static_cast<Concrete *>(this));
+}
+
+template <typename Concrete>
+Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) {
+  return replaceElementImpl(type, typeReplacementFns,
+                            *static_cast<Concrete *>(this));
+}
+
+//===----------------------------------------------------------------------===//
+/// AttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+template class detail::AttrTypeReplacerBase<AttrTypeReplacer>;
+
+template <typename T>
+T AttrTypeReplacer::cachedReplaceImpl(T element) {
+  const void *opaqueElement = element.getAsOpaquePointer();
+  auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
+  if (!inserted)
+    return T::getFromOpaquePointer(it->second);
+
+  T result = replaceBase(element);
+
+  cache[opaqueElement] = result.getAsOpaquePointer();
   return result;
 }
 
 Attribute AttrTypeReplacer::replace(Attribute attr) {
-  return replaceImpl(attr, attrReplacementFns);
+  return cachedReplaceImpl(attr);
 }
 
-Type AttrTypeReplacer::replace(Type type) {
-  return replaceImpl(type, typeReplacementFns);
+Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }
+
+//===----------------------------------------------------------------------===//
+/// CyclicAttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>;
+
+CyclicAttrTypeReplacer::CyclicAttrTypeReplacer()
+    : cache([&](void *attr) { return breakCycleImpl(attr); }) {}
+
+void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) {
+  attrCycleBreakerFns.emplace_back(std::move(fn));
+}
+
+void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) {
+  typeCycleBreakerFns.emplace_back(std::move(fn));
+}
+
+template <typename T>
+T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
+  void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
+  CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry =
+      cache.lookupOrInit(opaqueTaggedElement);
+  if (auto resultOpt = cacheEntry.get())
+    return T::getFromOpaquePointer(*resultOpt);
+
+  T result = replaceBase(element);
+
+  cacheEntry.resolve(result.getAsOpaquePointer());
+  return result;
+}
+
+Attribute CyclicAttrTypeReplacer::replace(Attribute attr) {
+  return cachedReplaceImpl(attr);
+}
+
+Type CyclicAttrTypeReplacer::replace(Type type) {
+  return cachedReplaceImpl(type);
+}
+
+std::optional<const void *>
+CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
+  AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
+  if (auto attr = dyn_cast<Attribute>(attrType)) {
+    for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
+      if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
+        return newRes->getAsOpaquePointer();
+      }
+    }
+  } else {
+    auto type = dyn_cast<Type>(attrType);
+    for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
+      if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
+        return newRes->getAsOpaquePointer();
+      }
+    }
+  }
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/IR/AttrTypeReplacerTest.cpp b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
new file mode 100644
index 0000000000000..c7b42eb267c7a
--- /dev/null
+++ b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
@@ -0,0 +1,231 @@
+//===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AttrTypeSubElements.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+TEST(CyclicAttrTypeReplacerTest, testNoRecursion) {
+  MLIRContext ctx;
+
+  CyclicAttrTypeReplacer replacer;
+  replacer.addReplacement([&](BoolAttr b) {
+    return StringAttr::get(&ctx, b.getValue() ? "true" : "false");
+  });
+
+  EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)),
+            StringAttr::get(&ctx, "true"));
+  EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)),
+            StringAttr::get(&ctx, "false"));
+  EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
+            mlir::UnitAttr::get(&ctx));
+}
+
+TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) {
+  MLIRContext ctx;
+  Builder b(&ctx);
+
+  CyclicAttrTypeReplacer replacer;
+  // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ...
+  replacer.addReplacement([&](IntegerAttr attr) {
+    return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3));
+  });
+  // The first repeat of any integer attr is pruned into a unit attr.
+  replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); });
+
+  // No recursion case.
+  EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
+            mlir::UnitAttr::get(&ctx));
+  // Starting at 0.
+  EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx));
+  // Starting at 2.
+  EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx));
+}
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacerTest: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+  CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) {
+    // IntegerType<width = N>
+    // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>.
+    // This will create a chain of infinite length without recursion pruning.
+    replacer.addReplacement([&](mlir::IntegerType intType) {
+      ++invokeCount;
+      return b.getFunctionType(
+          {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)});
+    });
+  }
+
+  void setBaseCase(std::optional<unsigned> pruneAt) {
+    replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
+      return (!pruneAt || intType.getWidth() == *pruneAt)
+                 ? std::make_optional(b.getIndexType())
+                 : std::nullopt;
+    });
+  }
+
+  Type getFunctionTypeChain(unsigned N) {
+    Type type = b.getIndexType();
+    for (unsigned i = 0; i < N; i++)
+      type = b.getFunctionType({}, type);
+    return type;
+  };
+
+  MLIRContext ctx;
+  Builder b;
+  CyclicAttrTypeReplacer replacer;
+  int invokeCount = 0;
+};
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+  setBaseCase(std::nullopt);
+
+  // No recursion case.
+  EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
+  EXPECT_EQ(invokeCount, 0);
+
+  // Starting at 0. Cycle length is 3.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeChain(3));
+  EXPECT_EQ(invokeCount, 3);
+
+  // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeChain(5));
+  EXPECT_EQ(invokeCount, 2);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) {
+  setBaseCase(std::nullopt);
+
+  // Starting at 1. Cycle length is 3.
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) {
+  setBaseCase(0);
+
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) {
+  setBaseCase(0);
+
+  // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeChain(5));
+  EXPECT_EQ(invokeCount, 5);
+}
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacerTest: BranchingRecusion
+//===----------------------------------------------------------------------===//
+
+class CyclicAttrTypeReplacerBranchingRecusionPruningTest
+    : public ::testing::Test {
+public:
+  CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) {
+    // IntegerType<width = N>
+    // ==> FunctionType<
+    //       IntegerType< width = (N+1) % 3> =>
+    //         IntegerType< width = (N+1) % 3>>.
+    // This will create a binary tree of infinite depth without pruning.
+    replacer.addReplacement([&](mlir::IntegerType intType) {
+      ++invokeCount;
+      Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3);
+      return b.getFunctionType({child}, {child});
+    });
+  }
+
+  void setBaseCase(std::optional<unsigned> pruneAt) {
+    replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
+      return (!pruneAt || intType.getWidth() == *pruneAt)
+                 ? std::make_optional(b.getIndexType())
+                 : std::nullopt;
+    });
+  }
+
+  Type getFunctionTypeTree(unsigned N) {
+    Type type = b.getIndexType();
+    for (unsigned i = 0; i < N; i++)
+      type = b.getFunctionType(type, type);
+    return type;
+  };
+
+  MLIRContext ctx;
+  Builder b;
+  CyclicAttrTypeReplacer replacer;
+  int invokeCount = 0;
+};
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) {
+  setBaseCase(std::nullopt);
+
+  // No recursion case.
+  EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
+  EXPECT_EQ(invokeCount, 0);
+
+  // Starting at 0. Cycle length is 3.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeTree(3));
+  // Since both branches are identical, this should incur linear invocations
+  // of the replacement function instead of exponential.
+  EXPECT_EQ(invokeCount, 3);
+
+  // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeTree(5));
+  EXPECT_EQ(invokeCount, 2);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) {
+  setBaseCase(std::nullopt);
+
+  // Starting at 1. Cycle length is 3.
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeTree(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) {
+  setBaseCase(0);
+
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeTree(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) {
+  setBaseCase(0);
+
+  // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+            getFunctionTypeTree(5));
+  EXPECT_EQ(invokeCount, 5);
+}
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 71f8f449756ec..05cb36e190316 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRIRTests
   AffineExprTest.cpp
   AffineMapTest.cpp
   AttributeTest.cpp
+  AttrTypeReplacerTest.cpp
   DialectTest.cpp
   InterfaceTest.cpp
   IRMapping.cpp



More information about the llvm-branch-commits mailing list