[llvm-branch-commits] [mlir] [MLIR] Cyclic AttrType Replacer (PR #98206)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jul 9 14:06:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Billy Zhu (zyx-billy)

<details>
<summary>Changes</summary>

The current `AttrTypeReplacer` does not allow for custom handling of replacer functions that may cause self-recursion. For example, the replacement of one attr/type may depend on the replacement of another attr/type (by calling into the replacer manually again), which in turn may depend on the replacement of the original attr/type.

To enable this functionality, this PR broke out the original AttrTypeReplacer into two parts:
- An uncached base version (`detail::AttrTypeReplacerBase`) that allows registering replacer functions and has logic for invoking it on attr/types & their sub-elements
- A cached version (`AttrTypeReplacer`) that provides the same caching as the original one. This is still the one used everywhere and behavior is unchanged.

On top of the uncached base version, a `CyclicAttrTypeReplacer` is introduced that provides caching & cycle-handling for replacer logic that is cyclic. Cycle-breaking & caching is provided by the `CyclicReplacerCache` from https://github.com/llvm/llvm-project/pull/98202.

Both concrete implementations of the uncached base version use CRTP to avoid dynamic dispatch. The base class merely provides replacer registration & invocation, and is not meant to be used, or otherwise extended elsewhere.

---

Stacked PRs:
- https://github.com/llvm/llvm-project/pull/98202
- ➡️ https://github.com/llvm/llvm-project/pull/98206

---

Patch is 24.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98206.diff


4 Files Affected:

- (modified) mlir/include/mlir/IR/AttrTypeSubElements.h (+119-19) 
- (modified) mlir/lib/IR/AttrTypeSubElements.cpp (+116-30) 
- (added) mlir/unittests/IR/AttrTypeReplacerTest.cpp (+231) 
- (modified) mlir/unittests/IR/CMakeLists.txt (+1) 


``````````diff
diff --git a/mlir/include/mlir/IR/AttrTypeSubElements.h b/mlir/include/mlir/IR/AttrTypeSubElements.h
index 3105040b87631..234767deea00a 100644
--- a/mlir/include/mlir/IR/AttrTypeSubElements.h
+++ b/mlir/include/mlir/IR/AttrTypeSubElements.h
@@ -16,6 +16,7 @@
 
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Visitors.h"
+#include "mlir/Support/CyclicReplacerCache.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include <optional>
@@ -116,9 +117,21 @@ class AttrTypeWalker {
 /// AttrTypeReplacer
 //===----------------------------------------------------------------------===//
 
-/// This class provides a utility for replacing attributes/types, and their sub
-/// elements. Multiple replacement functions may be registered.
-class AttrTypeReplacer {
+namespace detail {
+
+/// This class provides a base utility for replacing attributes/types, and their
+/// sub elements. Multiple replacement functions may be registered.
+///
+/// This base utility is uncached. Users can choose between two cached versions
+/// of this replacer:
+///   * For non-cyclic replacer logic, use `AttrTypeReplacer`.
+///   * For cyclic replacer logic, use `CyclicAttrTypeReplacer`.
+///
+/// Concrete implementations implement the following `replace` entry functions:
+///   * Attribute replace(Attribute attr);
+///   * Type replace(Type type);
+template <typename Concrete>
+class AttrTypeReplacerBase {
 public:
   //===--------------------------------------------------------------------===//
   // Application
@@ -139,12 +152,6 @@ class AttrTypeReplacer {
                                     bool replaceLocs = false,
                                     bool replaceTypes = false);
 
-  /// Replace the given attribute/type, and recursively replace any sub
-  /// elements. Returns either the new attribute/type, or nullptr in the case of
-  /// failure.
-  Attribute replace(Attribute attr);
-  Type replace(Type type);
-
   //===--------------------------------------------------------------------===//
   // Registration
   //===--------------------------------------------------------------------===//
@@ -206,21 +213,114 @@ class AttrTypeReplacer {
     });
   }
 
-private:
-  /// Internal implementation of the `replace` methods above.
-  template <typename T, typename ReplaceFns>
-  T replaceImpl(T element, ReplaceFns &replaceFns);
-
-  /// Replace the sub elements of the given interface.
-  template <typename T>
-  T replaceSubElements(T interface);
+protected:
+  /// Invokes the registered replacement functions from most recently registered
+  /// to least recently registered until a successful replacement is returned.
+  /// Unless skipping is requested, invokes `replace` on sub-elements of the
+  /// current attr/type.
+  Attribute replaceBase(Attribute attr);
+  Type replaceBase(Type type);
 
+private:
   /// The set of replacement functions that map sub elements.
   std::vector<ReplaceFn<Attribute>> attrReplacementFns;
   std::vector<ReplaceFn<Type>> typeReplacementFns;
+};
+
+} // namespace detail
+
+/// This is an attribute/type replacer that is naively cached. It is best used
+/// when the replacer logic is guaranteed to not contain cycles. Otherwise, any
+/// re-occurrence of an in-progress element will be skipped.
+class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> {
+public:
+  Attribute replace(Attribute attr);
+  Type replace(Type type);
+
+private:
+  /// Shared concrete implementation of the public `replace` functions. Invokes
+  /// `replaceBase` with caching.
+  template <typename T>
+  T cachedReplaceImpl(T element);
+
+  // Stores the opaque pointer of an attribute or type.
+  DenseMap<const void *, const void *> cache;
+};
+
+/// This is an attribute/type replacer that supports custom handling of cycles
+/// in the replacer logic. In addition to registering replacer functions, it
+/// allows registering cycle-breaking functions in the same style.
+class CyclicAttrTypeReplacer
+    : public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> {
+public:
+  CyclicAttrTypeReplacer();
 
-  /// The set of cached mappings for attributes/types.
-  DenseMap<const void *, const void *> attrTypeMap;
+  //===--------------------------------------------------------------------===//
+  // Application
+  //===--------------------------------------------------------------------===//
+
+  Attribute replace(Attribute attr);
+  Type replace(Type type);
+
+  //===--------------------------------------------------------------------===//
+  // Registration
+  //===--------------------------------------------------------------------===//
+
+  /// A cycle-breaking function. This is invoked if the same element is asked to
+  /// be replaced again when the first instance of it is still being replaced.
+  /// This function must not perform any more recursive `replace` calls.
+  /// If it is able to break the cycle, it should return a replacement result.
+  /// Otherwise, it can return std::nullopt to defer cycle breaking to the next
+  /// repeated element. However, the user must guarantee that, in any possible
+  /// cycle, there always exists at least one element that can break the cycle.
+  template <typename T>
+  using CycleBreakerFn = std::function<std::optional<T>(T)>;
+
+  /// Register a cycle-breaking function.
+  /// When breaking cycles, the mostly recently added cycle-breaking functions
+  /// will be invoked first.
+  void addCycleBreaker(CycleBreakerFn<Attribute> fn);
+  void addCycleBreaker(CycleBreakerFn<Type> fn);
+
+  /// Register a cycle-breaking function that doesn't match the default
+  /// signature.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<
+                std::decay_t<FnT>>::template arg_t<0>,
+            typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
+                                                Attribute, Type>>
+  std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) {
+    addCycleBreaker([callback = std::forward<FnT>(callback)](
+                        BaseT base) -> std::optional<BaseT> {
+      if (auto derived = dyn_cast<T>(base))
+        return callback(derived);
+      return std::nullopt;
+    });
+  }
+
+private:
+  /// Invokes the registered cycle-breaker functions from most recently
+  /// registered to least recently registered until a successful result is
+  /// returned.
+  std::optional<const void *> breakCycleImpl(void *element);
+
+  /// Shared concrete implementation of the public `replace` functions.
+  template <typename T>
+  T cachedReplaceImpl(T element);
+
+  /// The set of registered cycle-breaker functions.
+  std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns;
+  std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns;
+
+  /// A cache of previously-replaced attr/types.
+  /// The key of the cache is the opaque value of an AttrOrType. Using
+  /// AttrOrType allows distinguishing between the two types when invoking
+  /// cycle-breakers. Using its opaque value avoids the cyclic dependency issue
+  /// of directly using `AttrOrType` to instantiate the cache.
+  /// The value of the cache is just the opaque value of the attr/type itself
+  /// (not the PointerUnion).
+  using AttrOrType = PointerUnion<Attribute, Type>;
+  CyclicReplacerCache<void *, const void *> cache;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AttrTypeSubElements.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp
index 79b04966be6eb..783236ed3a9df 100644
--- a/mlir/lib/IR/AttrTypeSubElements.cpp
+++ b/mlir/lib/IR/AttrTypeSubElements.cpp
@@ -67,22 +67,28 @@ WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
 }
 
 //===----------------------------------------------------------------------===//
