[Mlir-commits] [mlir] 02c2ecb - [mlir:Bytecode] Add initial support for dialect defined attribute/type encodings

River Riddle llvmlistbot at llvm.org
Tue Aug 23 16:56:39 PDT 2022


Author: River Riddle
Date: 2022-08-23T16:56:04-07:00
New Revision: 02c2ecb9c6c355b8b6f650d258077bd9cca0aacf

URL: https://github.com/llvm/llvm-project/commit/02c2ecb9c6c355b8b6f650d258077bd9cca0aacf
DIFF: https://github.com/llvm/llvm-project/commit/02c2ecb9c6c355b8b6f650d258077bd9cca0aacf.diff

LOG: [mlir:Bytecode] Add initial support for dialect defined attribute/type encodings

Dialects can opt-in to providing custom encodings by implementing the
`BytecodeDialectInterface`. This interface provides hooks, namely
`readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used
by the bytecode reader and writer. These hooks are provided a reader and writer
implementation that can be used to encode various constructs in the underlying
bytecode format. A unique feature of this interface is that dialects may choose
to only encode a subset of their attributes and types in a custom bytecode
format, which can simplify adding new or experimental components that aren't
fully baked.

Differential Revision: https://reviews.llvm.org/D132498

Added: 
    mlir/include/mlir/Bytecode/BytecodeImplementation.h
    mlir/lib/IR/BuiltinDialectBytecode.cpp
    mlir/lib/IR/BuiltinDialectBytecode.h
    mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
    mlir/test/Dialect/Builtin/Bytecode/types.mlir

Modified: 
    mlir/docs/BytecodeFormat.md
    mlir/include/mlir/IR/DialectInterface.h
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.h
    mlir/lib/IR/BuiltinDialect.cpp
    mlir/lib/IR/CMakeLists.txt
    mlir/lib/IR/Dialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index acb1819c9932c..5260c996a880d 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -207,7 +207,26 @@ reference to the parent dialect instead.
 
 ##### Dialect Defined Encoding
 
-TODO: This is not yet supported.
+In addition to the assembly format fallback, dialects may also provide a custom
+encoding for their attributes and types. Custom encodings are very beneficial in
+that they are significantly smaller and faster to read and write.
+
+Dialects can opt-in to providing custom encodings by implementing the
+`BytecodeDialectInterface`. This interface provides hooks, namely
+`readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used
+by the bytecode reader and writer. These hooks are provided a reader and writer
+implementation that can be used to encode various constructs in the underlying
+bytecode format. A unique feature of this interface is that dialects may choose
+to only encode a subset of their attributes and types in a custom bytecode
+format, which can simplify adding new or experimental components that aren't
+fully baked.
+
+When implementing the bytecode interface, dialects are responsible for all
+aspects of the encoding. This includes the indicator for which kind of attribute
+or type is being encoded; the bytecode reader will only know that it has
+encountered an attribute or type of a given dialect, it doesn't encode any
+further information. As such, a common encoding idiom is to use a leading
+`varint` code to indicate how the attribute or type was encoded.
 
 ### IR Section
 

