[Mlir-commits] [mlir] 5f58e14 - [mlir] Add a generic DialectResourceBlobManager to simplify resource blob management

River Riddle llvmlistbot at llvm.org
Mon Aug 1 12:50:12 PDT 2022


Author: River Riddle
Date: 2022-08-01T12:37:16-07:00
New Revision: 5f58e14b36edbdacef86a2f3fc192b5720b2ba62

URL: https://github.com/llvm/llvm-project/commit/5f58e14b36edbdacef86a2f3fc192b5720b2ba62
DIFF: https://github.com/llvm/llvm-project/commit/5f58e14b36edbdacef86a2f3fc192b5720b2ba62.diff

LOG: [mlir] Add a generic DialectResourceBlobManager to simplify resource blob management

The DialectResourceBlobManager class provides functionality for managing resource blobs
in a generic, dialect-agnostic fashion. In addition to this class, a dialect interface and custom
resource handle are provided to simplify referencing and interacting with the manager. These
classes intend to simplify the work required for dialects that want to manage resource blobs
during compilation, such as for large elements attrs.  The old manager for the resource example
in the test dialect has been updated to use this, which provides and cleaner and more consistent API.

This commit also adds new HeapAsmResourceBlob and ImmortalAsmResourceBlob to simplify
creating resource blobs in common scenarios.

Differential Revision: https://reviews.llvm.org/D130021

Added: 
    mlir/include/mlir/IR/DialectResourceBlobManager.h
    mlir/lib/IR/DialectResourceBlobManager.cpp

Modified: 
    mlir/include/mlir/IR/AsmState.h
    mlir/include/mlir/IR/Dialect.h
    mlir/lib/IR/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestAttributes.h
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/test/lib/Dialect/Test/TestDialect.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 7cae35458f815..b54bf176ce2a6 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -77,69 +77,74 @@ class AsmStateImpl;
 //===----------------------------------------------------------------------===//
 // Resource Entry
 