-/// AttrTypeReplacer
+/// AttrTypeReplacerBase
 //===----------------------------------------------------------------------===//
 
-void AttrTypeReplacer::addReplacement(ReplaceFn<Attribute> fn) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
+    ReplaceFn<Attribute> fn) {
   attrReplacementFns.emplace_back(std::move(fn));
 }
-void AttrTypeReplacer::addReplacement(ReplaceFn<Type> fn) {
+
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
+    ReplaceFn<Type> fn) {
   typeReplacementFns.push_back(std::move(fn));
 }
 
-void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
-                                         bool replaceLocs, bool replaceTypes) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn(
+    Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
   // Functor that replaces the given element if the new value is different,
   // otherwise returns nullptr.
   auto replaceIfDifferent = [&](auto element) {
-    auto replacement = replace(element);
+    auto replacement = static_cast<Concrete *>(this)->replace(element);
     return (replacement && replacement != element) ? replacement : nullptr;
   };
 
@@ -127,17 +133,16 @@ void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
   }
 }
 
-void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op,
-                                                    bool replaceAttrs,
-                                                    bool replaceLocs,
-                                                    bool replaceTypes) {
+template <typename Concrete>
+void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn(
+    Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
   op->walk([&](Operation *nestedOp) {
     replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
   });
 }
 
-template <typename T>
-static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
+template <typename T, typename Replacer>
+static void updateSubElementImpl(T element, Replacer &replacer,
                                  SmallVectorImpl<T> &newElements,
                                  FailureOr<bool> &changed) {
   // Bail early if we failed at any point.
@@ -160,18 +165,18 @@ static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
   }
 }
 
