[Mlir-commits] [mlir] b299ec1 - Expose callbacks for encoding of types/attributes

Mehdi Amini llvmlistbot at llvm.org
Fri Jul 28 10:44:31 PDT 2023


Author: Mehdi Amini
Date: 2023-07-28T10:44:02-07:00
New Revision: b299ec16661f653df66cdaf161cdc5441bc9803c

URL: https://github.com/llvm/llvm-project/commit/b299ec16661f653df66cdaf161cdc5441bc9803c
DIFF: https://github.com/llvm/llvm-project/commit/b299ec16661f653df66cdaf161cdc5441bc9803c.diff

LOG: Expose callbacks for encoding of types/attributes

[mlir] Expose a mechanism to provide a callback for encoding types and attributes in MLIR bytecode.

Two callbacks are exposed, respectively, to the BytecodeWriterConfig and to the ParserConfig. At bytecode parsing/printing, clients have the ability to specify a callback to be used to optionally read/write the encoding. On failure, fallback path will execute the default parsers and printers for the dialect.

Testing shows how to leverage this functionality to support back-deployment and backward-compatibility usecases when roundtripping to bytecode a client dialect with type/attributes dependencies on upstream.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D153383

Added: 
    mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
    mlir/test/Bytecode/bytecode_callback.mlir
    mlir/test/Bytecode/bytecode_callback_full_override.mlir
    mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
    mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
    mlir/test/lib/IR/TestBytecodeCallbacks.cpp

