[llvm-branch-commits] [mlir] [MLIR] Cyclic AttrType Replacer (PR #98206)
Billy Zhu via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jul 9 12:33:37 PDT 2024
https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/98206
>From 8d1dd886c4a80507c8a97dda15e91acbfa7c3619 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 9 Jul 2024 10:27:13 -0700
Subject: [PATCH] refactor attrtype replacers & add tests
---
mlir/include/mlir/IR/AttrTypeSubElements.h | 138 ++++++++++--
mlir/lib/IR/AttrTypeSubElements.cpp | 146 ++++++++++---
mlir/unittests/IR/AttrTypeReplacerTest.cpp | 231 +++++++++++++++++++++
mlir/unittests/IR/CMakeLists.txt | 1 +
4 files changed, 467 insertions(+), 49 deletions(-)
create mode 100644 mlir/unittests/IR/AttrTypeReplacerTest.cpp
diff --git a/mlir/include/mlir/IR/AttrTypeSubElements.h b/mlir/include/mlir/IR/AttrTypeSubElements.h
index 3105040b87631..234767deea00a 100644
--- a/mlir/include/mlir/IR/AttrTypeSubElements.h
+++ b/mlir/include/mlir/IR/AttrTypeSubElements.h
@@ -16,6 +16,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
+#include "mlir/Support/CyclicReplacerCache.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include <optional>
@@ -116,9 +117,21 @@ class AttrTypeWalker {
/// AttrTypeReplacer
//===----------------------------------------------------------------------===//
-/// This class provides a utility for replacing attributes/types, and their sub
-/// elements. Multiple replacement functions may be registered.
-class AttrTypeReplacer {
+namespace detail {
+
+/// This class provides a base utility for replacing attributes/types, and their
+/// sub elements. Multiple replacement functions may be registered.
+///
+/// This base utility is uncached. Users can choose between two cached versions
+/// of this replacer:
+/// * For non-cyclic replacer logic, use `AttrTypeReplacer`.
+/// * For cyclic replacer logic, use `CyclicAttrTypeReplacer`.
+///
+/// Concrete implementations implement the following `replace` entry functions:
+/// * Attribute replace(Attribute attr);
+/// * Type replace(Type type);
+template <typename Concrete>
+class AttrTypeReplacerBase {
public:
//===--------------------------------------------------------------------===//
// Application
@@ -139,12 +152,6 @@ class AttrTypeReplacer {
bool replaceLocs = false,
bool replaceTypes = false);
- /// Replace the given attribute/type, and recursively replace any sub
- /// elements. Returns either the new attribute/type, or nullptr in the case of
- /// failure.
- Attribute replace(Attribute attr);
- Type replace(Type type);
-
//===--------------------------------------------------------------------===//
// Registration
//===--------------------------------------------------------------------===//
@@ -206,21 +213,114 @@ class AttrTypeReplacer {
});
}
-private:
- /// Internal implementation of the `replace` methods above.
- template <typename T, typename ReplaceFns>
- T replaceImpl(T element, ReplaceFns &replaceFns);
-
- /// Replace the sub elements of the given interface.
- template <typename T>
- T replaceSubElements(T interface);
+protected:
+ /// Invokes the registered replacement functions from most recently registered
+ /// to least recently registered until a successful replacement is returned.
+ /// Unless skipping is requested, invokes `replace` on sub-elements of the
+ /// current attr/type.
+ Attribute replaceBase(Attribute attr);
+ Type replaceBase(Type type);
+private:
/// The set of replacement functions that map sub elements.
std::vector<ReplaceFn<Attribute>> attrReplacementFns;
std::vector<ReplaceFn<Type>> typeReplacementFns;
+};
+
+} // namespace detail
+
+/// This is an attribute/type replacer that is naively cached. It is best used
+/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any
+/// re-occurrence of an in-progress element will be skipped.
+class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> {
+public:
+ Attribute replace(Attribute attr);
+ Type replace(Type type);
+
+private:
+ /// Shared concrete implementation of the public `replace` functions. Invokes
+ /// `replaceBase` with caching.
+ template <typename T>
+ T cachedReplaceImpl(T element);
+
+ // Stores the opaque pointer of an attribute or type.
+ DenseMap<const void *, const void *> cache;
+};
+
+/// This is an attribute/type replacer that supports custom handling of cycles
+/// in the replacer logic. In addition to registering replacer functions, it
+/// allows registering cycle-breaking functions in the same style.
+class CyclicAttrTypeReplacer
+ : public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> {
+public:
+ CyclicAttrTypeReplacer();
- /// The set of cached mappings for attributes/types.
- DenseMap<const void *, const void *> attrTypeMap;
+ //===--------------------------------------------------------------------===//
+ // Application
+ //===--------------------------------------------------------------------===//
+
+ Attribute replace(Attribute attr);
+ Type replace(Type type);
+
+ //===--------------------------------------------------------------------===//
+ // Registration
+ //===--------------------------------------------------------------------===//
+
+ /// A cycle-breaking function. This is invoked if the same element is asked to
+ /// be replaced again when the first instance of it is still being replaced.
+ /// This function must not perform any more recursive `replace` calls.
+ /// If it is able to break the cycle, it should return a replacement result.
+ /// Otherwise, it can return std::nullopt to defer cycle breaking to the next
+ /// repeated element. However, the user must guarantee that, in any possible
+ /// cycle, there always exists at least one element that can break the cycle.
+ template <typename T>
+ using CycleBreakerFn = std::function<std::optional<T>(T)>;
+
+ /// Register a cycle-breaking function.
+ /// When breaking cycles, the mostly recently added cycle-breaking functions
+ /// will be invoked first.
+ void addCycleBreaker(CycleBreakerFn<Attribute> fn);
+ void addCycleBreaker(CycleBreakerFn<Type> fn);
+
+ /// Register a cycle-breaking function that doesn't match the default
+ /// signature.
+ template <typename FnT,
+ typename T = typename llvm::function_traits<
+ std::decay_t<FnT>>::template arg_t<0>,
+ typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+ Attribute, Type>>
+ std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) {
+ addCycleBreaker([callback = std::forward<FnT>(callback)](
+ BaseT base) -> std::optional<BaseT> {
+ if (auto derived = dyn_cast<T>(base))
+ return callback(derived);
+ return std::nullopt;
+ });
+ }
+
+private:
+ /// Invokes the registered cycle-breaker functions from most recently
+ /// registered to least recently registered until a successful result is
+ /// returned.
+ std::optional<const void *> breakCycleImpl(void *element);
+
+ /// Shared concrete implementation of the public `replace` functions.
+ template <typename T>
+ T cachedReplaceImpl(T element);
+
+ /// The set of registered cycle-breaker functions.
+ std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns;
+ std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns;
+
+ /// A cache of previously-replaced attr/types.
+ /// The key of the cache is the opaque value of an AttrOrType. Using
+ /// AttrOrType allows distinguishing between the two types when invoking
+ /// cycle-breakers. Using its opaque value avoids the cyclic dependency issue
+ /// of directly using `AttrOrType` to instantiate the cache.
+ /// The value of the cache is just the opaque value of the attr/type itself
+ /// (not the PointerUnion).
+ using AttrOrType = PointerUnion<Attribute, Type>;
+ CyclicReplacerCache<void *, const void *> cache;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AttrTypeSubElements.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp
index 79b04966be6eb..783236ed3a9df 100644
--- a/mlir/lib/IR/AttrTypeSubElements.cpp
+++ b/mlir/lib/IR/AttrTypeSubElements.cpp
@@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
}
//===----------------------------------------------------------------------===//
-/// AttrTypeReplacer
+/// AttrTypeReplacerBase
//===----------------------------------------------------------------------===//
-void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
+ ReplaceFn<Attribute> fn) {
attrReplacementFns.emplace_back(std::move(fn));
}
-void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {
+
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
+ ReplaceFn<Type> fn) {
typeReplacementFns.push_back(std::move(fn));
}
-void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
- bool replaceLocs, bool replaceTypes) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn(
+ Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
// Functor that replaces the given element if the new value is different,
// otherwise returns nullptr.
auto replaceIfDifferent = [&](auto element) {
- auto replacement = replace(element);
+ auto replacement = static_cast<Concrete *>(this)->replace(element);
return (replacement && replacement != element) ? replacement : nullptr;
};
@@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
}
}
-void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
- bool replaceAttrs,
- bool replaceLocs,
- bool replaceTypes) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn(
+ Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
op->walk([&](Operation *nestedOp) {
replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
});
}
-template <typename T>
-static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
+template <typename T, typename Replacer>
+static void updateSubElementImpl(T element, Replacer &replacer,
SmallVectorImpl<T> &newElements,
FailureOr<bool> &changed) {
// Bail early if we failed at any point.
@@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
}
}
-template <typename T>
-T AttrTypeReplacer::replaceSubElements(T interface) {
+template <typename T, typename Replacer>
+static T replaceSubElements(T interface, Replacer &replacer) {
// Walk the current sub-elements, replacing them as necessary.
SmallVector<Attribute, 16> newAttrs;
SmallVector<Type, 16> newTypes;
FailureOr<bool> changed = false;
interface.walkImmediateSubElements(
[&](Attribute element) {
- updateSubElementImpl(element, *this, newAttrs, changed);
+ updateSubElementImpl(element, replacer, newAttrs, changed);
},
[&](Type element) {
- updateSubElementImpl(element, *this, newTypes, changed);
+ updateSubElementImpl(element, replacer, newTypes, changed);
});
if (failed(changed))
return nullptr;
@@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) {
}
/// Shared implementation of replacing a given attribute or type element.
-template <typename T, typename ReplaceFns>
-T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
- const void *opaqueElement = element.getAsOpaquePointer();
- auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
- if (!inserted)
- return T::getFromOpaquePointer(it->second);
-
+template <typename T, typename ReplaceFns, typename Replacer>
+static T replaceElementImpl(T element, ReplaceFns &replaceFns,
+ Replacer &replacer) {
T result = element;
WalkResult walkResult = WalkResult::advance();
for (auto &replaceFn : llvm::reverse(replaceFns)) {
@@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
// If an error occurred, return nullptr to indicate failure.
if (walkResult.wasInterrupted() || !result) {
- attrTypeMap[opaqueElement] = nullptr;
return nullptr;
}
// Handle replacing sub-elements if this element is also a container.
if (!walkResult.wasSkipped()) {
// Replace the sub elements of this element, bailing if we fail.
- if (!(result = replaceSubElements(result))) {
- attrTypeMap[opaqueElement] = nullptr;
+ if (!(result = replaceSubElements(result, replacer))) {
return nullptr;
}
}
- attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
+ return result;
+}
+
+template <typename Concrete>
+Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) {
+ return replaceElementImpl(attr, attrReplacementFns,
+ *static_cast<Concrete *>(this));
+}
+
+template <typename Concrete>
+Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) {
+ return replaceElementImpl(type, typeReplacementFns,
+ *static_cast<Concrete *>(this));
+}
+
+//===----------------------------------------------------------------------===//
+/// AttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+template class detail::AttrTypeReplacerBase<AttrTypeReplacer>;
+
+template <typename T>
+T AttrTypeReplacer::cachedReplaceImpl(T element) {
+ const void *opaqueElement = element.getAsOpaquePointer();
+ auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
+ if (!inserted)
+ return T::getFromOpaquePointer(it->second);
+
+ T result = replaceBase(element);
+
+ cache[opaqueElement] = result.getAsOpaquePointer();
return result;
}
Attribute AttrTypeReplacer::replace(Attribute attr) {
- return replaceImpl(attr, attrReplacementFns);
+ return cachedReplaceImpl(attr);
}
-Type AttrTypeReplacer::replace(Type type) {
- return replaceImpl(type, typeReplacementFns);
+Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }
+
+//===----------------------------------------------------------------------===//
+/// CyclicAttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>;
+
+CyclicAttrTypeReplacer::CyclicAttrTypeReplacer()
+ : cache([&](void *attr) { return breakCycleImpl(attr); }) {}
+
+void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) {
+ attrCycleBreakerFns.emplace_back(std::move(fn));
+}
+
+void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) {
+ typeCycleBreakerFns.emplace_back(std::move(fn));
+}
+
+template <typename T>
+T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
+ void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
+ CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry =
+ cache.lookupOrInit(opaqueTaggedElement);
+ if (auto resultOpt = cacheEntry.get())
+ return T::getFromOpaquePointer(*resultOpt);
+
+ T result = replaceBase(element);
+
+ cacheEntry.resolve(result.getAsOpaquePointer());
+ return result;
+}
+
+Attribute CyclicAttrTypeReplacer::replace(Attribute attr) {
+ return cachedReplaceImpl(attr);
+}
+
+Type CyclicAttrTypeReplacer::replace(Type type) {
+ return cachedReplaceImpl(type);
+}
+
+std::optional<const void *>
+CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
+ AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
+ if (auto attr = dyn_cast<Attribute>(attrType)) {
+ for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
+ if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
+ return newRes->getAsOpaquePointer();
+ }
+ }
+ } else {
+ auto type = dyn_cast<Type>(attrType);
+ for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
+ if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
+ return newRes->getAsOpaquePointer();
+ }
+ }
+ }
+ return std::nullopt;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/IR/AttrTypeReplacerTest.cpp b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
new file mode 100644
index 0000000000000..c7b42eb267c7a
--- /dev/null
+++ b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
@@ -0,0 +1,231 @@
+//===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AttrTypeSubElements.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+TEST(CyclicAttrTypeReplacerTest, testNoRecursion) {
+ MLIRContext ctx;
+
+ CyclicAttrTypeReplacer replacer;
+ replacer.addReplacement([&](BoolAttr b) {
+ return StringAttr::get(&ctx, b.getValue() ? "true" : "false");
+ });
+
+ EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)),
+ StringAttr::get(&ctx, "true"));
+ EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)),
+ StringAttr::get(&ctx, "false"));
+ EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
+ mlir::UnitAttr::get(&ctx));
+}
+
+TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) {
+ MLIRContext ctx;
+ Builder b(&ctx);
+
+ CyclicAttrTypeReplacer replacer;
+ // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ...
+ replacer.addReplacement([&](IntegerAttr attr) {
+ return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3));
+ });
+ // The first repeat of any integer attr is pruned into a unit attr.
+ replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); });
+
+ // No recursion case.
+ EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
+ mlir::UnitAttr::get(&ctx));
+ // Starting at 0.
+ EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx));
+ // Starting at 2.
+ EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx));
+}
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacerTest: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+ CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) {
+ // IntegerType<width = N>
+ // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>.
+ // This will create a chain of infinite length without recursion pruning.
+ replacer.addReplacement([&](mlir::IntegerType intType) {
+ ++invokeCount;
+ return b.getFunctionType(
+ {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)});
+ });
+ }
+
+ void setBaseCase(std::optional<unsigned> pruneAt) {
+ replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
+ return (!pruneAt || intType.getWidth() == *pruneAt)
+ ? std::make_optional(b.getIndexType())
+ : std::nullopt;
+ });
+ }
+
+ Type getFunctionTypeChain(unsigned N) {
+ Type type = b.getIndexType();
+ for (unsigned i = 0; i < N; i++)
+ type = b.getFunctionType({}, type);
+ return type;
+ };
+
+ MLIRContext ctx;
+ Builder b;
+ CyclicAttrTypeReplacer replacer;
+ int invokeCount = 0;
+};
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+ setBaseCase(std::nullopt);
+
+ // No recursion case.
+ EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
+ EXPECT_EQ(invokeCount, 0);
+
+ // Starting at 0. Cycle length is 3.
+ invokeCount = 0;
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+ getFunctionTypeChain(3));
+ EXPECT_EQ(invokeCount, 3);
+
+ // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+ invokeCount = 0;
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+ getFunctionTypeChain(5));
+ EXPECT_EQ(invokeCount, 2);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) {
+ setBaseCase(std::nullopt);
+
+ // Starting at 1. Cycle length is 3.
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+ getFunctionTypeChain(3));
+ EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) {
+ setBaseCase(0);
+
+ // Starting at 0. Cycle length is 3.
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+ getFunctionTypeChain(3));
+ EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) {
+ setBaseCase(0);
+
+ // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+ getFunctionTypeChain(5));
+ EXPECT_EQ(invokeCount, 5);
+}
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacerTest: BranchingRecusion
+//===----------------------------------------------------------------------===//
+
+class CyclicAttrTypeReplacerBranchingRecusionPruningTest
+ : public ::testing::Test {
+public:
+ CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) {
+ // IntegerType<width = N>
+ // ==> FunctionType<
+ // IntegerType< width = (N+1) % 3> =>
+ // IntegerType< width = (N+1) % 3>>.
+ // This will create a binary tree of infinite depth without pruning.
+ replacer.addReplacement([&](mlir::IntegerType intType) {
+ ++invokeCount;
+ Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3);
+ return b.getFunctionType({child}, {child});
+ });
+ }
+
+ void setBaseCase(std::optional<unsigned> pruneAt) {
+ replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
+ return (!pruneAt || intType.getWidth() == *pruneAt)
+ ? std::make_optional(b.getIndexType())
+ : std::nullopt;
+ });
+ }
+
+ Type getFunctionTypeTree(unsigned N) {
+ Type type = b.getIndexType();
+ for (unsigned i = 0; i < N; i++)
+ type = b.getFunctionType(type, type);
+ return type;
+ };
+
+ MLIRContext ctx;
+ Builder b;
+ CyclicAttrTypeReplacer replacer;
+ int invokeCount = 0;
+};
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) {
+ setBaseCase(std::nullopt);
+
+ // No recursion case.
+ EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
+ EXPECT_EQ(invokeCount, 0);
+
+ // Starting at 0. Cycle length is 3.
+ invokeCount = 0;
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+ getFunctionTypeTree(3));
+ // Since both branches are identical, this should incur linear invocations
+ // of the replacement function instead of exponential.
+ EXPECT_EQ(invokeCount, 3);
+
+ // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+ invokeCount = 0;
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+ getFunctionTypeTree(5));
+ EXPECT_EQ(invokeCount, 2);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) {
+ setBaseCase(std::nullopt);
+
+ // Starting at 1. Cycle length is 3.
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+ getFunctionTypeTree(3));
+ EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) {
+ setBaseCase(0);
+
+ // Starting at 0. Cycle length is 3.
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+ getFunctionTypeTree(3));
+ EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) {
+ setBaseCase(0);
+
+ // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+ EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
+ getFunctionTypeTree(5));
+ EXPECT_EQ(invokeCount, 5);
+}
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 71f8f449756ec..05cb36e190316 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRIRTests
AffineExprTest.cpp
AffineMapTest.cpp
AttributeTest.cpp
+ AttrTypeReplacerTest.cpp
DialectTest.cpp
InterfaceTest.cpp
IRMapping.cpp
More information about the llvm-branch-commits
mailing list