-template <typename T>
-T AttrTypeReplacer::replaceSubElements(T interface) {
+template <typename T, typename Replacer>
+static T replaceSubElements(T interface, Replacer &replacer) {
   // Walk the current sub-elements, replacing them as necessary.
   SmallVector<Attribute, 16> newAttrs;
   SmallVector<Type, 16> newTypes;
   FailureOr<bool> changed = false;
   interface.walkImmediateSubElements(
       [&](Attribute element) {
-        updateSubElementImpl(element, *this, newAttrs, changed);
+        updateSubElementImpl(element, replacer, newAttrs, changed);
       },
       [&](Type element) {
-        updateSubElementImpl(element, *this, newTypes, changed);
+        updateSubElementImpl(element, replacer, newTypes, changed);
       });
   if (failed(changed))
     return nullptr;
@@ -184,13 +189,9 @@ T AttrTypeReplacer::replaceSubElements(T interface) {
 }
 
 /// Shared implementation of replacing a given attribute or type element.
-template <typename T, typename ReplaceFns>
-T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
-  const void *opaqueElement = element.getAsOpaquePointer();
-  auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
-  if (!inserted)
-    return T::getFromOpaquePointer(it->second);
-
+template <typename T, typename ReplaceFns, typename Replacer>
+static T replaceElementImpl(T element, ReplaceFns &replaceFns,
+                            Replacer &replacer) {
   T result = element;
   WalkResult walkResult = WalkResult::advance();
   for (auto &replaceFn : llvm::reverse(replaceFns)) {
@@ -202,29 +203,114 @@ T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
 
   // If an error occurred, return nullptr to indicate failure.
   if (walkResult.wasInterrupted() || !result) {
-    attrTypeMap[opaqueElement] = nullptr;
     return nullptr;
   }
 
   // Handle replacing sub-elements if this element is also a container.
   if (!walkResult.wasSkipped()) {
     // Replace the sub elements of this element, bailing if we fail.
-    if (!(result = replaceSubElements(result))) {
-      attrTypeMap[opaqueElement] = nullptr;
+    if (!(result = replaceSubElements(result, replacer))) {
       return nullptr;
     }
   }
 
-  attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
+  return result;
+}
+
+template <typename Concrete>
+Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) {
+  return replaceElementImpl(attr, attrReplacementFns,
+                            *static_cast<Concrete *>(this));
+}
+
+template <typename Concrete>
+Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) {
+  return replaceElementImpl(type, typeReplacementFns,
+                            *static_cast<Concrete *>(this));
+}
+
+//===----------------------------------------------------------------------===//
+/// AttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+template class detail::AttrTypeReplacerBase<AttrTypeReplacer>;
+
+template <typename T>
+T AttrTypeReplacer::cachedReplaceImpl(T element) {
+  const void *opaqueElement = element.getAsOpaquePointer();
+  auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
+  if (!inserted)
+    return T::getFromOpaquePointer(it->second);
+
+  T result = replaceBase(element);
+
+  cache[opaqueElement] = result.getAsOpaquePointer();
   return result;
 }
 
 Attribute AttrTypeReplacer::replace(Attribute attr) {
-  return replaceImpl(attr, attrReplacementFns);
+  return cachedReplaceImpl(attr);
 }
 
-Type AttrTypeReplacer::replace(Type type) {
-  return replaceImpl(type, typeReplacementFns);
+Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }
+
+//===----------------------------------------------------------------------===//
+/// CyclicAttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>;
+
+CyclicAttrTypeReplacer::CyclicAttrTypeReplacer()
+    : cache([&](void *attr) { return breakCycleImpl(attr); }) {}
+
+void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) {
+  attrCycleBreakerFns.emplace_back(std::move(fn));
+}
+
+void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) {
+  typeCycleBreakerFns.emplace_back(std::move(fn));
+}
+
+template <typename T>
+T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
+  void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
+  CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry =
+      cache.lookupOrInit(opaqueTaggedElement);
+  if (auto resultOpt = cacheEntry.get())
+    return T::getFromOpaquePointer(*resultOpt);
+
+  T result = replaceBase(element);
+
+  cacheEntry.resolve(result.getAsOpaquePointer());
+  return result;
+}
+
+Attribute CyclicAttrTypeReplacer::replace(Attribute attr) {
+  return cachedReplaceImpl(attr);
+}
+
+Type CyclicAttrTypeReplacer::replace(Type type) {
+  return cachedReplaceImpl(type);
+}
+
+std::optional<const void *>
+CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
+  AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
+  if (auto attr = dyn_cast<Attribute>(attrType)) {
+    for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
+      if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
+        return newRes->getAsOpaquePointer();
+      }
+    }
+  } else {
+    auto type = dyn_cast<Type>(attrType);
+    for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
+      if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
+        return newRes->getAsOpaquePointer();
+      }
+    }
+  }
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/IR/AttrTypeReplacerTest.cpp b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
new file mode 100644
index 0000000000000..c7b42eb267c7a
--- /dev/null
+++ b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
@@ -0,0 +1,231 @@
+//===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AttrTypeSubElements.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacer
+//===----------------------------------------------------------------------===//
+
+TEST(CyclicAttrTypeReplacerTest, testNoRecursion) {
+  MLIRContext ctx;
+
+  CyclicAttrTypeReplacer replacer;
+  replacer.addReplacement([&](BoolAttr b) {
+    return StringAttr::get(&ctx, b.getValue() ? "true" : "false");
+  });
+
+  EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)),
+            StringAttr::get(&ctx, "true"));
+  EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)),
+            StringAttr::get(&ctx, "false"));
+  EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
+            mlir::UnitAttr::get(&ctx));
+}
+
+TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) {
+  MLIRContext ctx;
+  Builder b(&ctx);
+
+  CyclicAttrTypeReplacer replacer;
+  // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ...
+  replacer.addReplacement([&](IntegerAttr attr) {
+    return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3));
+  });
+  // The first repeat of any integer attr is pruned into a unit attr.
+  replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); });
+
+  // No recursion case.
+  EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
+            mlir::UnitAttr::get(&ctx));
+  // Starting at 0.
+  EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx));
+  // Starting at 2.
+  EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx));
+}
+
+//===----------------------------------------------------------------------===//
+// CyclicAttrTypeReplacerTest: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+  CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) {
+    // IntegerType<width = N>
+    // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>.
+    // This will create a chain of infinite length without recursion pruning.
+    replacer.addReplacement([&](mlir::IntegerType intType) {
+      ++invokeCount;
+      return b.getFunctionType(
+          {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)});
+    });
+  }
+
+  void setBaseCase(std::optional<unsigned> pruneAt) {
+    replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
+      return (!pruneAt || intType.getWidth() == *pruneAt)
+                 ? std::make_optional(b.getIndexType())
+                 : std::nullopt;
+    });
+  }
+
+  Type getFunctionTypeChain(unsigned N) {
+    Type type = b.getIndexType();
+    for (unsigned i = 0; i < N; i++)
+      type = b.getFunctionType({}, type);
+    return type;
+  };
+
+  MLIRContext ctx;
+  Builder b;
+  CyclicAttrTypeReplacer replacer;
+  int invokeCount = 0;
+};
+
+TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+  setBaseCase(std::nullopt);
+
+  // No recursion case.
+  EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
+  EXPECT_EQ(invokeCount, 0);
+
+  // Starting at 0. Cycle length is 3.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
+            getFunctionTypeChain(3));
+  EXPECT_EQ(invokeCount, 3);
+
+  // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+  invokeCount = 0;
+  EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/98206


More information about the llvm-branch-commits mailing list