Modified: 
    mlir/include/mlir/Bytecode/BytecodeImplementation.h
    mlir/include/mlir/Bytecode/BytecodeReader.h
    mlir/include/mlir/Bytecode/BytecodeWriter.h
    mlir/include/mlir/IR/AsmState.h
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.cpp
    mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 9c9aa7a4fc0ed1..bb1f0f717d8001 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -24,6 +24,17 @@
 #include "llvm/ADT/Twine.h"
 
 namespace mlir {
+//===--------------------------------------------------------------------===//
+// Dialect Version Interface.
+//===--------------------------------------------------------------------===//
+
+/// This class is used to represent the version of a dialect, for the purpose
+/// of polymorphic destruction.
+class DialectVersion {
+public:
+  virtual ~DialectVersion() = default;
+};
+
 //===----------------------------------------------------------------------===//
 // DialectBytecodeReader
 //===----------------------------------------------------------------------===//
@@ -38,7 +49,14 @@ class DialectBytecodeReader {
   virtual ~DialectBytecodeReader() = default;
 
   /// Emit an error to the reader.
-  virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
+  virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
+
+  /// Retrieve the dialect version by name if available.
+  virtual FailureOr<const DialectVersion *>
+  getDialectVersion(StringRef dialectName) const = 0;
+
+  /// Retrieve the context associated to the reader.
+  virtual MLIRContext *getContext() const = 0;
 
   /// Return the bytecode version being read.
   virtual uint64_t getBytecodeVersion() const = 0;
@@ -384,17 +402,6 @@ class DialectBytecodeWriter {
   virtual int64_t getBytecodeVersion() const = 0;
 };
 
-//===--------------------------------------------------------------------===//
-// Dialect Version Interface.
-//===--------------------------------------------------------------------===//
-
-/// This class is used to represent the version of a dialect, for the purpose
-/// of polymorphic destruction.
-class DialectVersion {
-public:
-  virtual ~DialectVersion() = default;
-};
-
 //===----------------------------------------------------------------------===//
 // BytecodeDialectInterface
 //===----------------------------------------------------------------------===//
@@ -409,47 +416,23 @@ class BytecodeDialectInterface
   //===--------------------------------------------------------------------===//
 
   /// Read an attribute belonging to this dialect from the given reader. This
-  /// method should return null in the case of failure.
+  /// method should return null in the case of failure. Optionally, the dialect
+  /// version can be accessed through the reader.
   virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
     reader.emitError() << "dialect " << getDialect()->getNamespace()
                        << " does not support reading attributes from bytecode";
     return Attribute();
   }
 
-  /// Read a versioned attribute encoding belonging to this dialect from the
-  /// given reader. This method should return null in the case of failure, and
-  /// falls back to the non-versioned reader in case the dialect implements
-  /// versioning but it does not support versioned custom encodings for the
-  /// attributes.
-  virtual Attribute readAttribute(DialectBytecodeReader &reader,
-                                  const DialectVersion &version) const {
-    reader.emitError()
-        << "dialect " << getDialect()->getNamespace()
-        << " does not support reading versioned attributes from bytecode";
-    return Attribute();
-  }
-
   /// Read a type belonging to this dialect from the given reader. This method
-  /// should return null in the case of failure.
+  /// should return null in the case of failure. Optionally, the dialect version
+  /// can be accessed thorugh the reader.
   virtual Type readType(DialectBytecodeReader &reader) const {
     reader.emitError() << "dialect " << getDialect()->getNamespace()
                        << " does not support reading types from bytecode";
     return Type();
   }
 
-  /// Read a versioned type encoding belonging to this dialect from the given
-  /// reader. This method should return null in the case of failure, and
-  /// falls back to the non-versioned reader in case the dialect implements
-  /// versioning but it does not support versioned custom encodings for the
-  /// types.
-  virtual Type readType(DialectBytecodeReader &reader,
-                        const DialectVersion &version) const {
-    reader.emitError()
-        << "dialect " << getDialect()->getNamespace()
-        << " does not support reading versioned types from bytecode";
-    return Type();
-  }
-
   //===--------------------------------------------------------------------===//
   // Writing
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h
index 206e42870ad85a..9f26506d486eec 100644
--- a/mlir/include/mlir/Bytecode/BytecodeReader.h
+++ b/mlir/include/mlir/Bytecode/BytecodeReader.h
@@ -25,7 +25,6 @@ class SourceMgr;
 } // namespace llvm
 
 namespace mlir {
-
 /// The BytecodeReader allows to load MLIR bytecode files, while keeping the
 /// state explicitly available in order to support lazy loading.
 /// The `finalize` method must be called before destruction.

diff  --git a/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h b/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
new file mode 100644
index 00000000000000..d623d0da2c0c90
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
@@ -0,0 +1,120 @@
+//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines interfaces to read MLIR bytecode files/streams.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEREADERCONFIG_H
+#define MLIR_BYTECODE_BYTECODEREADERCONFIG_H
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+class Attribute;
+class DialectBytecodeReader;
+class Type;
+
+/// A class to interact with the attributes and types parser when parsing MLIR
+/// bytecode.
+template <class T>
+class AttrTypeBytecodeReader {
+public:
+  AttrTypeBytecodeReader() = default;
+  virtual ~AttrTypeBytecodeReader() = default;
+
+  virtual LogicalResult read(DialectBytecodeReader &reader,
+                             StringRef dialectName, T &entry) = 0;
+
+  /// Return an Attribute/Type printer implemented via the given callable, whose
+  /// form should match that of the `parse` function above.
+  template <typename CallableT,
+            std::enable_if_t<
+                std::is_convertible_v<
+                    CallableT, std::function<LogicalResult(
+                                   DialectBytecodeReader &, StringRef, T &)>>,
+                bool> = true>
+  static std::unique_ptr<AttrTypeBytecodeReader<T>>
+  fromCallable(CallableT &&readFn) {
+    struct Processor : public AttrTypeBytecodeReader<T> {
+      Processor(CallableT &&readFn)
+          : AttrTypeBytecodeReader(), readFn(std::move(readFn)) {}
+      LogicalResult read(DialectBytecodeReader &reader, StringRef dialectName,
+                         T &entry) override {
+        return readFn(reader, dialectName, entry);
+      }
+
+      std::decay_t<CallableT> readFn;
+    };
+    return std::make_unique<Processor>(std::forward<CallableT>(readFn));
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// BytecodeReaderConfig
+//===----------------------------------------------------------------------===//
+
+/// A class containing bytecode-specific configurations of the `ParserConfig`.
+class BytecodeReaderConfig {
+public:
+  BytecodeReaderConfig() = default;
+
+  /// Returns the callbacks available to the parser.
+  ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
+  getAttributeCallbacks() const {
+    return attributeBytecodeParsers;
+  }
+  ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
+  getTypeCallbacks() const {
+    return typeBytecodeParsers;
+  }
+
+  /// Attach a custom bytecode parser callback to the configuration for parsing
+  /// of custom type/attributes encodings.
+  void attachAttributeCallback(
+      std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
+    attributeBytecodeParsers.emplace_back(std::move(parser));
+  }
+  void
+  attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) {
+    typeBytecodeParsers.emplace_back(std::move(parser));
+  }
+
+  /// Attach a custom bytecode parser callback to the configuration for parsing
+  /// of custom type/attributes encodings.
+  template <typename CallableT>
+  std::enable_if_t<std::is_convertible_v<
+      CallableT, std::function<LogicalResult(DialectBytecodeReader &, StringRef,
+                                             Attribute &)>>>
+  attachAttributeCallback(CallableT &&parserFn) {
+    attachAttributeCallback(AttrTypeBytecodeReader<Attribute>::fromCallable(
+        std::forward<CallableT>(parserFn)));
+  }
+  template <typename CallableT>
+  std::enable_if_t<std::is_convertible_v<
+      CallableT,
+      std::function<LogicalResult(DialectBytecodeReader &, StringRef, Type &)>>>
+  attachTypeCallback(CallableT &&parserFn) {
+    attachTypeCallback(AttrTypeBytecodeReader<Type>::fromCallable(
+        std::forward<CallableT>(parserFn)));
+  }
+
+private:
+  llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
+      attributeBytecodeParsers;
+  llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
+      typeBytecodeParsers;
+};
+
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEREADERCONFIG_H

diff  --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index c6df1a21a55bb4..e0c46c3dab27a7 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -17,6 +17,55 @@
 
 namespace mlir {
 class Operation;
+class DialectBytecodeWriter;
+
+/// A class to interact with the attributes and types printer when emitting MLIR
+/// bytecode.
+template <class T>
+class AttrTypeBytecodeWriter {
+public:
+  AttrTypeBytecodeWriter() = default;
+  virtual ~AttrTypeBytecodeWriter() = default;
+
+  /// Callback writer API used in IRNumbering, where groups are created and
+  /// type/attribute components are numbered. At this stage, writer is expected
+  /// to be a `NumberingDialectWriter`.
+  virtual LogicalResult write(T entry, std::optional<StringRef> &name,
+                              DialectBytecodeWriter &writer) = 0;
+
+  /// Callback writer API used in BytecodeWriter, where groups are created and
+  /// type/attribute components are numbered. Here, DialectBytecodeWriter is
+  /// expected to be an actual writer. The optional stringref specified by
+  /// the user is ignored, since the group was already specified when numbering
+  /// the IR.
+  LogicalResult write(T entry, DialectBytecodeWriter &writer) {
+    std::optional<StringRef> dummy;
+    return write(entry, dummy, writer);
+  }
+
+  /// Return an Attribute/Type printer implemented via the given callable, whose
+  /// form should match that of the `write` function above.
+  template <typename CallableT,
+            std::enable_if_t<std::is_convertible_v<
+                                 CallableT, std::function<LogicalResult(
+                                                T, std::optional<StringRef> &,
+                                                DialectBytecodeWriter &)>>,
+                             bool> = true>
+  static std::unique_ptr<AttrTypeBytecodeWriter<T>>
+  fromCallable(CallableT &&writeFn) {
+    struct Processor : public AttrTypeBytecodeWriter<T> {
+      Processor(CallableT &&writeFn)
+          : AttrTypeBytecodeWriter(), writeFn(std::move(writeFn)) {}
+      LogicalResult write(T entry, std::optional<StringRef> &name,
+                          DialectBytecodeWriter &writer) override {
+        return writeFn(entry, name, writer);
+      }
+
+      std::decay_t<CallableT> writeFn;
+    };
+    return std::make_unique<Processor>(std::forward<CallableT>(writeFn));
+  }
+};
 
 /// This class contains the configuration used for the bytecode writer. It
 /// controls various aspects of bytecode generation, and contains all of the
@@ -48,6 +97,43 @@ class BytecodeWriterConfig {
   /// Get the set desired bytecode version to emit.
   int64_t getDesiredBytecodeVersion() const;
 
+  //===--------------------------------------------------------------------===//
+  // Types and Attributes encoding
+  //===--------------------------------------------------------------------===//
+
+  /// Retrieve the callbacks.
+  ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
+  getAttributeWriterCallbacks() const;
+  ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
+  getTypeWriterCallbacks() const;
+
+  /// Attach a custom bytecode printer callback to the configuration for the
+  /// emission of custom type/attributes encodings.
+  void attachAttributeCallback(
+      std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback);
+  void
+  attachTypeCallback(std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback);
+
+  /// Attach a custom bytecode printer callback to the configuration for the
+  /// emission of custom type/attributes encodings.
+  template <typename CallableT>
+  std::enable_if_t<std::is_convertible_v<
+      CallableT,
+      std::function<LogicalResult(Attribute, std::optional<StringRef> &,
+                                  DialectBytecodeWriter &)>>>
+  attachAttributeCallback(CallableT &&emitFn) {
+    attachAttributeCallback(AttrTypeBytecodeWriter<Attribute>::fromCallable(
+        std::forward<CallableT>(emitFn)));
+  }
+  template <typename CallableT>
+  std::enable_if_t<std::is_convertible_v<
+      CallableT, std::function<LogicalResult(Type, std::optional<StringRef> &,
+                                             DialectBytecodeWriter &)>>>
+  attachTypeCallback(CallableT &&emitFn) {
+    attachTypeCallback(AttrTypeBytecodeWriter<Type>::fromCallable(
+        std::forward<CallableT>(emitFn)));
+  }
+
   //===--------------------------------------------------------------------===//
   // Resources
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 2abeacb8443280..42cbedcf9f8837 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_IR_ASMSTATE_H_
 #define MLIR_IR_ASMSTATE_H_
 
+#include "mlir/Bytecode/BytecodeReaderConfig.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/MapVector.h"
@@ -475,6 +476,11 @@ class ParserConfig {
   /// Returns if the parser should verify the IR after parsing.
   bool shouldVerifyAfterParse() const { return verifyAfterParse; }
 
+  /// Returns the parsing configurations associated to the bytecode read.
+  BytecodeReaderConfig &getBytecodeReaderConfig() const {
+    return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
+  }
+
   /// Return the resource parser registered to the given name, or nullptr if no
   /// parser with `name` is registered.
   AsmResourceParser *getResourceParser(StringRef name) const {
@@ -509,6 +515,7 @@ class ParserConfig {
   bool verifyAfterParse;
   DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
   FallbackAsmResourceMap *fallbackResourceMap;
+  BytecodeReaderConfig bytecodeReaderConfig;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 0639baf10b0bc0..91e47c4c0e4784 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -451,7 +451,7 @@ struct BytecodeDialect {
   /// Returns failure if the dialect couldn't be loaded *and* the provided
   /// context does not allow unregistered dialects. The provided reader is used
   /// for error emission if necessary.
-  LogicalResult load(DialectReader &reader, MLIRContext *ctx);
+  LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
 
   /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
   /// only be called after `load`.
@@ -505,10 +505,11 @@ struct BytecodeOperationName {
 
 /// Parse a single dialect group encoded in the byte stream.
 static LogicalResult parseDialectGrouping(
-    EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
+    EncodingReader &reader,
+    MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
     function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
   // Parse the dialect and the number of entries in the group.
-  BytecodeDialect *dialect;
+  std::unique_ptr<BytecodeDialect> *dialect;
   if (failed(parseEntry(reader, dialects, dialect, "dialect")))
     return failure();
   uint64_t numEntries;
@@ -516,7 +517,7 @@ static LogicalResult parseDialectGrouping(
     return failure();
 
   for (uint64_t i = 0; i < numEntries; ++i)
-    if (failed(entryCallback(dialect)))
+    if (failed(entryCallback(dialect->get())))
       return failure();
   return success();
 }
@@ -532,7 +533,7 @@ class ResourceSectionReader {
   /// Initialize the resource section reader with the given section data.
   LogicalResult
   initialize(Location fileLoc, const ParserConfig &config,
-             MutableArrayRef<BytecodeDialect> dialects,
+             MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
              StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
              ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
              const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
@@ -682,7 +683,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
 
 LogicalResult ResourceSectionReader::initialize(
     Location fileLoc, const ParserConfig &config,
-    MutableArrayRef<BytecodeDialect> dialects,
+    MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
     StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
     ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
     const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
@@ -731,19 +732,19 @@ LogicalResult ResourceSectionReader::initialize(
   // Read the dialect resources from the bytecode.
   MLIRContext *ctx = fileLoc->getContext();
   while (!offsetReader.empty()) {
-    BytecodeDialect *dialect;
+    std::unique_ptr<BytecodeDialect> *dialect;
     if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
-        failed(dialect->load(dialectReader, ctx)))
+        failed((*dialect)->load(dialectReader, ctx)))
       return failure();
-    Dialect *loadedDialect = dialect->getLoadedDialect();
+    Dialect *loadedDialect = (*dialect)->getLoadedDialect();
     if (!loadedDialect) {
       return resourceReader.emitError()
-             << "dialect '" << dialect->name << "' is unknown";
+             << "dialect '" << (*dialect)->name << "' is unknown";
     }
     const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
     if (!handler) {
       return resourceReader.emitError()
-             << "unexpected resources for dialect '" << dialect->name << "'";
+             << "unexpected resources for dialect '" << (*dialect)->name << "'";
     }
 
     // Ensure that each resource is declared before being processed.
@@ -753,7 +754,7 @@ LogicalResult ResourceSectionReader::initialize(
       if (failed(handle)) {
         return resourceReader.emitError()
                << "unknown 'resource' key '" << key << "' for dialect '"
-               << dialect->name << "'";
+               << (*dialect)->name << "'";
       }
       dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
       dialectResources.push_back(*handle);
@@ -796,15 +797,19 @@ class AttrTypeReader {
 
 public:
   AttrTypeReader(StringSectionReader &stringReader,
-                 ResourceSectionReader &resourceReader, Location fileLoc,
-                 uint64_t &bytecodeVersion)
+                 ResourceSectionReader &resourceReader,
+                 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
+                 uint64_t &bytecodeVersion, Location fileLoc,
+                 const ParserConfig &config)
       : stringReader(stringReader), resourceReader(resourceReader),
-        fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
+        dialectsMap(dialectsMap), fileLoc(fileLoc),
+        bytecodeVersion(bytecodeVersion), parserConfig(config) {}
 
   /// Initialize the attribute and type information within the reader.
-  LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
-                           ArrayRef<uint8_t> sectionData,
-                           ArrayRef<uint8_t> offsetSectionData);
+  LogicalResult
+  initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+             ArrayRef<uint8_t> sectionData,
+             ArrayRef<uint8_t> offsetSectionData);
 
   /// Resolve the attribute or type at the given index. Returns nullptr on
   /// failure.
@@ -878,6 +883,10 @@ class AttrTypeReader {
   /// parsing custom encoded attribute/type entries.
   ResourceSectionReader &resourceReader;
 
+  /// The map of the loaded dialects used to retrieve dialect information, such
+  /// as the dialect version.
+  const llvm::StringMap<BytecodeDialect *> &dialectsMap;
+
   /// The set of attribute and type entries.
   SmallVector<AttrEntry> attributes;
   SmallVector<TypeEntry> types;
@@ -887,27 +896,48 @@ class AttrTypeReader {
 
   /// Current bytecode version being used.
   uint64_t &bytecodeVersion;
+
+  /// Reference to the parser configuration.
+  const ParserConfig &parserConfig;
 };
 
 class DialectReader : public DialectBytecodeReader {
 public:
   DialectReader(AttrTypeReader &attrTypeReader,
                 StringSectionReader &stringReader,
-                ResourceSectionReader &resourceReader, EncodingReader &reader,
-                uint64_t &bytecodeVersion)
+                ResourceSectionReader &resourceReader,
+                const llvm::StringMap<BytecodeDialect *> &dialectsMap,
+                EncodingReader &reader, uint64_t &bytecodeVersion)
       : attrTypeReader(attrTypeReader), stringReader(stringReader),
-        resourceReader(resourceReader), reader(reader),
-        bytecodeVersion(bytecodeVersion) {}
+        resourceReader(resourceReader), dialectsMap(dialectsMap),
+        reader(reader), bytecodeVersion(bytecodeVersion) {}
 
-  InFlightDiagnostic emitError(const Twine &msg) override {
+  InFlightDiagnostic emitError(const Twine &msg) const override {
     return reader.emitError(msg);
   }
 
+  FailureOr<const DialectVersion *>
+  getDialectVersion(StringRef dialectName) const override {
+    // First check if the dialect is available in the map.
+    auto dialectEntry = dialectsMap.find(dialectName);
+    if (dialectEntry == dialectsMap.end())
+      return failure();
+    // If the dialect was found, try to load it. This will trigger reading the
+    // bytecode version from the version buffer if it wasn't already processed.
+    // Return failure if either of those two actions could not be completed.
+    if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
+        dialectEntry->getValue()->loadedVersion.get() == nullptr)
+      return failure();
+    return dialectEntry->getValue()->loadedVersion.get();
+  }
+
+  MLIRContext *getContext() const override { return getLoc().getContext(); }
+
   uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
 
-  DialectReader withEncodingReader(EncodingReader &encReader) {
+  DialectReader withEncodingReader(EncodingReader &encReader) const {
     return DialectReader(attrTypeReader, stringReader, resourceReader,
-                         encReader, bytecodeVersion);
+                         dialectsMap, encReader, bytecodeVersion);
   }
 
   Location getLoc() const { return reader.getLoc(); }
@@ -1010,6 +1040,7 @@ class DialectReader : public DialectBytecodeReader {
   AttrTypeReader &attrTypeReader;
   StringSectionReader &stringReader;
   ResourceSectionReader &resourceReader;
+  const llvm::StringMap<BytecodeDialect *> &dialectsMap;
   EncodingReader &reader;
   uint64_t &bytecodeVersion;
 };
@@ -1096,10 +1127,9 @@ class PropertiesSectionReader {
 };
 } // namespace
 
-LogicalResult
-AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
-                           ArrayRef<uint8_t> sectionData,
-                           ArrayRef<uint8_t> offsetSectionData) {
+LogicalResult AttrTypeReader::initialize(
+    MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+    ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
   EncodingReader offsetReader(offsetSectionData, fileLoc);
 
   // Parse the number of attribute and type entries.
@@ -1151,6 +1181,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
     return offsetReader.emitError(
         "unexpected trailing data in the Attribute/Type offset section");
   }
+
   return success();
 }
 
@@ -1216,32 +1247,54 @@ template <typename T>
 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
                                                EncodingReader &reader,
                                                StringRef entryType) {
-  DialectReader dialectReader(*this, stringReader, resourceReader, reader,
-                              bytecodeVersion);
+  DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
+                              reader, bytecodeVersion);
   if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
     return failure();
+
+  if constexpr (std::is_same_v<T, Type>) {
+    // Try parsing with callbacks first if available.
+    for (const auto &callback :
+         parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
+      if (failed(
+              callback->read(dialectReader, entry.dialect->name, entry.entry)))
+        return failure();
+      // Early return if parsing was successful.
+      if (!!entry.entry)
+        return success();
+
+      // Reset the reader if we failed to parse, so we can fall through the
+      // other parsing functions.
+      reader = EncodingReader(entry.data, reader.getLoc());
+    }
+  } else {
+    // Try parsing with callbacks first if available.
+    for (const auto &callback :
+         parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
+      if (failed(
+              callback->read(dialectReader, entry.dialect->name, entry.entry)))
+        return failure();
+      // Early return if parsing was successful.
+      if (!!entry.entry)
+        return success();
+
+      // Reset the reader if we failed to parse, so we can fall through the
+      // other parsing functions.
+      reader = EncodingReader(entry.data, reader.getLoc());
+    }
+  }
+
   // Ensure that the dialect implements the bytecode interface.
   if (!entry.dialect->interface) {
     return reader.emitError("dialect '", entry.dialect->name,
                             "' does not implement the bytecode interface");
   }
 
-  // Ask the dialect to parse the entry. If the dialect is versioned, parse
-  // using the versioned encoding readers.
-  if (entry.dialect->loadedVersion.get()) {
-    if constexpr (std::is_same_v<T, Type>)
-      entry.entry = entry.dialect->interface->readType(
-          dialectReader, *entry.dialect->loadedVersion);
-    else
-      entry.entry = entry.dialect->interface->readAttribute(
-          dialectReader, *entry.dialect->loadedVersion);
+  if constexpr (std::is_same_v<T, Type>)
+    entry.entry = entry.dialect->interface->readType(dialectReader);
+  else
+    entry.entry = entry.dialect->interface->readAttribute(dialectReader);
 
-  } else {
-    if constexpr (std::is_same_v<T, Type>)
-      entry.entry = entry.dialect->interface->readType(dialectReader);
-    else
-      entry.entry = entry.dialect->interface->readAttribute(dialectReader);
-  }
   return success(!!entry.entry);
 }
 
@@ -1262,7 +1315,8 @@ class mlir::BytecodeReader::Impl {
        llvm::MemoryBufferRef buffer,
        const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
       : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
-        attrTypeReader(stringReader, resourceReader, fileLoc, version),
+        attrTypeReader(stringReader, resourceReader, dialectsMap, version,
+                       fileLoc, config),
         // Use the builtin unrealized conversion cast operation to represent
         // forward references to values that aren't yet defined.
         forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -1528,7 +1582,8 @@ class mlir::BytecodeReader::Impl {
   StringRef producer;
 
   /// The table of IR units referenced within the bytecode file.
-  SmallVector<BytecodeDialect> dialects;
+  SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
+  llvm::StringMap<BytecodeDialect *> dialectsMap;
   SmallVector<BytecodeOperationName> opNames;
 
   /// The reader used to process resources within the bytecode.
@@ -1675,7 +1730,8 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
 //===----------------------------------------------------------------------===//
 // Dialect Section
 
-LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
+LogicalResult BytecodeDialect::load(const DialectReader &reader,
+                                    MLIRContext *ctx) {
   if (dialect)
     return success();
   Dialect *loadedDialect = ctx->getOrLoadDialect(name);
@@ -1719,13 +1775,15 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
 
   // Parse each of the dialects.
   for (uint64_t i = 0; i < numDialects; ++i) {
+    dialects[i] = std::make_unique<BytecodeDialect>();
     /// Before version kDialectVersioning, there wasn't any versioning available
     /// for dialects, and the entryIdx represent the string itself.
     if (version < bytecode::kDialectVersioning) {
-      if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+      if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
         return failure();
       continue;
     }
+
     // Parse ID representing dialect and version.
     uint64_t dialectNameIdx;
     bool versionAvailable;
@@ -1733,18 +1791,19 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
                                                  versionAvailable)))
       return failure();
     if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
-                                               dialects[i].name)))
+                                               dialects[i]->name)))
       return failure();
     if (versionAvailable) {
       bytecode::Section::ID sectionID;
-      if (failed(
-              sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
+      if (failed(sectionReader.parseSection(sectionID,
+                                            dialects[i]->versionBuffer)))
         return failure();
       if (sectionID != bytecode::Section::kDialectVersions) {
         emitError(fileLoc, "expected dialect version section");
         return failure();
       }
     }
+    dialectsMap[dialects[i]->name] = dialects[i].get();
   }
 
   // Parse the operation names, which are grouped by dialect.
@@ -1792,7 +1851,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
   if (!opName->opName) {
     // Load the dialect and its version.
     DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
-                                reader, version);
+                                dialectsMap, reader, version);
     if (failed(opName->dialect->load(dialectReader, getContext())))
       return failure();
     // If the opName is empty, this is because we use to accept names such as
@@ -1835,7 +1894,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
 
   // Initialize the resource reader with the resource sections.
   DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
-                              reader, version);
+                              dialectsMap, reader, version);
   return resourceReader.initialize(fileLoc, config, dialects, stringReader,
                                    *resourceData, *resourceOffsetData,
                                    dialectReader, bufferOwnerRef);
@@ -2036,14 +2095,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
         "parsed use-list orders were invalid and could not be applied");
 
   // Resolve dialect version.
-  for (const BytecodeDialect &byteCodeDialect : dialects) {
+  for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
     // Parsing is complete, give an opportunity to each dialect to visit the
     // IR and perform upgrades.
-    if (!byteCodeDialect.loadedVersion)
+    if (!byteCodeDialect->loadedVersion)
       continue;
-    if (byteCodeDialect.interface &&
-        failed(byteCodeDialect.interface->upgradeFromVersion(
-            *moduleOp, *byteCodeDialect.loadedVersion)))
+    if (byteCodeDialect->interface &&
+        failed(byteCodeDialect->interface->upgradeFromVersion(
+            *moduleOp, *byteCodeDialect->loadedVersion)))
       return failure();
   }
 
@@ -2196,7 +2255,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
     // interface and control the serialization.
     if (wasRegistered) {
       DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
-                                  reader, version);
+                                  dialectsMap, reader, version);
       if (failed(
               propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
         return failure();

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index d8f2cb106510d9..75315b5ec75e3d 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -18,15 +18,10 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/CachedHashString.h"
 #include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/raw_ostream.h"
 #include "llvm/Support/Endian.h"
-#include <cstddef>
-#include <cstdint>
-#include <cstring>
+#include "llvm/Support/raw_ostream.h"
 #include <optional>
-#include <sys/types.h>
 
 #define DEBUG_TYPE "mlir-bytecode-writer"
 
@@ -47,6 +42,12 @@ struct BytecodeWriterConfig::Impl {
   /// The producer of the bytecode.
   StringRef producer;
 
+  /// Printer callbacks used to emit custom type and attribute encodings.
+  llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
+      attributeWriterCallbacks;
+  llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
+      typeWriterCallbacks;
+
   /// A collection of non-dialect resource printers.
   SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
 };
@@ -60,6 +61,26 @@ BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map,
 }
 BytecodeWriterConfig::~BytecodeWriterConfig() = default;
 
+ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
+BytecodeWriterConfig::getAttributeWriterCallbacks() const {
+  return impl->attributeWriterCallbacks;
+}
+
+ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
+BytecodeWriterConfig::getTypeWriterCallbacks() const {
+  return impl->typeWriterCallbacks;
+}
+
+void BytecodeWriterConfig::attachAttributeCallback(
+    std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) {
+  impl->attributeWriterCallbacks.emplace_back(std::move(callback));
+}
+
+void BytecodeWriterConfig::attachTypeCallback(
+    std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) {
+  impl->typeWriterCallbacks.emplace_back(std::move(callback));
+}
+
 void BytecodeWriterConfig::attachResourcePrinter(
     std::unique_ptr<AsmResourcePrinter> printer) {
   impl->externalResourcePrinters.emplace_back(std::move(printer));
@@ -774,32 +795,50 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
   auto emitAttrOrType = [&](auto &entry) {
     auto entryValue = entry.getValue();
 
-    // First, try to emit this entry using the dialect bytecode interface.
-    bool hasCustomEncoding = false;
-    if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
-      // The writer used when emitting using a custom bytecode encoding.
+    auto emitAttrOrTypeRawImpl = [&]() -> void {
+      RawEmitterOstream(attrTypeEmitter) << entryValue;
+      attrTypeEmitter.emitByte(0);
+    };
+    auto emitAttrOrTypeImpl = [&]() -> bool {
+      // TODO: We don't currently support custom encoded mutable types and
+      // attributes.
+      if (entryValue.template hasTrait<TypeTrait::IsMutable>() ||
+          entryValue.template hasTrait<AttributeTrait::IsMutable>()) {
+        emitAttrOrTypeRawImpl();
+        return false;
+      }
+
       DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
                                   numberingState, stringSection);
-
       if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
-        // TODO: We don't currently support custom encoded mutable types.
-        hasCustomEncoding =
-            !entryValue.template hasTrait<TypeTrait::IsMutable>() &&
-            succeeded(interface->writeType(entryValue, dialectWriter));
+        for (const auto &callback : config.typeWriterCallbacks) {
+          if (succeeded(callback->write(entryValue, dialectWriter)))
+            return true;
+        }
+        if (const BytecodeDialectInterface *interface =
+                entry.dialect->interface) {
+          if (succeeded(interface->writeType(entryValue, dialectWriter)))
+            return true;
+        }
       } else {
-        // TODO: We don't currently support custom encoded mutable attributes.
-        hasCustomEncoding =
-            !entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
-            succeeded(interface->writeAttribute(entryValue, dialectWriter));
+        for (const auto &callback : config.attributeWriterCallbacks) {
+          if (succeeded(callback->write(entryValue, dialectWriter)))
+            return true;
+        }
+        if (const BytecodeDialectInterface *interface =
+                entry.dialect->interface) {
+          if (succeeded(interface->writeAttribute(entryValue, dialectWriter)))
+            return true;
+        }
       }
-    }
 
-    // If the entry was not emitted using the dialect interface, emit it using
-    // the textual format.
-    if (!hasCustomEncoding) {
-      RawEmitterOstream(attrTypeEmitter) << entryValue;
-      attrTypeEmitter.emitByte(0);
-    }
+      // If the entry was not emitted using a callback or a dialect interface,
+      // emit it using the textual format.
+      emitAttrOrTypeRawImpl();
+      return false;
+    };
+
+    bool hasCustomEncoding = emitAttrOrTypeImpl();
 
     // Record the offset of this entry.
     uint64_t curOffset = attrTypeEmitter.size();

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index ef643ca6d74c76..67f929059e4709 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -314,9 +314,22 @@ void IRNumberingState::number(Attribute attr) {
 
   // If this attribute will be emitted using the bytecode format, perform a
   // dummy writing to number any nested components.
-  if (const auto *interface = numbering->dialect->interface) {
-    // TODO: We don't allow custom encodings for mutable attributes right now.
-    if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
+  // TODO: We don't allow custom encodings for mutable attributes right now.
+  if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
+    // Try overriding emission with callbacks.
+    for (const auto &callback : config.getAttributeWriterCallbacks()) {
+      NumberingDialectWriter writer(*this);
+      // The client has the ability to override the group name through the
+      // callback.
+      std::optional<StringRef> groupNameOverride;
+      if (succeeded(callback->write(attr, groupNameOverride, writer))) {
+        if (groupNameOverride.has_value())
+          numbering->dialect = &numberDialect(*groupNameOverride);
+        return;
+      }
+    }
+
+    if (const auto *interface = numbering->dialect->interface) {
       NumberingDialectWriter writer(*this);
       if (succeeded(interface->writeAttribute(attr, writer)))
         return;
@@ -464,9 +477,24 @@ void IRNumberingState::number(Type type) {
 
   // If this type will be emitted using the bytecode format, perform a dummy
   // writing to number any nested components.
-  if (const auto *interface = numbering->dialect->interface) {
-    // TODO: We don't allow custom encodings for mutable types right now.
-    if (!type.hasTrait<TypeTrait::IsMutable>()) {
+  // TODO: We don't allow custom encodings for mutable types right now.
+  if (!type.hasTrait<TypeTrait::IsMutable>()) {
+    // Try overriding emission with callbacks.
+    for (const auto &callback : config.getTypeWriterCallbacks()) {
+      NumberingDialectWriter writer(*this);
+      // The client has the ability to override the group name through the
+      // callback.
+      std::optional<StringRef> groupNameOverride;
+      if (succeeded(callback->write(type, groupNameOverride, writer))) {
+        if (groupNameOverride.has_value())
+          numbering->dialect = &numberDialect(*groupNameOverride);
+        return;
+      }
+    }
+
+    // If this attribute will be emitted using the bytecode format, perform a
+    // dummy writing to number any nested components.
+    if (const auto *interface = numbering->dialect->interface) {
       NumberingDialectWriter writer(*this);
       if (succeeded(interface->writeType(type, writer)))
         return;

diff  --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir
new file mode 100644
index 00000000000000..cf3981c86b9442
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_callback.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
+// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
+
+func.func @base_test(%arg0 : i32) -> f32 {
+  %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
+  %1 = "test.cast"(%0) : (i32) -> f32
+  return %1 : f32
+}
+
+// VERSION_1_2: Overriding IntegerType encoding...
+// VERSION_1_2: Overriding parsing of IntegerType encoding...
+
+// VERSION_2_0-NOT: Overriding IntegerType encoding...
+// VERSION_2_0-NOT: Overriding parsing of IntegerType encoding...

diff  --git a/mlir/test/Bytecode/bytecode_callback_full_override.mlir b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
new file mode 100644
index 00000000000000..21ff947ad389b6
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
@@ -0,0 +1,18 @@
+// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s
+
+// CHECK-NOT: failed to read bytecode
+func.func @base_test(%arg0 : i32) -> f32 {
+  %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
+  %1 = "test.cast"(%0) : (i32) -> f32
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: error: unknown attribute code: 99
+// CHECK: failed to read bytecode
+func.func @base_test(%arg0 : !test.i32) -> f32 {
+  %0 = "test.addi"(%arg0, %arg0) : (!test.i32, !test.i32) -> !test.i32
+  %1 = "test.cast"(%0) : (!test.i32) -> f32
+  return %1 : f32
+}

diff  --git a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
new file mode 100644
index 00000000000000..487972f85af5be
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3
+// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4
+
+"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
+
+// TEST_3: Overriding TestAttrParamsAttr encoding...
+// TEST_3: "test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> ()
+
+// -----
+
+"test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> ()
+
+// TEST_4: Overriding parsing of TestAttrParamsAttr encoding...
+// TEST_4: "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()

diff  --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
new file mode 100644
index 00000000000000..1e272ec4f3afc2
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1
+// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2
+
+func.func @base_test(%arg0: !test.i32, %arg1: f32) {
+  return
+}
+
+// TEST_1: Overriding TestI32Type encoding...
+// TEST_1: func.func @base_test([[ARG0:%.+]]: i32, [[ARG1:%.+]]: f32) {
+
+// -----
+
+func.func @base_test(%arg0: i32, %arg1: f32) {
+  return
+}
+
+// TEST_2: Overriding parsing of TestI32Type encoding...
+// TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) {

diff  --git a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir
index aba6b3fd1a34aa..87beaa6dd7a056 100644
--- a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir
+++ b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir
@@ -5,12 +5,12 @@
 // Index
 //===--------------------------------------------------------------------===//
 
-// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX
+// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=INDEX
 // INDEX: invalid Attribute index: 3
 
 //===--------------------------------------------------------------------===//
 // Trailing Data
 //===--------------------------------------------------------------------===//
 
-// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA
+// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA
 // TRAILING_DATA: trailing characters found after Attribute assembly format: trailing

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 34936783d62ae1..c3235b7b7c68b4 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -14,9 +14,10 @@
 #ifndef MLIR_TESTDIALECT_H
 #define MLIR_TESTDIALECT_H
 
-#include "TestTypes.h"
 #include "TestAttributes.h"
 #include "TestInterfaces.h"
+#include "TestTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -57,6 +58,19 @@ class RewritePatternSet;
 #include "TestOpsDialect.h.inc"
 
 namespace test {
+
+//===----------------------------------------------------------------------===//
+// TestDialect version utilities
+//===----------------------------------------------------------------------===//
+
+struct TestDialectVersion : public mlir::DialectVersion {
+  TestDialectVersion() = default;
+  TestDialectVersion(uint32_t _major, uint32_t _minor)
+      : major(_major), minor(_minor){};
+  uint32_t major = 2;
+  uint32_t minor = 0;
+};
+
 // Define some classes to exercises the Properties feature.
 
 struct PropertiesWithCustomPrint {

diff  --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 7315b253df998e..3dfb76fd0f5f7c 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -14,15 +14,6 @@
 using namespace mlir;
 using namespace test;
 
-//===----------------------------------------------------------------------===//
-// TestDialect version utilities
-//===----------------------------------------------------------------------===//
-
-struct TestDialectVersion : public DialectVersion {
-  uint32_t major = 2;
-  uint32_t minor = 0;
-};
-
 //===----------------------------------------------------------------------===//
 // TestDialect Interfaces
 //===----------------------------------------------------------------------===//
@@ -47,7 +38,7 @@ struct TestResourceBlobManagerInterface
 };
 
 namespace {
-enum test_encoding { k_attr_params = 0 };
+enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
 }
 
 // Test support for interacting with the Bytecode reader/writer.
@@ -56,6 +47,24 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
   TestBytecodeDialectInterface(Dialect *dialect)
       : BytecodeDialectInterface(dialect) {}
 
+  LogicalResult writeType(Type type,
+                          DialectBytecodeWriter &writer) const final {
+    if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
+      writer.writeVarInt(test_encoding::k_test_i32);
+      return success();
+    }
+    return failure();
+  }
+
+  Type readType(DialectBytecodeReader &reader) const final {
+    uint64_t encoding;
+    if (failed(reader.readVarInt(encoding)))
+      return Type();
+    if (encoding == test_encoding::k_test_i32)
+      return TestI32Type::get(getContext());
+    return Type();
+  }
+
   LogicalResult writeAttribute(Attribute attr,
                                DialectBytecodeWriter &writer) const final {
     if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
@@ -67,9 +76,13 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
     return failure();
   }
 
-  Attribute readAttribute(DialectBytecodeReader &reader,
-                          const DialectVersion &version_) const final {
-    const auto &version = static_cast<const TestDialectVersion &>(version_);
+  Attribute readAttribute(DialectBytecodeReader &reader) const final {
+    auto versionOr = reader.getDialectVersion("test");
+    // Assume current version if not available through the reader.
+    const auto version =
+        (succeeded(versionOr))
+            ? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
+            : TestDialectVersion();
     if (version.major < 2)
       return readAttrOldEncoding(reader);
     if (version.major == 2 && version.minor == 0)

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9f897a6a30f541..fb0c54ce7c3b15 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1258,8 +1258,9 @@ def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_fo
 }
 
 def TestAddIOp : TEST_Op<"addi"> {
-  let arguments = (ins I32:$op1, I32:$op2);
-  let results = (outs I32);
+  let arguments = (ins AnyTypeOf<[I32, TestI32]>:$op1,
+                       AnyTypeOf<[I32, TestI32]>:$op2);
+  let results = (outs AnyTypeOf<[I32, TestI32]>);
 }
 
 def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
@@ -2620,6 +2621,12 @@ def TestVersionedOpB : TEST_Op<"versionedB"> {
   );
 }
 
+def TestVersionedOpC : TEST_Op<"versionedC"> {
+  let arguments = (ins AnyAttrOf<[TestAttrParams,
+                                  I32ElementsAttr]>:$attribute
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // Test Properties
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 15dbd74aec118f..f899d72219d058 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -369,4 +369,8 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
   let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
 }
 
+def TestI32 : Test_Type<"TestI32"> {
+  let mnemonic = "i32";
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 447a2481e8dbad..1696a14654831b 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestIR
+  TestBytecodeCallbacks.cpp
   TestBuiltinAttributeInterfaces.cpp
   TestBuiltinDistinctAttributes.cpp
   TestClone.cpp

diff  --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp
new file mode 100644
index 00000000000000..1464a80865f776
--- /dev/null
+++ b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp
@@ -0,0 +1,372 @@
+//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks  --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Bytecode/BytecodeReader.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/raw_ostream.h"
+#include <list>
+
+using namespace mlir;
+using namespace llvm;
+
+namespace {
+class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
+public:
+  TestDialectVersionParser(cl::Option &O)
+      : cl::parser<test::TestDialectVersion>(O) {}
+
+  bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg,
+             test::TestDialectVersion &v) {
+    long long major, minor;
+    if (getAsSignedInteger(arg.split(".").first, 10, major))
+      return O.error("Invalid argument '" + arg);
+    if (getAsSignedInteger(arg.split(".").second, 10, minor))
+      return O.error("Invalid argument '" + arg);
+    v = test::TestDialectVersion(major, minor);
+    // Returns true on error.
+    return false;
+  }
+  static void print(raw_ostream &os, const test::TestDialectVersion &v) {
+    os << v.major << "." << v.minor;
+  };
+};
+
+/// This is a test pass which uses callbacks to encode attributes and types in a
+/// custom fashion.
+struct TestBytecodeCallbackPass
+    : public PassWrapper<TestBytecodeCallbackPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass)
+
+  StringRef getArgument() const final { return "test-bytecode-callback"; }
+  StringRef getDescription() const final {
+    return "Test encoding of a dialect type/attributes with a custom callback";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<test::TestDialect>();
+  }
+  TestBytecodeCallbackPass() = default;
+  TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {}
+
+  void runOnOperation() override {
+    switch (testKind) {
+    case (0):
+      return runTest0(getOperation());
+    case (1):
+      return runTest1(getOperation());
+    case (2):
+      return runTest2(getOperation());
+    case (3):
+      return runTest3(getOperation());
+    case (4):
+      return runTest4(getOperation());
+    case (5):
+      return runTest5(getOperation());
+    default:
+      llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
+    }
+  }
+
+  mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
+      targetVersion{*this, "test-dialect-version",
+                    llvm::cl::desc(
+                        "Specifies the test dialect version to emit and parse"),
+                    cl::init(test::TestDialectVersion())};
+
+  mlir::Pass::Option<int> testKind{
+      *this, "callback-test",
+      llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)};
+
+private:
+  void doRoundtripWithConfigs(Operation *op,
+                              const BytecodeWriterConfig &writeConfig,
+                              const ParserConfig &parseConfig) {
+    std::string bytecode;
+    llvm::raw_string_ostream os(bytecode);
+    if (failed(writeBytecodeToFile(op, os, writeConfig))) {
+      op->emitError() << "failed to write bytecode\n";
+      signalPassFailure();
+      return;
+    }
+    auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
+    if (!newModuleOp.get()) {
+      op->emitError() << "failed to read bytecode\n";
+      signalPassFailure();
+      return;
+    }
+    // Print the module to the output stream, so that we can filecheck the
+    // result.
+    newModuleOp->print(llvm::outs());
+    return;
+  }
+
+  // Test0: let's assume that versions older than 2.0 were relying on a special
+  // integer attribute of a deprecated dialect called "funky". Assume that its
+  // encoding was made by two varInts, the first was the ID (999) and the second
+  // contained width and signedness info. We can emit it using a callback
+  // writing a custom encoding for the "funky" dialect group, and parse it back
+  // with a custom parser reading the same encoding in the same dialect group.
+  // Note that the ID 999 does not correspond to a valid integer type in the
+  // current encodings of builtin types.
+  void runTest0(Operation *op) {
+    auto newCtx = std::make_shared<MLIRContext>();
+    test::TestDialectVersion targetEmissionVersion = targetVersion;
+    BytecodeWriterConfig writeConfig;
+    writeConfig.attachTypeCallback(
+        [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
+            DialectBytecodeWriter &writer) -> LogicalResult {
+          // Do not override anything if version less than 2.0.
+          if (targetEmissionVersion.major >= 2)
+            return failure();
+
+          // For version less than 2.0, override the encoding of IntegerType.
+          if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
+            llvm::outs() << "Overriding IntegerType encoding...\n";
+            dialectGroupName = StringLiteral("funky");
+            writer.writeVarInt(/* IntegerType */ 999);
+            writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
+            return success();
+          }
+          return failure();
+        });
+    newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
+    newCtx->allowUnregisteredDialects();
+    ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
+    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
+        [&](DialectBytecodeReader &reader, StringRef dialectName,
+            Type &entry) -> LogicalResult {
+          // Get test dialect version from the version map.
+          auto versionOr = reader.getDialectVersion("test");
+          assert(succeeded(versionOr) && "expected reader to be able to access "
+                                         "the version for test dialect");
+          const auto *version =
+              reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+
+          // TODO: once back-deployment is formally supported,
+          // `targetEmissionVersion` will be encoded in the bytecode file, and
+          // exposed through the versionMap. Right now though this is not yet
+          // supported. For the purpose of the test, just use
+          // `targetEmissionVersion`.
+          (void)version;
+          if (targetEmissionVersion.major >= 2)
+            return success();
+
+          // `dialectName` is the name of the group we have the opportunity to
+          // override. In this case, override only the dialect group "funky",
+          // for which does not exist in memory.
+          if (dialectName != StringLiteral("funky"))
+            return success();
+
+          uint64_t encoding;
+          if (failed(reader.readVarInt(encoding)) || encoding != 999)
+            return success();
+          llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
+          uint64_t _widthAndSignedness, width;
+          IntegerType::SignednessSemantics signedness;
+          if (succeeded(reader.readVarInt(_widthAndSignedness)) &&
+              ((width = _widthAndSignedness >> 2), true) &&
+              ((signedness = static_cast<IntegerType::SignednessSemantics>(
+                    _widthAndSignedness & 0x3)),
+               true))
+            entry = IntegerType::get(reader.getContext(), width, signedness);
+          // Return nullopt to fall through the rest of the parsing code path.
+          return success();
+        });
+    doRoundtripWithConfigs(op, writeConfig, parseConfig);
+    return;
+  }
+
+  // Test1: When writing bytecode, we override the encoding of TestI32Type with
+  // the encoding of builtin IntegerType. We can natively parse this without
+  // the use of a callback, relying on the existing builtin reader mechanism.
+  void runTest1(Operation *op) {
+    auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+    BytecodeDialectInterface *iface =
+        builtin->getRegisteredInterface<BytecodeDialectInterface>();
+    BytecodeWriterConfig writeConfig;
+    writeConfig.attachTypeCallback(
+        [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
+            DialectBytecodeWriter &writer) -> LogicalResult {
+          // Emit TestIntegerType using the builtin dialect encoding.
+          if (llvm::isa<test::TestI32Type>(entryValue)) {
+            llvm::outs() << "Overriding TestI32Type encoding...\n";
+            auto builtinI32Type =
+                IntegerType::get(op->getContext(), 32,
+                                 IntegerType::SignednessSemantics::Signless);
+            // Specify that this type will need to be written as part of the
+            // builtin group. This will override the default dialect group of
+            // the attribute (test).
+            dialectGroupName = StringLiteral("builtin");
+            if (succeeded(iface->writeType(builtinI32Type, writer)))
+              return success();
+          }
+          return failure();
+        });
+    // We natively parse the attribute as a builtin, so no callback needed.
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
+    doRoundtripWithConfigs(op, writeConfig, parseConfig);
+    return;
+  }
+
+  // Test2: When writing bytecode, we write standard builtin IntegerTypes. At
+  // parsing, we use the encoding of IntegerType to intercept all i32. Then,
+  // instead of creating i32s, we assemble TestI32Type and return it.
+  void runTest2(Operation *op) {
+    auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+    BytecodeDialectInterface *iface =
+        builtin->getRegisteredInterface<BytecodeDialectInterface>();
+    BytecodeWriterConfig writeConfig;
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
+    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
+        [&](DialectBytecodeReader &reader, StringRef dialectName,
+            Type &entry) -> LogicalResult {
+          if (dialectName != StringLiteral("builtin"))
+            return success();
+          Type builtinAttr = iface->readType(reader);
+          if (auto integerType =
+                  llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
+            if (integerType.getWidth() == 32 && integerType.isSignless()) {
+              llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
+              entry = test::TestI32Type::get(reader.getContext());
+            }
+          }
+          return success();
+        });
+    doRoundtripWithConfigs(op, writeConfig, parseConfig);
+    return;
+  }
+
+  // Test3: When writing bytecode, we override the encoding of
+  // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
+  // can natively parse this without the use of a callback, relying on the
+  // existing builtin reader mechanism.
+  void runTest3(Operation *op) {
+    auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+    BytecodeDialectInterface *iface =
+        builtin->getRegisteredInterface<BytecodeDialectInterface>();
+    auto i32Type = IntegerType::get(op->getContext(), 32,
+                                    IntegerType::SignednessSemantics::Signless);
+    BytecodeWriterConfig writeConfig;
+    writeConfig.attachAttributeCallback(
+        [&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
+            DialectBytecodeWriter &writer) -> LogicalResult {
+          // Emit TestIntegerType using the builtin dialect encoding.
+          if (auto testParamAttrs =
+                  llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
+            llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
+            // Specify that this attribute will need to be written as part of
+            // the builtin group. This will override the default dialect group
+            // of the attribute (test).
+            dialectGroupName = StringLiteral("builtin");
+            auto denseAttr = DenseIntElementsAttr::get(
+                RankedTensorType::get({2}, i32Type),
+                {testParamAttrs.getV0(), testParamAttrs.getV1()});
+            if (succeeded(iface->writeAttribute(denseAttr, writer)))
+              return success();
+          }
+          return failure();
+        });
+    // We natively parse the attribute as a builtin, so no callback needed.
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
+    doRoundtripWithConfigs(op, writeConfig, parseConfig);
+    return;
+  }
+
+  // Test4: When writing bytecode, we write standard builtin
+  // DenseIntElementsAttr. At parsing, we use the encoding of
+  // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
+  // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
+  // TestAttrParamsAttr and return it.
+  void runTest4(Operation *op) {
+    auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+    BytecodeDialectInterface *iface =
+        builtin->getRegisteredInterface<BytecodeDialectInterface>();
+    auto i32Type = IntegerType::get(op->getContext(), 32,
+                                    IntegerType::SignednessSemantics::Signless);
+    BytecodeWriterConfig writeConfig;
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
+    parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
+        [&](DialectBytecodeReader &reader, StringRef dialectName,
+            Attribute &entry) -> LogicalResult {
+          // Override only the case where the return type of the builtin reader
+          // is an i32 and fall through on all the other cases, since we want to
+          // still use TestDialect normal codepath to parse the other types.
+          Attribute builtinAttr = iface->readAttribute(reader);
+          if (auto denseAttr =
+                  llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
+            if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
+                denseAttr.getElementType() == i32Type) {
+              llvm::outs()
+                  << "Overriding parsing of TestAttrParamsAttr encoding...\n";
+              int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
+              int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
+              entry =
+                  test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
+            }
+          }
+          return success();
+        });
+    doRoundtripWithConfigs(op, writeConfig, parseConfig);
+    return;
+  }
+
+  // Test5: When writing bytecode, we want TestDialect to use nothing else than
+  // the builtin types and attributes and take full control of the encoding,
+  // returning failure if any type or attribute is not part of builtin.
+  void runTest5(Operation *op) {
+    auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+    BytecodeDialectInterface *iface =
+        builtin->getRegisteredInterface<BytecodeDialectInterface>();
+    BytecodeWriterConfig writeConfig;
+    writeConfig.attachAttributeCallback(
+        [&](Attribute attr, std::optional<StringRef> &dialectGroupName,
+            DialectBytecodeWriter &writer) -> LogicalResult {
+          return iface->writeAttribute(attr, writer);
+        });
+    writeConfig.attachTypeCallback(
+        [&](Type type, std::optional<StringRef> &dialectGroupName,
+            DialectBytecodeWriter &writer) -> LogicalResult {
+          return iface->writeType(type, writer);
+        });
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
+    parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
+        [&](DialectBytecodeReader &reader, StringRef dialectName,
+            Attribute &entry) -> LogicalResult {
+          Attribute builtinAttr = iface->readAttribute(reader);
+          if (!builtinAttr)
+            return failure();
+          entry = builtinAttr;
+          return success();
+        });
+    parseConfig.getBytecodeReaderConfig().attachTypeCallback(
+        [&](DialectBytecodeReader &reader, StringRef dialectName,
+            Type &entry) -> LogicalResult {
+          Type builtinType = iface->readType(reader);
+          if (!builtinType) {
+            return failure();
+          }
+          entry = builtinType;
+          return success();
+        });
+    doRoundtripWithConfigs(op, writeConfig, parseConfig);
+    return;
+  }
+};
+} // namespace
+
+namespace mlir {
+void registerTestBytecodeCallbackPasses() {
+  PassRegistration<TestBytecodeCallbackPass>();
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e91cb118461ec5..78bd70b40c91e7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -43,6 +43,7 @@ void registerSymbolTestPasses();
 void registerRegionTestPasses();
 void registerTestAffineDataCopyPass();
 void registerTestAffineReifyValueBoundsPass();
+void registerTestBytecodeCallbackPasses();
 void registerTestDecomposeAffineOpPass();
 void registerTestAffineLoopUnswitchingPass();
 void registerTestAllReduceLoweringPass();
@@ -167,6 +168,7 @@ void registerTestPasses() {
   registerTestDecomposeAffineOpPass();
   registerTestAffineLoopUnswitchingPass();
   registerTestAllReduceLoweringPass();
+  registerTestBytecodeCallbackPasses();
   registerTestFunc();
   registerTestGpuMemoryPromotionPass();
   registerTestLoopPermutationPass();


        


More information about the Mlir-commits mailing list