diff  --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
new file mode 100644
index 0000000000000..01bada7e0572f
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -0,0 +1,220 @@
+//===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines various interfaces and utilities necessary for dialects
+// to hook into bytecode serialization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
+#define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/Twine.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// DialectBytecodeReader
+//===----------------------------------------------------------------------===//
+
+/// This class defines a virtual interface for reading a bytecode stream,
+/// providing hooks into the bytecode reader. As such, this class should only be
+/// derived and defined by the main bytecode reader, users (i.e. dialects)
+/// should generally only interact with this class via the
+/// BytecodeDialectInterface below.
+class DialectBytecodeReader {
+public:
+  virtual ~DialectBytecodeReader() = default;
+
+  /// Emit an error to the reader.
+  virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
+
+  //===--------------------------------------------------------------------===//
+  // IR
+  //===--------------------------------------------------------------------===//
+
+  /// Read out a list of elements, invoking the provided callback for each
+  /// element. The callback function may be in any of the following forms:
+  ///   * LogicalResult(T &)
+  ///   * FailureOr<T>()
+  template <typename T, typename CallbackFn>
+  LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) {
+    uint64_t size;
+    if (failed(readVarInt(size)))
+      return failure();
+    result.reserve(size);
+
+    for (uint64_t i = 0; i < size; ++i) {
+      // Check if the callback uses FailureOr, or populates the result by
+      // reference.
+      if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
+        T element = {};
+        if (failed(callback(element)))
+          return failure();
+        result.emplace_back(std::move(element));
+      } else {
+        FailureOr<T> element = callback();
+        if (failed(element))
+          return failure();
+        result.emplace_back(std::move(*element));
+      }
+    }
+    return success();
+  }
+
+  /// Read a reference to the given attribute.
+  virtual LogicalResult readAttribute(Attribute &result) = 0;
+  template <typename T>
+  LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
+    return readList(attrs, [this](T &attr) { return readAttribute(attr); });
+  }
+  template <typename T>
+  LogicalResult parseAttribute(T &result) {
+    Attribute baseResult;
+    if (failed(parseAttribute(baseResult)))
+      return failure();
+    if ((result = baseResult.dyn_cast<T>()))
+      return success();
+    return emitError() << "expected attribute of type: "
+                       << llvm::getTypeName<T>() << ", but got: " << baseResult;
+  }
+
+  /// Read a reference to the given type.
+  virtual LogicalResult readType(Type &result) = 0;
+  template <typename T>
+  LogicalResult readTypes(SmallVectorImpl<T> &types) {
+    return readList(types, [this](T &type) { return readType(type); });
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Primitives
+  //===--------------------------------------------------------------------===//
+
+  /// Read a variable width integer.
+  // TODO: Add a signed variant when necessary.
+  virtual LogicalResult readVarInt(uint64_t &result) = 0;
+
+  /// Read a string from the bytecode.
+  virtual LogicalResult readString(StringRef &result) = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// DialectBytecodeWriter
+//===----------------------------------------------------------------------===//
+
+/// This class defines a virtual interface for writing to a bytecode stream,
+/// providing hooks into the bytecode writer. As such, this class should only be
+/// derived and defined by the main bytecode writer, users (i.e. dialects)
+/// should generally only interact with this class via the
+/// BytecodeDialectInterface below.
+class DialectBytecodeWriter {
+public:
+  virtual ~DialectBytecodeWriter() = default;
+
+  //===--------------------------------------------------------------------===//
+  // IR
+  //===--------------------------------------------------------------------===//
+
+  /// Write out a list of elements, invoking the provided callback for each
+  /// element.
+  template <typename RangeT, typename CallbackFn>
+  void writeList(RangeT &&range, CallbackFn &&callback) {
+    writeVarInt(llvm::size(range));
+    for (auto &element : range)
+      callback(element);
+  }
+
+  /// Write a reference to the given attribute.
+  virtual void writeAttribute(Attribute attr) = 0;
+  template <typename T>
+  void writeAttributes(ArrayRef<T> attrs) {
+    writeList(attrs, [this](T attr) { writeAttribute(attr); });
+  }
+
+  /// Write a reference to the given type.
+  virtual void writeType(Type type) = 0;
+  template <typename T>
+  void writeTypes(ArrayRef<T> types) {
+    writeList(types, [this](T type) { writeType(type); });
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Primitives
+  //===--------------------------------------------------------------------===//
+
+  /// Write a variable width integer to the output stream. This should be the
+  /// preferred method for emitting integers whenever possible.
+  // TODO: Add a signed variant when necessary.
+  virtual void writeVarInt(uint64_t value) = 0;
+
+  /// Write a string to the bytecode, which is owned by the caller and is
+  /// guaranteed to not die before the end of the bytecode process. This should
+  /// only be called if such a guarantee can be made, such as when the string is
+  /// owned by an attribute or type.
+  virtual void writeOwnedString(StringRef str) = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// BytecodeDialectInterface
+//===----------------------------------------------------------------------===//
+
+class BytecodeDialectInterface
+    : public DialectInterface::Base<BytecodeDialectInterface> {
+public:
+  using Base::Base;
+
+  //===--------------------------------------------------------------------===//
+  // Reading
+  //===--------------------------------------------------------------------===//
+
+  /// Read an attribute belonging to this dialect from the given reader. This
+  /// method should return null in the case of failure.
+  virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
+    reader.emitError() << "dialect " << getDialect()->getNamespace()
+                       << " does not support reading attributes from bytecode";
+    return Attribute();
+  }
+
+  /// Read a type belonging to this dialect from the given reader. This method
+  /// should return null in the case of failure.
+  virtual Type readType(DialectBytecodeReader &reader) const {
+    reader.emitError() << "dialect " << getDialect()->getNamespace()
+                       << " does not support reading types from bytecode";
+    return Type();
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Writing
+  //===--------------------------------------------------------------------===//
+
+  /// Write the given attribute, which belongs to this dialect, to the given
+  /// writer. This method may return failure to indicate that the given
+  /// attribute could not be encoded, in which case the textual format will be
+  /// used to encode this attribute instead.
+  virtual LogicalResult writeAttribute(Attribute attr,
+                                       DialectBytecodeWriter &writer) const {
+    return failure();
+  }
+
+  /// Write the given type, which belongs to this dialect, to the given writer.
+  /// This method may return failure to indicate that the given type could not
+  /// be encoded, in which case the textual format will be used to encode this
+  /// type instead.
+  virtual LogicalResult writeType(Type type,
+                                  DialectBytecodeWriter &writer) const {
+    return failure();
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H

diff  --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index 2b3aa1211455c..d5d36ed0171d2 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -50,6 +50,9 @@ class DialectInterface {
   /// Return the dialect that this interface represents.
   Dialect *getDialect() const { return dialect; }
 
+  /// Return the context that holds the parent dialect of this interface.
+  MLIRContext *getContext() const;
+
   /// Return the derived interface id.
   TypeID getID() const { return interfaceID; }
 

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 010e4e492fa64..78d60a0b5a3fe 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeReader.h"
 #include "../Encoding.h"
 #include "mlir/AsmParser/AsmParser.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/OpImplementation.h"
@@ -66,7 +67,7 @@ class EncodingReader {
 
   /// Emit an error using the given arguments.
   template <typename... Args>
-  LogicalResult emitError(Args &&...args) const {
+  InFlightDiagnostic emitError(Args &&...args) const {
     return ::emitError(fileLoc).append(std::forward<Args>(args)...);
   }
 
@@ -326,6 +327,11 @@ struct BytecodeDialect {
           "-allow-unregistered-dialect with the MLIR tool used.");
     }
     dialect = loadedDialect;
+
+    // If the dialect was actually loaded, check to see if it has a bytecode
+    // interface.
+    if (loadedDialect)
+      interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
     return success();
   }
 
@@ -333,6 +339,11 @@ struct BytecodeDialect {
   /// load, nullptr if we failed to load, otherwise the loaded dialect.
   Optional<Dialect *> dialect;
 
+  /// The bytecode interface of the dialect, or nullptr if the dialect does not
+  /// implement the bytecode interface. This field should only be checked if the
+  /// `dialect` field is non-None.
+  const BytecodeDialectInterface *interface = nullptr;
+
   /// The name of the dialect.
   StringRef name;
 };
@@ -397,7 +408,8 @@ class AttrTypeReader {
   using TypeEntry = Entry<Type>;
 
 public:
-  AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {}
+  AttrTypeReader(StringSectionReader &stringReader, Location fileLoc)
+      : stringReader(stringReader), fileLoc(fileLoc) {}
 
   /// Initialize the attribute and type information within the reader.
   LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
@@ -456,6 +468,10 @@ class AttrTypeReader {
   LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
                                  StringRef entryType);
 
+  /// The string section reader used to resolve string references when parsing
+  /// custom encoded attribute/type entries.
+  StringSectionReader &stringReader;
+
   /// The set of attribute and type entries.
   SmallVector<AttrEntry> attributes;
   SmallVector<TypeEntry> types;
@@ -463,6 +479,47 @@ class AttrTypeReader {
   /// A location used for error emission.
   Location fileLoc;
 };
+
+class DialectReader : public DialectBytecodeReader {
+public:
+  DialectReader(AttrTypeReader &attrTypeReader,
+                StringSectionReader &stringReader, EncodingReader &reader)
+      : attrTypeReader(attrTypeReader), stringReader(stringReader),
+        reader(reader) {}
+
+  InFlightDiagnostic emitError(const Twine &msg) override {
+    return reader.emitError(msg);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // IR
+  //===--------------------------------------------------------------------===//
+
+  LogicalResult readAttribute(Attribute &result) override {
+    return attrTypeReader.parseAttribute(reader, result);
+  }
+
+  LogicalResult readType(Type &result) override {
+    return attrTypeReader.parseType(reader, result);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Primitives
+  //===--------------------------------------------------------------------===//
+
+  LogicalResult readVarInt(uint64_t &result) override {
+    return reader.parseVarInt(result);
+  }
+
+  LogicalResult readString(StringRef &result) override {
+    return stringReader.parseString(reader, result);
+  }
+
+private:
+  AttrTypeReader &attrTypeReader;
+  StringSectionReader &stringReader;
+  EncodingReader &reader;
+};
 } // namespace
 
 LogicalResult
@@ -486,7 +543,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
     size_t currentIndex = 0, endIndex = range.size();
 
     // Parse an individual entry.
-    auto parseEntryFn = [&](BytecodeDialect *dialect) {
+    auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
       auto &entry = range[currentIndex++];
 
       uint64_t entrySize;
@@ -548,8 +605,7 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
   }
 
   if (!reader.empty()) {
-    (void)reader.emitError("unexpected trailing bytes after " + entryType +
-                           " entry");
+    reader.emitError("unexpected trailing bytes after " + entryType + " entry");
     return T();
   }
   return entry.entry;
@@ -584,8 +640,22 @@ template <typename T>
 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
                                                EncodingReader &reader,
                                                StringRef entryType) {
-  // FIXME: Add support for reading custom attribute/type encodings.
-  return reader.emitError("unexpected Attribute encoding");
+  if (failed(entry.dialect->load(reader, fileLoc.getContext())))
+    return failure();
+
+  // Ensure that the dialect implements the bytecode interface.
+  if (!entry.dialect->interface) {
+    return reader.emitError("dialect '", entry.dialect->name,
+                            "' does not implement the bytecode interface");
+  }
+
+  // Ask the dialect to parse the entry.
+  DialectReader dialectReader(*this, stringReader, reader);
+  if constexpr (std::is_same_v<T, Type>)
+    entry.entry = entry.dialect->interface->readType(dialectReader);
+  else
+    entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+  return success(!!entry.entry);
 }
 
 //===----------------------------------------------------------------------===//
@@ -597,7 +667,7 @@ namespace {
 class BytecodeReader {
 public:
   BytecodeReader(Location fileLoc, const ParserConfig &config)
-      : config(config), fileLoc(fileLoc), attrTypeReader(fileLoc),
+      : config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc),
         // Use the builtin unrealized conversion cast operation to represent
         // forward references to values that aren't yet defined.
         forwardRefOpState(UnknownLoc::get(config.getContext()),

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index ebf827fde41b8..6fc2fb4354db8 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Bytecode/BytecodeWriter.h"
 #include "../Encoding.h"
 #include "IRNumbering.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/OpImplementation.h"
 #include "llvm/ADT/CachedHashString.h"
@@ -358,22 +359,78 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
 //===----------------------------------------------------------------------===//
 // Attributes and Types
 
+namespace {
+class DialectWriter : public DialectBytecodeWriter {
+public:
+  DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState,
+                StringSectionBuilder &stringSection)
+      : emitter(emitter), numberingState(numberingState),
+        stringSection(stringSection) {}
+
+  //===--------------------------------------------------------------------===//
+  // IR
+  //===--------------------------------------------------------------------===//
+
+  void writeAttribute(Attribute attr) override {
+    emitter.emitVarInt(numberingState.getNumber(attr));
+  }
+  void writeType(Type type) override {
+    emitter.emitVarInt(numberingState.getNumber(type));
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Primitives
+  //===--------------------------------------------------------------------===//
+
+  void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); }
+
+  void writeOwnedString(StringRef str) override {
+    emitter.emitVarInt(stringSection.insert(str));
+  }
+
+private:
+  EncodingEmitter &emitter;
+  IRNumberingState &numberingState;
+  StringSectionBuilder &stringSection;
+};
+} // namespace
+
 void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
   EncodingEmitter attrTypeEmitter;
   EncodingEmitter offsetEmitter;
   offsetEmitter.emitVarInt(llvm::size(numberingState.getAttributes()));
   offsetEmitter.emitVarInt(llvm::size(numberingState.getTypes()));
 
+  // The writer used when emitting using a custom bytecode encoding.
+  DialectWriter dialectWriter(attrTypeEmitter, numberingState, stringSection);
+
   // A functor used to emit an attribute or type entry.
   uint64_t prevOffset = 0;
   auto emitAttrOrType = [&](auto &entry) {
-    // TODO: Allow dialects to provide more optimal implementations of attribute
-    // and type encodings.
+    auto entryValue = entry.getValue();
+
+    // First, try to emit this entry using the dialect bytecode interface.
     bool hasCustomEncoding = false;
+    if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
+      if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
+        // TODO: We don't currently support custom encoded mutable types.
+        hasCustomEncoding =
+            !entryValue.template hasTrait<TypeTrait::IsMutable>() &&
+            succeeded(interface->writeType(entryValue, dialectWriter));
+      } else {
+        // TODO: We don't currently support custom encoded mutable attributes.
+        hasCustomEncoding =
+            !entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
+            succeeded(interface->writeAttribute(entryValue, dialectWriter));
+      }
+    }
 
-    // Emit the entry using the textual format.
-    raw_emitter_ostream(attrTypeEmitter) << entry.getValue();
-    attrTypeEmitter.emitByte(0);
+    // If the entry was not emitted using the dialect interface, emit it using
+    // the textual format.
+    if (!hasCustomEncoding) {
+      raw_emitter_ostream(attrTypeEmitter) << entryValue;
+      attrTypeEmitter.emitByte(0);
+    }
 
     // Record the offset of this entry.
     uint64_t curOffset = attrTypeEmitter.size();

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 61fef0e35cbba..88a69034d557f 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "IRNumbering.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
 #include "mlir/Bytecode/BytecodeWriter.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
@@ -14,6 +15,28 @@
 using namespace mlir;
 using namespace mlir::bytecode::detail;
 
+//===----------------------------------------------------------------------===//
+// NumberingDialectWriter
+//===----------------------------------------------------------------------===//
+
+struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
+  NumberingDialectWriter(IRNumberingState &state) : state(state) {}
+
+  void writeAttribute(Attribute attr) override { state.number(attr); }
+  void writeType(Type type) override { state.number(type); }
+
+  /// Stubbed out methods that are not used for numbering.
+  void writeVarInt(uint64_t) override {}
+  void writeOwnedString(StringRef) override {
+    // TODO: It might be nice to prenumber strings and sort by the number of
+    // references. This could potentially be useful for optimizing things like
+    // file locations.
+  }
+
+  /// The parent numbering state that is populated by this writer.
+  IRNumberingState &state;
+};
+
 //===----------------------------------------------------------------------===//
 // IR Numbering
 //===----------------------------------------------------------------------===//
@@ -138,10 +161,22 @@ void IRNumberingState::number(Attribute attr) {
   // have a registered dialect when it got created. We don't want to encode this
   // as the builtin OpaqueAttr, we want to encode it as if the dialect was
   // actually loaded.
-  if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>())
+  if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
     numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
-  else
-    numbering->dialect = &numberDialect(&attr.getDialect());
+    return;
+  }
+  numbering->dialect = &numberDialect(&attr.getDialect());
+
+  // If this attribute will be emitted using the bytecode format, perform a
+  // dummy writing to number any nested components.
+  if (const auto *interface = numbering->dialect->interface) {
+    // TODO: We don't allow custom encodings for mutable attributes right now.
+    if (attr.hasTrait<AttributeTrait::IsMutable>())
+      return;
+
+    NumberingDialectWriter writer(*this);
+    (void)interface->writeAttribute(attr, writer);
+  }
 }
 
 void IRNumberingState::number(Block &block) {
@@ -164,7 +199,7 @@ auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
   DialectNumbering *&numbering = registeredDialects[dialect];
   if (!numbering) {
     numbering = &numberDialect(dialect->getNamespace());
-    numbering->dialect = dialect;
+    numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
   }
   return *numbering;
 }
@@ -244,8 +279,20 @@ void IRNumberingState::number(Type type) {
   // registered dialect when it got created. We don't want to encode this as the
   // builtin OpaqueType, we want to encode it as if the dialect was actually
   // loaded.
-  if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>())
+  if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>()) {
     numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
-  else
-    numbering->dialect = &numberDialect(&type.getDialect());
+    return;
+  }
+  numbering->dialect = &numberDialect(&type.getDialect());
+
+  // If this type will be emitted using the bytecode format, perform a dummy
+  // writing to number any nested components.
+  if (const auto *interface = numbering->dialect->interface) {
+    // TODO: We don't allow custom encodings for mutable types right now.
+    if (type.hasTrait<TypeTrait::IsMutable>())
+      return;
+
+    NumberingDialectWriter writer(*this);
+    (void)interface->writeType(type, writer);
+  }
 }

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index fd8e3b14c62e5..9f4cbfec2d8d3 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -18,6 +18,7 @@
 #include "llvm/ADT/MapVector.h"
 
 namespace mlir {
+class BytecodeDialectInterface;
 class BytecodeWriterConfig;
 
 namespace bytecode {
@@ -90,8 +91,8 @@ struct DialectNumbering {
   /// The number assigned to the dialect.
   unsigned number;
 
-  /// The loaded dialect, or nullptr if the dialect isn't loaded.
-  Dialect *dialect = nullptr;
+  /// The bytecode dialect interface of the dialect if defined.
+  const BytecodeDialectInterface *interface = nullptr;
 };
 
 //===----------------------------------------------------------------------===//
@@ -147,6 +148,10 @@ class IRNumberingState {
   }
 
 private:
+  /// This class is used to provide a fake dialect writer for numbering nested
+  /// attributes and types.
+  struct NumberingDialectWriter;
+
   /// Number the given IR unit for bytecode emission.
   void number(Attribute attr);
   void number(Block &block);

diff  --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 7df22a9038f71..6686e7f58c9c9 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/BuiltinDialect.h"
+#include "BuiltinDialectBytecode.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -117,6 +118,7 @@ void BuiltinDialect::initialize() {
 
   auto &blobInterface = addInterface<BuiltinBlobManagerInterface>();
   addInterface<BuiltinOpAsmDialectInterface>(blobInterface);
+  builtin_dialect_detail::addBytecodeInterface(this);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
new file mode 100644
index 0000000000000..619a342e61024
--- /dev/null
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -0,0 +1,269 @@
+//===- BuiltinDialectBytecode.cpp - Builtin Bytecode Implementation -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "BuiltinDialectBytecode.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Encoding
+//===----------------------------------------------------------------------===//
+
+namespace {
+namespace builtin_encoding {
+/// This enum contains marker codes used to indicate which attribute is
+/// currently being decoded, and how it should be decoded. The order of these
+/// codes should generally be unchanged, as any changes will inevitably break
+/// compatibility with older bytecode.
+enum AttributeCode {
+  ///   ArrayAttr {
+  ///     elements: Attribute[]
+  ///   }
+  ///
+  kArrayAttr = 0,
+
+  ///   DictionaryAttr {
+  ///     attrs: <StringAttr, Attribute>[]
+  ///   }
+  kDictionaryAttr = 1,
+
+  ///   StringAttr {
+  ///     string
+  ///   }
+  kStringAttr = 2,
+};
+
+/// This enum contains marker codes used to indicate which type is currently
+/// being decoded, and how it should be decoded. The order of these codes should
+/// generally be unchanged, as any changes will inevitably break compatibility
+/// with older bytecode.
+enum TypeCode {
+  ///   IntegerType {
+  ///     widthAndSignedness: varint // (width << 2) | (signedness)
+  ///   }
+  ///
+  kIntegerType = 0,
+
+  ///   IndexType {
+  ///   }
+  ///
+  kIndexType = 1,
+
+  ///   FunctionType {
+  ///     inputs: Type[],
+  ///     results: Type[]
+  ///   }
+  ///
+  kFunctionType = 2,
+};
+
+} // namespace builtin_encoding
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BuiltinDialectBytecodeInterface
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class implements the bytecode interface for the builtin dialect.
+struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
+  BuiltinDialectBytecodeInterface(Dialect *dialect)
+      : BytecodeDialectInterface(dialect) {}
+
+  //===--------------------------------------------------------------------===//
+  // Attributes
+
+  Attribute readAttribute(DialectBytecodeReader &reader) const override;
+  ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
+  DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
+  StringAttr readStringAttr(DialectBytecodeReader &reader) const;
+
+  LogicalResult writeAttribute(Attribute attr,
+                               DialectBytecodeWriter &writer) const override;
+  void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
+  void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
+  void write(StringAttr attr, DialectBytecodeWriter &writer) const;
+
+  //===--------------------------------------------------------------------===//
+  // Types
+
+  Type readType(DialectBytecodeReader &reader) const override;
+  IntegerType readIntegerType(DialectBytecodeReader &reader) const;
+  FunctionType readFunctionType(DialectBytecodeReader &reader) const;
+
+  LogicalResult writeType(Type type,
+                          DialectBytecodeWriter &writer) const override;
+  void write(IntegerType type, DialectBytecodeWriter &writer) const;
+  void write(FunctionType type, DialectBytecodeWriter &writer) const;
+};
+} // namespace
+
+void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) {
+  dialect->addInterfaces<BuiltinDialectBytecodeInterface>();
+}
+
+//===----------------------------------------------------------------------===//
+// Attributes: Reader
+
+Attribute BuiltinDialectBytecodeInterface::readAttribute(
+    DialectBytecodeReader &reader) const {
+  uint64_t code;
+  if (failed(reader.readVarInt(code)))
+    return Attribute();
+  switch (code) {
+  case builtin_encoding::kArrayAttr:
+    return readArrayAttr(reader);
+  case builtin_encoding::kDictionaryAttr:
+    return readDictionaryAttr(reader);
+  case builtin_encoding::kStringAttr:
+    return readStringAttr(reader);
+  default:
+    reader.emitError() << "unknown builtin attribute code: " << code;
+    return Attribute();
+  }
+}
+
+ArrayAttr BuiltinDialectBytecodeInterface::readArrayAttr(
+    DialectBytecodeReader &reader) const {
+  SmallVector<Attribute> elements;
+  if (failed(reader.readAttributes(elements)))
+    return ArrayAttr();
+  return ArrayAttr::get(getContext(), elements);
+}
+
+DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr(
+    DialectBytecodeReader &reader) const {
+  auto readNamedAttr = [&]() -> FailureOr<NamedAttribute> {
+    StringAttr name;
+    Attribute value;
+    if (failed(reader.readAttribute(name)) ||
+        failed(reader.readAttribute(value)))
+      return failure();
+    return NamedAttribute(name, value);
+  };
+  SmallVector<NamedAttribute> attrs;
+  if (failed(reader.readList(attrs, readNamedAttr)))
+    return DictionaryAttr();
+  return DictionaryAttr::get(getContext(), attrs);
+}
+
+StringAttr BuiltinDialectBytecodeInterface::readStringAttr(
+    DialectBytecodeReader &reader) const {
+  StringRef string;
+  if (failed(reader.readString(string)))
+    return StringAttr();
+  return StringAttr::get(getContext(), string);
+}
+
+//===----------------------------------------------------------------------===//
+// Attributes: Writer
+
+LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
+    Attribute attr, DialectBytecodeWriter &writer) const {
+  return TypeSwitch<Attribute, LogicalResult>(attr)
+      .Case<ArrayAttr, DictionaryAttr, StringAttr>([&](auto attr) {
+        write(attr, writer);
+        return success();
+      })
+      .Default([&](Attribute) { return failure(); });
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    ArrayAttr attr, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kArrayAttr);
+  writer.writeAttributes(attr.getValue());
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    DictionaryAttr attr, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kDictionaryAttr);
+  writer.writeList(attr.getValue(), [&](NamedAttribute attr) {
+    writer.writeAttribute(attr.getName());
+    writer.writeAttribute(attr.getValue());
+  });
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    StringAttr attr, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kStringAttr);
+  writer.writeOwnedString(attr.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// Types: Reader
+
+Type BuiltinDialectBytecodeInterface::readType(
+    DialectBytecodeReader &reader) const {
+  uint64_t code;
+  if (failed(reader.readVarInt(code)))
+    return Type();
+  switch (code) {
+  case builtin_encoding::kIntegerType:
+    return readIntegerType(reader);
+  case builtin_encoding::kIndexType:
+    return IndexType::get(getContext());
+
+  case builtin_encoding::kFunctionType:
+    return readFunctionType(reader);
+  default:
+    reader.emitError() << "unknown builtin type code: " << code;
+    return Type();
+  }
+}
+
+IntegerType BuiltinDialectBytecodeInterface::readIntegerType(
+    DialectBytecodeReader &reader) const {
+  uint64_t encoding;
+  if (failed(reader.readVarInt(encoding)))
+    return IntegerType();
+  return IntegerType::get(
+      getContext(), encoding >> 2,
+      static_cast<IntegerType::SignednessSemantics>(encoding & 0x3));
+}
+
+FunctionType BuiltinDialectBytecodeInterface::readFunctionType(
+    DialectBytecodeReader &reader) const {
+  SmallVector<Type> inputs, results;
+  if (failed(reader.readTypes(inputs)) || failed(reader.readTypes(results)))
+    return FunctionType();
+  return FunctionType::get(getContext(), inputs, results);
+}
+
+//===----------------------------------------------------------------------===//
+// Types: Writer
+
+LogicalResult BuiltinDialectBytecodeInterface::writeType(
+    Type type, DialectBytecodeWriter &writer) const {
+  return TypeSwitch<Type, LogicalResult>(type)
+      .Case<IntegerType, FunctionType>([&](auto type) {
+        write(type, writer);
+        return success();
+      })
+      .Case([&](IndexType) {
+        return writer.writeVarInt(builtin_encoding::kIndexType), success();
+      })
+      .Default([&](Type) { return failure(); });
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    IntegerType type, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kIntegerType);
+  writer.writeVarInt((type.getWidth() << 2) | type.getSignedness());
+}
+
+void BuiltinDialectBytecodeInterface::write(
+    FunctionType type, DialectBytecodeWriter &writer) const {
+  writer.writeVarInt(builtin_encoding::kFunctionType);
+  writer.writeTypes(type.getInputs());
+  writer.writeTypes(type.getResults());
+}