-/// This class is used to build resource entries for use by the printer. Each
-/// resource entry is represented using a key/value pair. The provided key must
-/// be unique within the current context, which allows for a client to provide
-/// resource entries without worrying about overlap with other clients.
-class AsmResourceBuilder {
-public:
-  virtual ~AsmResourceBuilder();
-
-  /// Build a resource entry represented by the given bool.
-  virtual void buildBool(StringRef key, bool data) = 0;
-
-  /// Build a resource entry represented by the given human-readable string
-  /// value.
-  virtual void buildString(StringRef key, StringRef data) = 0;
-
-  /// Build an resource entry represented by the given binary blob data.
-  virtual void buildBlob(StringRef key, ArrayRef<char> data,
-                         uint32_t dataAlignment) = 0;
-  /// Build an resource entry represented by the given binary blob data. This is
-  /// a useful overload if the data type is known. Note that this does not
-  /// support `char` element types to avoid accidentally not providing the
-  /// expected alignment of data in situations that treat blobs generically.
-  template <typename T>
-  std::enable_if_t<!std::is_same<T, char>::value> buildBlob(StringRef key,
-                                                            ArrayRef<T> data) {
-    buildBlob(
-        key, ArrayRef<char>((const char *)data.data(), data.size() * sizeof(T)),
-        alignof(T));
-  }
-};
-
 /// This class represents a processed binary blob of data. A resource blob is
 /// essentially a collection of data, potentially mutable, with an associated
 /// deleter function (used if the data needs to be destroyed).
 class AsmResourceBlob {
 public:
-  /// A deleter function that frees a blob given the data and allocation size.
-  using DeleterFn = llvm::unique_function<void(const void *data, size_t size)>;
+  /// A deleter function that frees a blob given the data, allocation size, and
+  /// allocation aligment.
+  using DeleterFn =
+      llvm::unique_function<void(void *data, size_t size, size_t align)>;
+
+  //===--------------------------------------------------------------------===//
+  // Construction
+  //===--------------------------------------------------------------------===//
 
   AsmResourceBlob() = default;
-  AsmResourceBlob(ArrayRef<char> data, DeleterFn deleter, bool dataIsMutable)
-      : data(data), deleter(std::move(deleter)), dataIsMutable(dataIsMutable) {}
+  AsmResourceBlob(ArrayRef<char> data, size_t dataAlignment, DeleterFn deleter,
+                  bool dataIsMutable)
+      : data(data), dataAlignment(dataAlignment), deleter(std::move(deleter)),
+        dataIsMutable(dataIsMutable) {}
   /// Utility constructor that initializes a blob with a non-char type T.
   template <typename T, typename DelT>
   AsmResourceBlob(ArrayRef<T> data, DelT &&deleteFn, bool dataIsMutable)
       : data((const char *)data.data(), data.size() * sizeof(T)),
-        deleter([deleteFn = std::forward<DelT>(deleteFn)](const void *data,
-                                                          size_t size) {
-          return deleteFn((const T *)data, size);
+        dataAlignment(alignof(T)),
+        deleter([deleteFn = std::forward<DelT>(deleteFn)](
+                    void *data, size_t size, size_t align) {
+          return deleteFn((T *)data, size, align);
         }),
         dataIsMutable(dataIsMutable) {}
   AsmResourceBlob(AsmResourceBlob &&) = default;
-  AsmResourceBlob &operator=(AsmResourceBlob &&) = default;
+  AsmResourceBlob &operator=(AsmResourceBlob &&rhs) {
+    // Delete the current blob if necessary.
+    if (deleter)
+      deleter(const_cast<char *>(data.data()), data.size(), dataAlignment);
+
+    // Take the data entries from rhs.
+    data = rhs.data;
+    dataAlignment = rhs.dataAlignment;
+    deleter = std::move(rhs.deleter);
+    dataIsMutable = rhs.dataIsMutable;
+    return *this;
+  }
   AsmResourceBlob(const AsmResourceBlob &) = delete;
   AsmResourceBlob &operator=(const AsmResourceBlob &) = delete;
   ~AsmResourceBlob() {
     if (deleter)
-      deleter(data.data(), data.size());
+      deleter(const_cast<char *>(data.data()), data.size(), dataAlignment);
   }
 
+  //===--------------------------------------------------------------------===//
+  // Data Access
+  //===--------------------------------------------------------------------===//
+
+  /// Return the alignment of the underlying data.
+  size_t getDataAlignment() const { return dataAlignment; }
+
   /// Return the raw underlying data of this blob.
   ArrayRef<char> getData() const { return data; }
 
+  /// Return the underlying data as an array of the given type. This is an
+  /// inherrently unsafe operation, and should only be used when the data is
+  /// known to be of the correct type.
+  template <typename T>
+  ArrayRef<T> getDataAs() const {
+    return llvm::makeArrayRef<T>((const T *)data.data(),
+                                 data.size() / sizeof(T));
+  }
+
   /// Return a mutable reference to the raw underlying data of this blob.
   /// Asserts that the blob `isMutable`.
   MutableArrayRef<char> getMutableData() {
@@ -159,6 +164,9 @@ class AsmResourceBlob {
   /// The raw, properly aligned, blob data.
   ArrayRef<char> data;
 
+  /// The alignment of the data.
+  size_t dataAlignment = 0;
+
   /// An optional deleter function used to deallocate the underlying data when
   /// necessary.
   DeleterFn deleter;
@@ -167,6 +175,92 @@ class AsmResourceBlob {
   bool dataIsMutable;
 };
 
+/// This class provides a simple utility wrapper for creating heap allocated
+/// AsmResourceBlobs.
+class HeapAsmResourceBlob {
+public:
+  /// Create a new heap allocated blob with the given size and alignment.
+  /// `dataIsMutable` indicates if the allocated data can be mutated. By
+  /// default, we treat heap allocated blobs as mutable.
+  static AsmResourceBlob allocate(size_t size, size_t align,
+                                  bool dataIsMutable = true) {
+    return AsmResourceBlob(
+        ArrayRef<char>((char *)llvm::allocate_buffer(size, align), size), align,
+        llvm::deallocate_buffer, dataIsMutable);
+  }
+  /// Create a new heap allocated blob and copy the provided data into it.
+  static AsmResourceBlob allocateAndCopy(ArrayRef<char> data, size_t align,
+                                         bool dataIsMutable = true) {
+    AsmResourceBlob blob = allocate(data.size(), align, dataIsMutable);
+    std::memcpy(blob.getMutableData().data(), data.data(), data.size());
+    return blob;
+  }
+  template <typename T>
+  static std::enable_if_t<!std::is_same<T, char>::value, AsmResourceBlob>
+  allocateAndCopy(ArrayRef<T> data, bool dataIsMutable = true) {
+    return allocateAndCopy(
+        ArrayRef<char>((const char *)data.data(), data.size() * sizeof(T)),
+        alignof(T));
+  }
+};
+/// This class provides a simple utility wrapper for creating "unmanaged"
+/// AsmResourceBlobs. The lifetime of the data provided to these blobs is
+/// guaranteed to persist beyond the lifetime of this reference.
+class UnmanagedAsmResourceBlob {
+public:
+  /// Create a new unmanaged resource directly referencing the provided data.
+  /// `dataIsMutable` indicates if the allocated data can be mutated. By
+  /// default, we treat unmanaged blobs as immutable.
+  static AsmResourceBlob allocate(ArrayRef<char> data, size_t align,
+                                  bool dataIsMutable = false) {
+    return AsmResourceBlob(data, align, /*deleter=*/{},
+                           /*dataIsMutable=*/false);
+  }
+  template <typename T>
+  static std::enable_if_t<!std::is_same<T, char>::value, AsmResourceBlob>
+  allocate(ArrayRef<T> data, bool dataIsMutable = false) {
+    return allocate(
+        ArrayRef<char>((const char *)data.data(), data.size() * sizeof(T)),
+        alignof(T));
+  }
+};
+
+/// This class is used to build resource entries for use by the printer. Each
+/// resource entry is represented using a key/value pair. The provided key must
+/// be unique within the current context, which allows for a client to provide
+/// resource entries without worrying about overlap with other clients.
+class AsmResourceBuilder {
+public:
+  virtual ~AsmResourceBuilder();
+
+  /// Build a resource entry represented by the given bool.
+  virtual void buildBool(StringRef key, bool data) = 0;
+
+  /// Build a resource entry represented by the given human-readable string
+  /// value.
+  virtual void buildString(StringRef key, StringRef data) = 0;
+
+  /// Build an resource entry represented by the given binary blob data.
+  virtual void buildBlob(StringRef key, ArrayRef<char> data,
+                         uint32_t dataAlignment) = 0;
+  /// Build an resource entry represented by the given binary blob data. This is
+  /// a useful overload if the data type is known. Note that this does not
+  /// support `char` element types to avoid accidentally not providing the
+  /// expected alignment of data in situations that treat blobs generically.
+  template <typename T>
+  std::enable_if_t<!std::is_same<T, char>::value> buildBlob(StringRef key,
+                                                            ArrayRef<T> data) {
+    buildBlob(
+        key, ArrayRef<char>((const char *)data.data(), data.size() * sizeof(T)),
+        alignof(T));
+  }
+  /// Build an resource entry represented by the given resource blob. This is
+  /// a useful overload if a blob already exists in-memory.
+  void buildBlob(StringRef key, const AsmResourceBlob &blob) {
+    buildBlob(key, blob.getData(), blob.getDataAlignment());
+  }
+};
+
 /// This class represents a single parsed resource entry.
 class AsmParsedResourceEntry {
 public:
@@ -186,17 +280,24 @@ class AsmParsedResourceEntry {
   /// failure if the entry does not correspond to a string.
   virtual FailureOr<std::string> parseAsString() const = 0;
 
-  /// The type of an allocator function used to allocate memory for a blob when
-  /// required. The function is provided a size and alignment, and should return
-  /// an aligned allocation buffer.
+  /// An allocator function used to allocate memory for a blob when required.
+  /// The function is provided a size and alignment, and should return an
+  /// aligned allocation buffer.
   using BlobAllocatorFn =
-      function_ref<AsmResourceBlob(unsigned size, unsigned align)>;
+      function_ref<AsmResourceBlob(size_t size, size_t align)>;
 
   /// Parse the resource entry represented by a binary blob. Returns failure if
   /// the entry does not correspond to a blob. If the blob needed to be
   /// allocated, the given allocator function is invoked.
   virtual FailureOr<AsmResourceBlob>
   parseAsBlob(BlobAllocatorFn allocator) const = 0;
+  /// Parse the resource entry represented by a binary blob using heap
+  /// allocation.
+  FailureOr<AsmResourceBlob> parseAsBlob() const {
+    return parseAsBlob([](size_t size, size_t align) {
+      return HeapAsmResourceBlob::allocate(size, align);
+    });
+  }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index d09d96d21d154..00dd621e6f54f 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -157,13 +157,13 @@ class Dialect {
 
   /// Lookup an interface for the given ID if one is registered, otherwise
   /// nullptr.
-  const DialectInterface *getRegisteredInterface(TypeID interfaceID) {
+  DialectInterface *getRegisteredInterface(TypeID interfaceID) {
     auto it = registeredInterfaces.find(interfaceID);
     return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
   }
   template <typename InterfaceT>
-  const InterfaceT *getRegisteredInterface() {
-    return static_cast<const InterfaceT *>(
+  InterfaceT *getRegisteredInterface() {
+    return static_cast<InterfaceT *>(
         getRegisteredInterface(InterfaceT::getInterfaceID()));
   }
 
@@ -189,6 +189,12 @@ class Dialect {
     (void)std::initializer_list<int>{
         0, (addInterface(std::make_unique<Args>(this)), 0)...};
   }
+  template <typename InterfaceT, typename... Args>
+  InterfaceT &addInterface(Args &&...args) {
+    InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...);
+    addInterface(std::unique_ptr<DialectInterface>(interface));
+    return *interface;
+  }
 
 protected:
   /// The constructor takes a unique namespace for this dialect as well as the
@@ -305,15 +311,11 @@ struct isa_impl<
 };
 template <typename T>
 struct cast_retty_impl<T, ::mlir::Dialect *> {
-  using ret_type =
-      std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
-                         const T *>;
+  using ret_type = T *;
 };
 template <typename T>
 struct cast_retty_impl<T, ::mlir::Dialect> {
-  using ret_type =
-      std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
-                         const T &>;
+  using ret_type = T &;
 };
 
 template <typename T>
@@ -325,7 +327,7 @@ struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
   }
   template <typename To>
   static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
-                          const To &>
+                          To &>
   doitImpl(::mlir::Dialect &dialect) {
     return *dialect.getRegisteredInterface<To>();
   }

diff  --git a/mlir/include/mlir/IR/DialectResourceBlobManager.h b/mlir/include/mlir/IR/DialectResourceBlobManager.h
new file mode 100644
index 0000000000000..23dead9d54011
--- /dev/null
+++ b/mlir/include/mlir/IR/DialectResourceBlobManager.h
@@ -0,0 +1,215 @@
+//===- DialectResourceBlobManager.h - Dialect Blob Management ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utility classes for referencing and managing asm resource
+// blobs. These classes are intended to more easily facilitate the sharing of
+// large blobs, and their definition.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H
+#define MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H
+
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/RWMutex.h"
+#include "llvm/Support/SMLoc.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// DialectResourceBlobManager
+//===---------------------------------------------------------------------===//
+
+/// This class defines a manager for dialect resource blobs. Blobs are uniqued
+/// by a given key, and represented using AsmResourceBlobs.
+class DialectResourceBlobManager {
+public:
+  /// The class represents an individual entry of a blob.
+  class BlobEntry {
+  public:
+    /// Return the key used to reference this blob.
+    StringRef getKey() const { return key; }
+
+    /// Return the blob owned by this entry if one has been initialized. Returns
+    /// nullptr otherwise.
+    const AsmResourceBlob *getBlob() const { return blob ? &*blob : nullptr; }
+    AsmResourceBlob *getBlob() { return blob ? &*blob : nullptr; }
+
+    /// Set the blob owned by this entry.
+    void setBlob(AsmResourceBlob &&newBlob) { blob = std::move(newBlob); }
+
+  private:
+    BlobEntry() = default;
+    BlobEntry(BlobEntry &&) = default;
+    BlobEntry &operator=(const BlobEntry &) = delete;
+    BlobEntry &operator=(BlobEntry &&) = delete;
+
+    /// Initialize this entry with the given key and blob.
+    void initialize(StringRef newKey, Optional<AsmResourceBlob> newBlob) {
+      key = newKey;
+      blob = std::move(newBlob);
+    }
+
+    /// The key used for this blob.
+    StringRef key;
+
+    /// The blob that is referenced by this entry if it is valid.
+    Optional<AsmResourceBlob> blob;
+
+    /// Allow access to the constructors.
+    friend DialectResourceBlobManager;
+    friend class llvm::StringMapEntryStorage<BlobEntry>;
+  };
+
+  /// Return the blob registered for the given name, or nullptr if no blob
+  /// is registered.
+  BlobEntry *lookup(StringRef name);
+  const BlobEntry *lookup(StringRef name) const {
+    return const_cast<DialectResourceBlobManager *>(this)->lookup(name);
+  }
+
+  /// Update the blob for the entry defined by the provided name. This method
+  /// asserts that an entry for the given name exists in the manager.
+  void update(StringRef name, AsmResourceBlob &&newBlob);
+
+  /// Insert a new entry with the provided name and optional blob data. The name
+  /// may be modified during insertion if another entry already exists with that
+  /// name. Returns the inserted entry.
+  BlobEntry &insert(StringRef name, Optional<AsmResourceBlob> blob = {});
+  /// Insertion method that returns a dialect specific handle to the inserted
+  /// entry.
+  template <typename HandleT>
+  HandleT insert(typename HandleT::Dialect *dialect, StringRef name,
+                 Optional<AsmResourceBlob> blob = {}) {
+    BlobEntry &entry = insert(name, std::move(blob));
+    return HandleT(&entry, dialect);
+  }
+
+private:
+  /// A mutex to protect access to the blob map.
+  llvm::sys::SmartRWMutex<true> blobMapLock;
+
+  /// The internal map of tracked blobs. StringMap stores entries in distinct
+  /// allocations, so we can freely take references to the data without fear of
+  /// invalidation during additional insertion/deletion.
+  llvm::StringMap<BlobEntry> blobMap;
+};
+
+//===----------------------------------------------------------------------===//
+// ResourceBlobManagerDialectInterface
+//===---------------------------------------------------------------------===//
+
+/// This class implements a dialect interface that provides common functionality
+/// for interacting with a resource blob manager.
+class ResourceBlobManagerDialectInterface
+    : public DialectInterface::Base<ResourceBlobManagerDialectInterface> {
+public:
+  ResourceBlobManagerDialectInterface(Dialect *dialect)
+      : Base(dialect),
+        blobManager(std::make_shared<DialectResourceBlobManager>()) {}
+
+  /// Return the blob manager held by this interface.
+  DialectResourceBlobManager &getBlobManager() { return *blobManager; }
+  const DialectResourceBlobManager &getBlobManager() const {
+    return *blobManager;
+  }
+
+  /// Set the blob manager held by this interface.
+  void
+  setBlobManager(std::shared_ptr<DialectResourceBlobManager> newBlobManager) {
+    blobManager = std::move(newBlobManager);
+  }
+
+private:
+  /// The blob manager owned by the dialect implementing this interface.
+  std::shared_ptr<DialectResourceBlobManager> blobManager;
+};
+
+/// This class provides a base class for dialects implementing the resource blob
+/// interface. It provides several additional dialect specific utilities on top
+/// of the generic interface. `HandleT` is the type of the handle used to
+/// reference a resource blob.
+template <typename HandleT>
+class ResourceBlobManagerDialectInterfaceBase
+    : public ResourceBlobManagerDialectInterface {
+public:
+  using ResourceBlobManagerDialectInterface::
+      ResourceBlobManagerDialectInterface;
+
+  /// Update the blob for the entry defined by the provided name. This method
+  /// asserts that an entry for the given name exists in the manager.
+  void update(StringRef name, AsmResourceBlob &&newBlob) {
+    getBlobManager().update(name, std::move(newBlob));
+  }
+
+  /// Insert a new resource blob entry with the provided name and optional blob
+  /// data. The name may be modified during insertion if another entry already
+  /// exists with that name. Returns a dialect specific handle to the inserted
+  /// entry.
+  HandleT insert(StringRef name, Optional<AsmResourceBlob> blob = {}) {
+    return getBlobManager().template insert<HandleT>(
+        cast<typename HandleT::Dialect>(getDialect()), name, std::move(blob));
+  }
+
+  /// Build resources for each of the referenced blobs within this manager.
+  void buildResources(AsmResourceBuilder &provider,
+                      ArrayRef<AsmDialectResourceHandle> referencedResources) {
+    for (const AsmDialectResourceHandle &handle : referencedResources) {
+      if (const auto *dialectHandle = dyn_cast<HandleT>(&handle)) {
+        if (auto *blob = dialectHandle->getBlob())
+          provider.buildBlob(dialectHandle->getKey(), *blob);
+      }
+    }
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// DialectResourceBlobHandle
+//===----------------------------------------------------------------------===//
+
+/// This class defines a dialect specific handle to a resource blob. These
+/// handles utilize a StringRef for the internal key, and an AsmResourceBlob as
+/// the underlying data.
+template <typename DialectT>
+struct DialectResourceBlobHandle
+    : public AsmDialectResourceHandleBase<DialectResourceBlobHandle<DialectT>,
+                                          DialectResourceBlobManager::BlobEntry,
+                                          DialectT> {
+  using AsmDialectResourceHandleBase<DialectResourceBlobHandle<DialectT>,
+                                     DialectResourceBlobManager::BlobEntry,
+                                     DialectT>::AsmDialectResourceHandleBase;
+  using ManagerInterface = ResourceBlobManagerDialectInterfaceBase<
+      DialectResourceBlobHandle<DialectT>>;
+
+  /// Return the human readable string key for this handle.
+  StringRef getKey() const { return this->getResource()->getKey(); }
+
+  /// Return the blob referenced by this handle if the underlying resource has
+  /// been initialized. Returns nullptr otherwise.
+  AsmResourceBlob *getBlob() { return this->getResource()->getBlob(); }
+  const AsmResourceBlob *getBlob() const {
+    return this->getResource()->getBlob();
+  }
+
+  /// Get the interface for the dialect that owns handles of this type. Asserts
+  /// that the dialect is registered.
+  static ManagerInterface &getManagerInterface(MLIRContext *ctx) {
+    auto *dialect = ctx->getOrLoadDialect<DialectT>();
+    assert(dialect && "dialect not registered");
+
+    auto *iface = dialect->template getRegisteredInterface<ManagerInterface>();
+    assert(iface && "dialect doesn't provide the blob manager interface?");
+    return *iface;
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_DIALECTRESOURCEBLOBMANAGER_H

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 94f06467b7c07..72f386c31a241 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRIR
   BuiltinTypeInterfaces.cpp
   Diagnostics.cpp
   Dialect.cpp
+  DialectResourceBlobManager.cpp
   Dominance.cpp
   ExtensibleDialect.cpp
   FunctionImplementation.cpp

diff  --git a/mlir/lib/IR/DialectResourceBlobManager.cpp b/mlir/lib/IR/DialectResourceBlobManager.cpp
new file mode 100644
index 0000000000000..dbfe9c1ef85e9
--- /dev/null
+++ b/mlir/lib/IR/DialectResourceBlobManager.cpp
@@ -0,0 +1,64 @@
+//===- DialectResourceBlobManager.cpp - Dialect Blob Management -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/DialectResourceBlobManager.h"
+#include "llvm/ADT/SmallString.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// DialectResourceBlobManager
+//===---------------------------------------------------------------------===//
+
+auto DialectResourceBlobManager::lookup(StringRef name) -> BlobEntry * {
+  llvm::sys::SmartScopedReader<true> reader(blobMapLock);
+
+  auto it = blobMap.find(name);
+  return it != blobMap.end() ? &it->second : nullptr;
+}
+
+void DialectResourceBlobManager::update(StringRef name,
+                                        AsmResourceBlob &&newBlob) {
+  BlobEntry *entry = lookup(name);
+  assert(entry && "`update` expects an existing entry for the provided name");
+  entry->setBlob(std::move(newBlob));
+}
+
+auto DialectResourceBlobManager::insert(StringRef name,
+                                        Optional<AsmResourceBlob> blob)
+    -> BlobEntry & {
+  llvm::sys::SmartScopedWriter<true> writer(blobMapLock);
+
+  // Functor used to attempt insertion with a given name.
+  auto tryInsertion = [&](StringRef name) -> BlobEntry * {
+    auto it = blobMap.try_emplace(name, BlobEntry());
+    if (it.second) {
+      it.first->second.initialize(it.first->getKey(), std::move(blob));
+      return &it.first->second;
+    }
+    return nullptr;
+  };
+
+  // Try inserting with the name provided by the user.
+  if (BlobEntry *entry = tryInsertion(name))
+    return *entry;
+
+  // If an entry already exists for the user provided name, tweak the name and
+  // re-attempt insertion until we find one that is unique.
+  llvm::SmallString<32> nameStorage(name);
+  nameStorage.push_back('_');
+  size_t nameCounter = 1;
+  do {
+    Twine(nameCounter++).toVector(nameStorage);
+
+    // Try inserting with the new name.
+    if (BlobEntry *entry = tryInsertion(name))
+      return *entry;
+    nameStorage.resize(name.size() + 1);
+  } while (true);
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index c86f29cf15a96..6bd1ad236d39e 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -246,7 +246,7 @@ def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
   let mnemonic = "e1di64_elements";
   let parameters = (ins
     AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
-    ResourceHandleParameter<"TestExternalElementsDataHandle">:$handle
+    ResourceHandleParameter<"TestDialectResourceBlobHandle">:$handle
   );
   let extraClassDeclaration = [{
     /// Return the elements referenced by this attribute.

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 1fbd2920a0b2a..e28d7426e2682 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -170,7 +170,9 @@ Attribute TestSubElementsAccessAttr::replaceImmediateSubElements(
 //===----------------------------------------------------------------------===//
 
 ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
-  return getHandle().getData()->getData();
+  if (auto *blob = getHandle().getBlob())
+    return blob->getDataAs<uint64_t>();
+  return llvm::None;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index 74414182313b8..4cb4d61044860 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -24,9 +24,14 @@
 
 #include "TestAttrInterfaces.h.inc"
 #include "TestOpEnums.h.inc"
+#include "mlir/IR/DialectResourceBlobManager.h"
 
 namespace test {
-struct TestExternalElementsDataHandle;
+class TestDialect;
+
+/// A handle used to reference external elements instances.
+using TestDialectResourceBlobHandle =
+    mlir::DialectResourceBlobHandle<TestDialect>;
 } // namespace test
 
 #define GET_ATTRDEF_CLASSES

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 2e27663e2cd67..e75c7ea964a4b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -44,55 +44,6 @@ void test::registerTestDialect(DialectRegistry &registry) {
   registry.insert<TestDialect>();
 }
 
-//===----------------------------------------------------------------------===//
-// External Elements Data
-//===----------------------------------------------------------------------===//
-
-ArrayRef<uint64_t> TestExternalElementsData::getData() const {
-  ArrayRef<char> data = AsmResourceBlob::getData();
-  return ArrayRef<uint64_t>((const uint64_t *)data.data(),
-                            data.size() / sizeof(uint64_t));
-}
-
-TestExternalElementsData
-TestExternalElementsData::allocate(size_t numElements) {
-  return TestExternalElementsData(
-      llvm::ArrayRef<uint64_t>(new uint64_t[numElements], numElements),
-      [](const uint64_t *data, size_t) { delete[] data; },
-      /*dataIsMutable=*/true);
-}
-
-const TestExternalElementsData *
-TestExternalElementsDataManager::getData(StringRef name) const {
-  auto it = dataMap.find(name);
-  return it != dataMap.end() ? &*it->second : nullptr;
-}
-
-std::pair<TestExternalElementsDataManager::DataMap::iterator, bool>
-TestExternalElementsDataManager::insert(StringRef name) {
-  auto it = dataMap.try_emplace(name, nullptr);
-  if (it.second)
-    return it;
-
-  llvm::SmallString<32> nameStorage(name);
-  nameStorage.push_back('_');
-  size_t nameCounter = 1;
-  do {
-    nameStorage += std::to_string(nameCounter++);
-    auto it = dataMap.try_emplace(nameStorage, nullptr);
-    if (it.second)
-      return it;
-    nameStorage.resize(name.size() + 1);
-  } while (true);
-}
-
-void TestExternalElementsDataManager::setData(StringRef name,
-                                              TestExternalElementsData &&data) {
-  auto it = dataMap.find(name);
-  assert(it != dataMap.end() && "data not registered");
-  it->second = std::make_unique<TestExternalElementsData>(std::move(data));
-}
-
 //===----------------------------------------------------------------------===//
 // TestDialect Interfaces
 //===----------------------------------------------------------------------===//
@@ -109,9 +60,18 @@ static_assert(OpTrait::hasSingleBlockImplicitTerminator<
               "hasSingleBlockImplicitTerminator does not match "
               "SingleBlockImplicitTerminatorOp");
 
+struct TestResourceBlobManagerInterface
+    : public ResourceBlobManagerDialectInterfaceBase<
+          TestDialectResourceBlobHandle> {
+  using ResourceBlobManagerDialectInterfaceBase<
+      TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
+};
+
 // Test support for interacting with the AsmPrinter.
 struct TestOpAsmInterface : public OpAsmDialectInterface {
   using OpAsmDialectInterface::OpAsmDialectInterface;
+  TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
+      : OpAsmDialectInterface(dialect), blobManager(mgr) {}
 
   //===------------------------------------------------------------------===//
   // Aliases
@@ -176,33 +136,21 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
 
   std::string
   getResourceKey(const AsmDialectResourceHandle &handle) const override {
-    return cast<TestExternalElementsDataHandle>(handle).getKey().str();
+    return cast<TestDialectResourceBlobHandle>(handle).getKey().str();
   }
 
   FailureOr<AsmDialectResourceHandle>
   declareResource(StringRef key) const final {
-    TestDialect *dialect = cast<TestDialect>(getDialect());
-    TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();
-
-    // Resolve the reference by inserting a new entry into the manager.
-    auto it = mgr.insert(key).first;
-    return TestExternalElementsDataHandle(&*it, dialect);
+    return blobManager.insert(key);
   }
 
   LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
-    TestDialect *dialect = cast<TestDialect>(getDialect());
-    TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();
-
-    // The resource entries are external constant data.
-    auto blobAllocFn = [](unsigned size, unsigned align) {
-      assert(align == alignof(uint64_t) && "unexpected data alignment");
-      return TestExternalElementsData::allocate(size / sizeof(uint64_t));
-    };
-    FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(blobAllocFn);
+    FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
     if (failed(blob))
       return failure();
 
-    mgr.setData(entry.getKey(), std::move(*blob));
+    // Update the blob for this entry.
+    blobManager.update(entry.getKey(), std::move(*blob));
     return success();
   }
 
@@ -210,11 +158,12 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
   buildResources(Operation *op,
                  const SetVector<AsmDialectResourceHandle> &referencedResources,
                  AsmResourceBuilder &provider) const final {
-    for (const AsmDialectResourceHandle &handle : referencedResources) {
-      const auto &testHandle = cast<TestExternalElementsDataHandle>(handle);
-      provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData());
-    }
+    blobManager.buildResources(provider, referencedResources.getArrayRef());
   }
+
+private:
+  /// The blob manager for the dialect.
+  TestResourceBlobManagerInterface &blobManager;
 };
 
 struct TestDialectFoldInterface : public DialectFoldInterface {
@@ -412,8 +361,11 @@ void TestDialect::initialize() {
   registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
 
-  addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
-                TestInlinerInterface, TestReductionPatternInterface>();
+  auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
+  addInterface<TestOpAsmInterface>(blobInterface);
+
+  addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
+                TestReductionPatternInterface>();
   allowUnknownOperations();
 
   // Instantiate our fallback op interface that we'll use on specific

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index e11a5f8098feb..0e583d38aa1a5 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -25,6 +25,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
@@ -45,68 +46,6 @@ class DLTIDialect;
 class RewritePatternSet;
 } // namespace mlir
 
-namespace test {
-class TestDialect;
-
-//===----------------------------------------------------------------------===//
-// External Elements Data
-//===----------------------------------------------------------------------===//
-
-/// This class represents a single external elements instance. It keeps track of
-/// the data, and deallocates when destructed.
-class TestExternalElementsData : public mlir::AsmResourceBlob {
-public:
-  using mlir::AsmResourceBlob::AsmResourceBlob;
-  TestExternalElementsData(mlir::AsmResourceBlob &&blob)
-      : mlir::AsmResourceBlob(std::move(blob)) {}
-
-  /// Return the data of this external elements instance.
-  llvm::ArrayRef<uint64_t> getData() const;
-
-  /// Allocate a new external elements instance with the given number of
-  /// elements.
-  static TestExternalElementsData allocate(size_t numElements);
-};
-
-/// A handle used to reference external elements instances.
-struct TestExternalElementsDataHandle
-    : public mlir::AsmDialectResourceHandleBase<
-          TestExternalElementsDataHandle,
-          llvm::StringMapEntry<std::unique_ptr<TestExternalElementsData>>,
-          TestDialect> {
-  using AsmDialectResourceHandleBase::AsmDialectResourceHandleBase;
-
-  /// Return a key to use for this handle.
-  llvm::StringRef getKey() const { return getResource()->getKey(); }
-
-  /// Return the data referenced by this handle.
-  TestExternalElementsData *getData() const {
-    return getResource()->getValue().get();
-  }
-};
-
-/// This class acts as a manager for external elements data. It provides API
-/// for creating and accessing registered elements data.
-class TestExternalElementsDataManager {
-  using DataMap = llvm::StringMap<std::unique_ptr<TestExternalElementsData>>;
-
-public:
-  /// Return the data registered for the given name, or nullptr if no data is
-  /// registered.
-  const TestExternalElementsData *getData(llvm::StringRef name) const;
-
-  /// Register an entry with the provided name, which may be modified if another
-  /// entry was already inserted with that name. Returns the inserted entry.
-  std::pair<DataMap::iterator, bool> insert(llvm::StringRef name);
-
-  /// Set the data for the given entry, which is expected to exist.
-  void setData(llvm::StringRef name, TestExternalElementsData &&data);
-
-private:
-  llvm::StringMap<std::unique_ptr<TestExternalElementsData>> dataMap;
-};
-} // namespace test
-
 //===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 891902487f806..d13d6abc77554 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -42,11 +42,6 @@ def Test_Dialect : Dialect {
      void printType(::mlir::Type type,
                     ::mlir::DialectAsmPrinter &printer) const override;
 
-    /// Returns the external elements data manager for this dialect.
-    TestExternalElementsDataManager &getExternalDataManager() {
-      return externalDataManager;
-    }
-
   private:
     // Storage for a custom fallback interface.
     void *fallbackEffectOpInterfaces;
@@ -55,9 +50,6 @@ def Test_Dialect : Dialect {
                                ::llvm::SetVector<::mlir::Type> &stack) const;
     void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer,
                        ::llvm::SetVector<::mlir::Type> &stack) const;
-
-    /// An external data manager used to test external elements data.
-    TestExternalElementsDataManager externalDataManager;
   }];
 }
 


        


More information about the Mlir-commits mailing list