diff  --git a/mlir/lib/IR/BuiltinDialectBytecode.h b/mlir/lib/IR/BuiltinDialectBytecode.h
new file mode 100644
index 0000000000000..775e8e0987184
--- /dev/null
+++ b/mlir/lib/IR/BuiltinDialectBytecode.h
@@ -0,0 +1,26 @@
+//===- BuiltinDialectBytecode.h - MLIR Bytecode Implementation --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines hooks into the builtin dialect bytecode implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H
+#define LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H
+
+namespace mlir {
+class BuiltinDialect;
+
+namespace builtin_dialect_detail {
+/// Add the interfaces necessary for encoding the builtin dialect components in
+/// bytecode.
+void addBytecodeInterface(BuiltinDialect *dialect);
+} // namespace builtin_dialect_detail
+} // namespace mlir
+
+#endif // LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 72f386c31a241..355ddd4d450ae 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRIR
   BuiltinAttributeInterfaces.cpp
   BuiltinAttributes.cpp
   BuiltinDialect.cpp
+  BuiltinDialectBytecode.cpp
   BuiltinTypes.cpp
   BuiltinTypeInterfaces.cpp
   Diagnostics.cpp

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index b8f5aa29c31f5..e72e071d8f95a 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -113,6 +113,10 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
 
 DialectInterface::~DialectInterface() = default;
 
+MLIRContext *DialectInterface::getContext() const {
+  return dialect->getContext();
+}
+
 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
     MLIRContext *ctx, TypeID interfaceKind) {
   for (auto *dialect : ctx->getLoadedDialects()) {

diff  --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
new file mode 100644
index 0000000000000..8f91a25768196
--- /dev/null
+++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: s390x-
+
+// CHECK-LABEL: @TestArray
+module @TestArray attributes {
+  // CHECK: bytecode.array = [unit]
+  bytecode.array = [unit]
+} {}
+
+// CHECK-LABEL: @TestString
+module @TestString attributes {
+  // CHECK: bytecode.string = "hello"
+  bytecode.string = "hello"
+} {}

diff  --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
new file mode 100644
index 0000000000000..bb311aff4ae0f
--- /dev/null
+++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: s390x-
+
+// CHECK-LABEL: @TestInteger
+module @TestInteger attributes {
+  // CHECK: bytecode.int = i1024,
+  // CHECK: bytecode.int1 = si32,
+  // CHECK: bytecode.int2 = ui512
+  bytecode.int = i1024,
+  bytecode.int1 = si32,
+  bytecode.int2 = ui512
+} {}
+
+// CHECK-LABEL: @TestIndex
+module @TestIndex attributes {
+  // CHECK: bytecode.index = index
+  bytecode.index = index
+} {}
+
+// CHECK-LABEL: @TestFunc
+module @TestFunc attributes {
+  // CHECK: bytecode.func = () -> (),
+  // CHECK: bytecode.func1 = (i1) -> i32
+  bytecode.func = () -> (),
+  bytecode.func1 = (i1) -> (i32)
+} {}


        


More information about the Mlir-commits mailing list