[Mlir-commits] [mlir] 02c2ecb - [mlir:Bytecode] Add initial support for dialect defined attribute/type encodings
River Riddle
llvmlistbot at llvm.org
Tue Aug 23 16:56:39 PDT 2022
Author: River Riddle
Date: 2022-08-23T16:56:04-07:00
New Revision: 02c2ecb9c6c355b8b6f650d258077bd9cca0aacf
URL: https://github.com/llvm/llvm-project/commit/02c2ecb9c6c355b8b6f650d258077bd9cca0aacf
DIFF: https://github.com/llvm/llvm-project/commit/02c2ecb9c6c355b8b6f650d258077bd9cca0aacf.diff
LOG: [mlir:Bytecode] Add initial support for dialect defined attribute/type encodings
Dialects can opt-in to providing custom encodings by implementing the
`BytecodeDialectInterface`. This interface provides hooks, namely
`readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used
by the bytecode reader and writer. These hooks are provided a reader and writer
implementation that can be used to encode various constructs in the underlying
bytecode format. A unique feature of this interface is that dialects may choose
to only encode a subset of their attributes and types in a custom bytecode
format, which can simplify adding new or experimental components that aren't
fully baked.
Differential Revision: https://reviews.llvm.org/D132498
Added:
mlir/include/mlir/Bytecode/BytecodeImplementation.h
mlir/lib/IR/BuiltinDialectBytecode.cpp
mlir/lib/IR/BuiltinDialectBytecode.h
mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
mlir/test/Dialect/Builtin/Bytecode/types.mlir
Modified:
mlir/docs/BytecodeFormat.md
mlir/include/mlir/IR/DialectInterface.h
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
mlir/lib/Bytecode/Writer/IRNumbering.cpp
mlir/lib/Bytecode/Writer/IRNumbering.h
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/CMakeLists.txt
mlir/lib/IR/Dialect.cpp
Removed:
################################################################################
diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index acb1819c9932c..5260c996a880d 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -207,7 +207,26 @@ reference to the parent dialect instead.
##### Dialect Defined Encoding
-TODO: This is not yet supported.
+In addition to the assembly format fallback, dialects may also provide a custom
+encoding for their attributes and types. Custom encodings are very beneficial in
+that they are significantly smaller and faster to read and write.
+
+Dialects can opt-in to providing custom encodings by implementing the
+`BytecodeDialectInterface`. This interface provides hooks, namely
+`readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used
+by the bytecode reader and writer. These hooks are provided a reader and writer
+implementation that can be used to encode various constructs in the underlying
+bytecode format. A unique feature of this interface is that dialects may choose
+to only encode a subset of their attributes and types in a custom bytecode
+format, which can simplify adding new or experimental components that aren't
+fully baked.
+
+When implementing the bytecode interface, dialects are responsible for all
+aspects of the encoding. This includes the indicator for which kind of attribute
+or type is being encoded; the bytecode reader will only know that it has
+encountered an attribute or type of a given dialect, it doesn't encode any
+further information. As such, a common encoding idiom is to use a leading
+`varint` code to indicate how the attribute or type was encoded.
### IR Section
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
new file mode 100644
index 0000000000000..01bada7e0572f
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -0,0 +1,220 @@
+//===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines various interfaces and utilities necessary for dialects
+// to hook into bytecode serialization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
+#define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/Twine.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// DialectBytecodeReader
+//===----------------------------------------------------------------------===//
+
+/// This class defines a virtual interface for reading a bytecode stream,
+/// providing hooks into the bytecode reader. As such, this class should only be
+/// derived and defined by the main bytecode reader, users (i.e. dialects)
+/// should generally only interact with this class via the
+/// BytecodeDialectInterface below.
+class DialectBytecodeReader {
+public:
+ virtual ~DialectBytecodeReader() = default;
+
+ /// Emit an error to the reader.
+ virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
+
+ //===--------------------------------------------------------------------===//
+ // IR
+ //===--------------------------------------------------------------------===//
+
+ /// Read out a list of elements, invoking the provided callback for each
+ /// element. The callback function may be in any of the following forms:
+ /// * LogicalResult(T &)
+ /// * FailureOr<T>()
+ template <typename T, typename CallbackFn>
+ LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) {
+ uint64_t size;
+ if (failed(readVarInt(size)))
+ return failure();
+ result.reserve(size);
+
+ for (uint64_t i = 0; i < size; ++i) {
+ // Check if the callback uses FailureOr, or populates the result by
+ // reference.
+ if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
+ T element = {};
+ if (failed(callback(element)))
+ return failure();
+ result.emplace_back(std::move(element));
+ } else {
+ FailureOr<T> element = callback();
+ if (failed(element))
+ return failure();
+ result.emplace_back(std::move(*element));
+ }
+ }
+ return success();
+ }
+
+ /// Read a reference to the given attribute.
+ virtual LogicalResult readAttribute(Attribute &result) = 0;
+ template <typename T>
+ LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
+ return readList(attrs, [this](T &attr) { return readAttribute(attr); });
+ }
+ template <typename T>
+ LogicalResult parseAttribute(T &result) {
+ Attribute baseResult;
+ if (failed(parseAttribute(baseResult)))
+ return failure();
+ if ((result = baseResult.dyn_cast<T>()))
+ return success();
+ return emitError() << "expected attribute of type: "
+ << llvm::getTypeName<T>() << ", but got: " << baseResult;
+ }
+
+ /// Read a reference to the given type.
+ virtual LogicalResult readType(Type &result) = 0;
+ template <typename T>
+ LogicalResult readTypes(SmallVectorImpl<T> &types) {
+ return readList(types, [this](T &type) { return readType(type); });
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Primitives
+ //===--------------------------------------------------------------------===//
+
+ /// Read a variable width integer.
+ // TODO: Add a signed variant when necessary.
+ virtual LogicalResult readVarInt(uint64_t &result) = 0;
+
+ /// Read a string from the bytecode.
+ virtual LogicalResult readString(StringRef &result) = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// DialectBytecodeWriter
+//===----------------------------------------------------------------------===//
+
+/// This class defines a virtual interface for writing to a bytecode stream,
+/// providing hooks into the bytecode writer. As such, this class should only be
+/// derived and defined by the main bytecode writer, users (i.e. dialects)
+/// should generally only interact with this class via the
+/// BytecodeDialectInterface below.
+class DialectBytecodeWriter {
+public:
+ virtual ~DialectBytecodeWriter() = default;
+
+ //===--------------------------------------------------------------------===//
+ // IR
+ //===--------------------------------------------------------------------===//
+
+ /// Write out a list of elements, invoking the provided callback for each
+ /// element.
+ template <typename RangeT, typename CallbackFn>
+ void writeList(RangeT &&range, CallbackFn &&callback) {
+ writeVarInt(llvm::size(range));
+ for (auto &element : range)
+ callback(element);
+ }
+
+ /// Write a reference to the given attribute.
+ virtual void writeAttribute(Attribute attr) = 0;
+ template <typename T>
+ void writeAttributes(ArrayRef<T> attrs) {
+ writeList(attrs, [this](T attr) { writeAttribute(attr); });
+ }
+
+ /// Write a reference to the given type.
+ virtual void writeType(Type type) = 0;
+ template <typename T>
+ void writeTypes(ArrayRef<T> types) {
+ writeList(types, [this](T type) { writeType(type); });
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Primitives
+ //===--------------------------------------------------------------------===//
+
+ /// Write a variable width integer to the output stream. This should be the
+ /// preferred method for emitting integers whenever possible.
+ // TODO: Add a signed variant when necessary.
+ virtual void writeVarInt(uint64_t value) = 0;
+
+ /// Write a string to the bytecode, which is owned by the caller and is
+ /// guaranteed to not die before the end of the bytecode process. This should
+ /// only be called if such a guarantee can be made, such as when the string is
+ /// owned by an attribute or type.
+ virtual void writeOwnedString(StringRef str) = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// BytecodeDialectInterface
+//===----------------------------------------------------------------------===//
+
+class BytecodeDialectInterface
+ : public DialectInterface::Base<BytecodeDialectInterface> {
+public:
+ using Base::Base;
+
+ //===--------------------------------------------------------------------===//
+ // Reading
+ //===--------------------------------------------------------------------===//
+
+ /// Read an attribute belonging to this dialect from the given reader. This
+ /// method should return null in the case of failure.
+ virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
+ reader.emitError() << "dialect " << getDialect()->getNamespace()
+ << " does not support reading attributes from bytecode";
+ return Attribute();
+ }
+
+ /// Read a type belonging to this dialect from the given reader. This method
+ /// should return null in the case of failure.
+ virtual Type readType(DialectBytecodeReader &reader) const {
+ reader.emitError() << "dialect " << getDialect()->getNamespace()
+ << " does not support reading types from bytecode";
+ return Type();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Writing
+ //===--------------------------------------------------------------------===//
+
+ /// Write the given attribute, which belongs to this dialect, to the given
+ /// writer. This method may return failure to indicate that the given
+ /// attribute could not be encoded, in which case the textual format will be
+ /// used to encode this attribute instead.
+ virtual LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const {
+ return failure();
+ }
+
+ /// Write the given type, which belongs to this dialect, to the given writer.
+ /// This method may return failure to indicate that the given type could not
+ /// be encoded, in which case the textual format will be used to encode this
+ /// type instead.
+ virtual LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const {
+ return failure();
+ }
+};
+
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index 2b3aa1211455c..d5d36ed0171d2 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -50,6 +50,9 @@ class DialectInterface {
/// Return the dialect that this interface represents.
Dialect *getDialect() const { return dialect; }
+ /// Return the context that holds the parent dialect of this interface.
+ MLIRContext *getContext() const;
+
/// Return the derived interface id.
TypeID getID() const { return interfaceID; }
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 010e4e492fa64..78d60a0b5a3fe 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -12,6 +12,7 @@
#include "mlir/Bytecode/BytecodeReader.h"
#include "../Encoding.h"
#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
@@ -66,7 +67,7 @@ class EncodingReader {
/// Emit an error using the given arguments.
template <typename... Args>
- LogicalResult emitError(Args &&...args) const {
+ InFlightDiagnostic emitError(Args &&...args) const {
return ::emitError(fileLoc).append(std::forward<Args>(args)...);
}
@@ -326,6 +327,11 @@ struct BytecodeDialect {
"-allow-unregistered-dialect with the MLIR tool used.");
}
dialect = loadedDialect;
+
+ // If the dialect was actually loaded, check to see if it has a bytecode
+ // interface.
+ if (loadedDialect)
+ interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
return success();
}
@@ -333,6 +339,11 @@ struct BytecodeDialect {
/// load, nullptr if we failed to load, otherwise the loaded dialect.
Optional<Dialect *> dialect;
+ /// The bytecode interface of the dialect, or nullptr if the dialect does not
+ /// implement the bytecode interface. This field should only be checked if the
+ /// `dialect` field is non-None.
+ const BytecodeDialectInterface *interface = nullptr;
+
/// The name of the dialect.
StringRef name;
};
@@ -397,7 +408,8 @@ class AttrTypeReader {
using TypeEntry = Entry<Type>;
public:
- AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {}
+ AttrTypeReader(StringSectionReader &stringReader, Location fileLoc)
+ : stringReader(stringReader), fileLoc(fileLoc) {}
/// Initialize the attribute and type information within the reader.
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
@@ -456,6 +468,10 @@ class AttrTypeReader {
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
StringRef entryType);
+ /// The string section reader used to resolve string references when parsing
+ /// custom encoded attribute/type entries.
+ StringSectionReader &stringReader;
+
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -463,6 +479,47 @@ class AttrTypeReader {
/// A location used for error emission.
Location fileLoc;
};
+
+class DialectReader : public DialectBytecodeReader {
+public:
+ DialectReader(AttrTypeReader &attrTypeReader,
+ StringSectionReader &stringReader, EncodingReader &reader)
+ : attrTypeReader(attrTypeReader), stringReader(stringReader),
+ reader(reader) {}
+
+ InFlightDiagnostic emitError(const Twine &msg) override {
+ return reader.emitError(msg);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // IR
+ //===--------------------------------------------------------------------===//
+
+ LogicalResult readAttribute(Attribute &result) override {
+ return attrTypeReader.parseAttribute(reader, result);
+ }
+
+ LogicalResult readType(Type &result) override {
+ return attrTypeReader.parseType(reader, result);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Primitives
+ //===--------------------------------------------------------------------===//
+
+ LogicalResult readVarInt(uint64_t &result) override {
+ return reader.parseVarInt(result);
+ }
+
+ LogicalResult readString(StringRef &result) override {
+ return stringReader.parseString(reader, result);
+ }
+
+private:
+ AttrTypeReader &attrTypeReader;
+ StringSectionReader &stringReader;
+ EncodingReader &reader;
+};
} // namespace
LogicalResult
@@ -486,7 +543,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
size_t currentIndex = 0, endIndex = range.size();
// Parse an individual entry.
- auto parseEntryFn = [&](BytecodeDialect *dialect) {
+ auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
auto &entry = range[currentIndex++];
uint64_t entrySize;
@@ -548,8 +605,7 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
}
if (!reader.empty()) {
- (void)reader.emitError("unexpected trailing bytes after " + entryType +
- " entry");
+ reader.emitError("unexpected trailing bytes after " + entryType + " entry");
return T();
}
return entry.entry;
@@ -584,8 +640,22 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
- // FIXME: Add support for reading custom attribute/type encodings.
- return reader.emitError("unexpected Attribute encoding");
+ if (failed(entry.dialect->load(reader, fileLoc.getContext())))
+ return failure();
+
+ // Ensure that the dialect implements the bytecode interface.
+ if (!entry.dialect->interface) {
+ return reader.emitError("dialect '", entry.dialect->name,
+ "' does not implement the bytecode interface");
+ }
+
+ // Ask the dialect to parse the entry.
+ DialectReader dialectReader(*this, stringReader, reader);
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(dialectReader);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ return success(!!entry.entry);
}
//===----------------------------------------------------------------------===//
@@ -597,7 +667,7 @@ namespace {
class BytecodeReader {
public:
BytecodeReader(Location fileLoc, const ParserConfig &config)
- : config(config), fileLoc(fileLoc), attrTypeReader(fileLoc),
+ : config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index ebf827fde41b8..6fc2fb4354db8 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -9,6 +9,7 @@
#include "mlir/Bytecode/BytecodeWriter.h"
#include "../Encoding.h"
#include "IRNumbering.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/CachedHashString.h"
@@ -358,22 +359,78 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
//===----------------------------------------------------------------------===//
// Attributes and Types
+namespace {
+class DialectWriter : public DialectBytecodeWriter {
+public:
+ DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState,
+ StringSectionBuilder &stringSection)
+ : emitter(emitter), numberingState(numberingState),
+ stringSection(stringSection) {}
+
+ //===--------------------------------------------------------------------===//
+ // IR
+ //===--------------------------------------------------------------------===//
+
+ void writeAttribute(Attribute attr) override {
+ emitter.emitVarInt(numberingState.getNumber(attr));
+ }
+ void writeType(Type type) override {
+ emitter.emitVarInt(numberingState.getNumber(type));
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Primitives
+ //===--------------------------------------------------------------------===//
+
+ void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); }
+
+ void writeOwnedString(StringRef str) override {
+ emitter.emitVarInt(stringSection.insert(str));
+ }
+
+private:
+ EncodingEmitter &emitter;
+ IRNumberingState &numberingState;
+ StringSectionBuilder &stringSection;
+};
+} // namespace
+
void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
EncodingEmitter attrTypeEmitter;
EncodingEmitter offsetEmitter;
offsetEmitter.emitVarInt(llvm::size(numberingState.getAttributes()));
offsetEmitter.emitVarInt(llvm::size(numberingState.getTypes()));
+ // The writer used when emitting using a custom bytecode encoding.
+ DialectWriter dialectWriter(attrTypeEmitter, numberingState, stringSection);
+
// A functor used to emit an attribute or type entry.
uint64_t prevOffset = 0;
auto emitAttrOrType = [&](auto &entry) {
- // TODO: Allow dialects to provide more optimal implementations of attribute
- // and type encodings.
+ auto entryValue = entry.getValue();
+
+ // First, try to emit this entry using the dialect bytecode interface.
bool hasCustomEncoding = false;
+ if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
+ if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
+ // TODO: We don't currently support custom encoded mutable types.
+ hasCustomEncoding =
+ !entryValue.template hasTrait<TypeTrait::IsMutable>() &&
+ succeeded(interface->writeType(entryValue, dialectWriter));
+ } else {
+ // TODO: We don't currently support custom encoded mutable attributes.
+ hasCustomEncoding =
+ !entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
+ succeeded(interface->writeAttribute(entryValue, dialectWriter));
+ }
+ }
- // Emit the entry using the textual format.
- raw_emitter_ostream(attrTypeEmitter) << entry.getValue();
- attrTypeEmitter.emitByte(0);
+ // If the entry was not emitted using the dialect interface, emit it using
+ // the textual format.
+ if (!hasCustomEncoding) {
+ raw_emitter_ostream(attrTypeEmitter) << entryValue;
+ attrTypeEmitter.emitByte(0);
+ }
// Record the offset of this entry.
uint64_t curOffset = attrTypeEmitter.size();
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 61fef0e35cbba..88a69034d557f 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "IRNumbering.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -14,6 +15,28 @@
using namespace mlir;
using namespace mlir::bytecode::detail;
+//===----------------------------------------------------------------------===//
+// NumberingDialectWriter
+//===----------------------------------------------------------------------===//
+
+struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
+ NumberingDialectWriter(IRNumberingState &state) : state(state) {}
+
+ void writeAttribute(Attribute attr) override { state.number(attr); }
+ void writeType(Type type) override { state.number(type); }
+
+ /// Stubbed out methods that are not used for numbering.
+ void writeVarInt(uint64_t) override {}
+ void writeOwnedString(StringRef) override {
+ // TODO: It might be nice to prenumber strings and sort by the number of
+ // references. This could potentially be useful for optimizing things like
+ // file locations.
+ }
+
+ /// The parent numbering state that is populated by this writer.
+ IRNumberingState &state;
+};
+
//===----------------------------------------------------------------------===//
// IR Numbering
//===----------------------------------------------------------------------===//
@@ -138,10 +161,22 @@ void IRNumberingState::number(Attribute attr) {
// have a registered dialect when it got created. We don't want to encode this
// as the builtin OpaqueAttr, we want to encode it as if the dialect was
// actually loaded.
- if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>())
+ if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
- else
- numbering->dialect = &numberDialect(&attr.getDialect());
+ return;
+ }
+ numbering->dialect = &numberDialect(&attr.getDialect());
+
+ // If this attribute will be emitted using the bytecode format, perform a
+ // dummy writing to number any nested components.
+ if (const auto *interface = numbering->dialect->interface) {
+ // TODO: We don't allow custom encodings for mutable attributes right now.
+ if (attr.hasTrait<AttributeTrait::IsMutable>())
+ return;
+
+ NumberingDialectWriter writer(*this);
+ (void)interface->writeAttribute(attr, writer);
+ }
}
void IRNumberingState::number(Block &block) {
@@ -164,7 +199,7 @@ auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
DialectNumbering *&numbering = registeredDialects[dialect];
if (!numbering) {
numbering = &numberDialect(dialect->getNamespace());
- numbering->dialect = dialect;
+ numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
}
return *numbering;
}
@@ -244,8 +279,20 @@ void IRNumberingState::number(Type type) {
// registered dialect when it got created. We don't want to encode this as the
// builtin OpaqueType, we want to encode it as if the dialect was actually
// loaded.
- if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>())
+ if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>()) {
numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
- else
- numbering->dialect = &numberDialect(&type.getDialect());
+ return;
+ }
+ numbering->dialect = &numberDialect(&type.getDialect());
+
+ // If this type will be emitted using the bytecode format, perform a dummy
+ // writing to number any nested components.
+ if (const auto *interface = numbering->dialect->interface) {
+ // TODO: We don't allow custom encodings for mutable types right now.
+ if (type.hasTrait<TypeTrait::IsMutable>())
+ return;
+
+ NumberingDialectWriter writer(*this);
+ (void)interface->writeType(type, writer);
+ }
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index fd8e3b14c62e5..9f4cbfec2d8d3 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -18,6 +18,7 @@
#include "llvm/ADT/MapVector.h"
namespace mlir {
+class BytecodeDialectInterface;
class BytecodeWriterConfig;
namespace bytecode {
@@ -90,8 +91,8 @@ struct DialectNumbering {
/// The number assigned to the dialect.
unsigned number;
- /// The loaded dialect, or nullptr if the dialect isn't loaded.
- Dialect *dialect = nullptr;
+ /// The bytecode dialect interface of the dialect if defined.
+ const BytecodeDialectInterface *interface = nullptr;
};
//===----------------------------------------------------------------------===//
@@ -147,6 +148,10 @@ class IRNumberingState {
}
private:
+ /// This class is used to provide a fake dialect writer for numbering nested
+ /// attributes and types.
+ struct NumberingDialectWriter;
+
/// Number the given IR unit for bytecode emission.
void number(Attribute attr);
void number(Block &block);
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 7df22a9038f71..6686e7f58c9c9 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinDialect.h"
+#include "BuiltinDialectBytecode.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -117,6 +118,7 @@ void BuiltinDialect::initialize() {
auto &blobInterface = addInterface<BuiltinBlobManagerInterface>();
addInterface<BuiltinOpAsmDialectInterface>(blobInterface);
+ builtin_dialect_detail::addBytecodeInterface(this);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
new file mode 100644
index 0000000000000..619a342e61024
--- /dev/null
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -0,0 +1,269 @@
+//===- BuiltinDialectBytecode.cpp - Builtin Bytecode Implementation -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "BuiltinDialectBytecode.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Encoding
+//===----------------------------------------------------------------------===//
+
+namespace {
+namespace builtin_encoding {
+/// This enum contains marker codes used to indicate which attribute is
+/// currently being decoded, and how it should be decoded. The order of these
+/// codes should generally be unchanged, as any changes will inevitably break
+/// compatibility with older bytecode.
+enum AttributeCode {
+ /// ArrayAttr {
+ /// elements: Attribute[]
+ /// }
+ ///
+ kArrayAttr = 0,
+
+ /// DictionaryAttr {
+ /// attrs: <StringAttr, Attribute>[]
+ /// }
+ kDictionaryAttr = 1,
+
+ /// StringAttr {
+ /// string
+ /// }
+ kStringAttr = 2,
+};
+
+/// This enum contains marker codes used to indicate which type is currently
+/// being decoded, and how it should be decoded. The order of these codes should
+/// generally be unchanged, as any changes will inevitably break compatibility
+/// with older bytecode.
+enum TypeCode {
+ /// IntegerType {
+ /// widthAndSignedness: varint // (width << 2) | (signedness)
+ /// }
+ ///
+ kIntegerType = 0,
+
+ /// IndexType {
+ /// }
+ ///
+ kIndexType = 1,
+
+ /// FunctionType {
+ /// inputs: Type[],
+ /// results: Type[]
+ /// }
+ ///
+ kFunctionType = 2,
+};
+
+} // namespace builtin_encoding
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BuiltinDialectBytecodeInterface
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class implements the bytecode interface for the builtin dialect.
+struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
+ BuiltinDialectBytecodeInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ //===--------------------------------------------------------------------===//
+ // Attributes
+
+ Attribute readAttribute(DialectBytecodeReader &reader) const override;
+ ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
+ DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
+ StringAttr readStringAttr(DialectBytecodeReader &reader) const;
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const override;
+ void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
+ void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
+ void write(StringAttr attr, DialectBytecodeWriter &writer) const;
+
+ //===--------------------------------------------------------------------===//
+ // Types
+
+ Type readType(DialectBytecodeReader &reader) const override;
+ IntegerType readIntegerType(DialectBytecodeReader &reader) const;
+ FunctionType readFunctionType(DialectBytecodeReader &reader) const;
+
+ LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const override;
+ void write(IntegerType type, DialectBytecodeWriter &writer) const;
+ void write(FunctionType type, DialectBytecodeWriter &writer) const;
+};
+} // namespace
+
+void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) {
+ dialect->addInterfaces<BuiltinDialectBytecodeInterface>();
+}
+
+//===----------------------------------------------------------------------===//
+// Attributes: Reader
+
+Attribute BuiltinDialectBytecodeInterface::readAttribute(
+ DialectBytecodeReader &reader) const {
+ uint64_t code;
+ if (failed(reader.readVarInt(code)))
+ return Attribute();
+ switch (code) {
+ case builtin_encoding::kArrayAttr:
+ return readArrayAttr(reader);
+ case builtin_encoding::kDictionaryAttr:
+ return readDictionaryAttr(reader);
+ case builtin_encoding::kStringAttr:
+ return readStringAttr(reader);
+ default:
+ reader.emitError() << "unknown builtin attribute code: " << code;
+ return Attribute();
+ }
+}
+
+ArrayAttr BuiltinDialectBytecodeInterface::readArrayAttr(
+ DialectBytecodeReader &reader) const {
+ SmallVector<Attribute> elements;
+ if (failed(reader.readAttributes(elements)))
+ return ArrayAttr();
+ return ArrayAttr::get(getContext(), elements);
+}
+
+DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr(
+ DialectBytecodeReader &reader) const {
+ auto readNamedAttr = [&]() -> FailureOr<NamedAttribute> {
+ StringAttr name;
+ Attribute value;
+ if (failed(reader.readAttribute(name)) ||
+ failed(reader.readAttribute(value)))
+ return failure();
+ return NamedAttribute(name, value);
+ };
+ SmallVector<NamedAttribute> attrs;
+ if (failed(reader.readList(attrs, readNamedAttr)))
+ return DictionaryAttr();
+ return DictionaryAttr::get(getContext(), attrs);
+}
+
+StringAttr BuiltinDialectBytecodeInterface::readStringAttr(
+ DialectBytecodeReader &reader) const {
+ StringRef string;
+ if (failed(reader.readString(string)))
+ return StringAttr();
+ return StringAttr::get(getContext(), string);
+}
+
+//===----------------------------------------------------------------------===//
+// Attributes: Writer
+
+LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
+ Attribute attr, DialectBytecodeWriter &writer) const {
+ return TypeSwitch<Attribute, LogicalResult>(attr)
+ .Case<ArrayAttr, DictionaryAttr, StringAttr>([&](auto attr) {
+ write(attr, writer);
+ return success();
+ })
+ .Default([&](Attribute) { return failure(); });
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ ArrayAttr attr, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kArrayAttr);
+ writer.writeAttributes(attr.getValue());
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ DictionaryAttr attr, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kDictionaryAttr);
+ writer.writeList(attr.getValue(), [&](NamedAttribute attr) {
+ writer.writeAttribute(attr.getName());
+ writer.writeAttribute(attr.getValue());
+ });
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ StringAttr attr, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kStringAttr);
+ writer.writeOwnedString(attr.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// Types: Reader
+
+Type BuiltinDialectBytecodeInterface::readType(
+ DialectBytecodeReader &reader) const {
+ uint64_t code;
+ if (failed(reader.readVarInt(code)))
+ return Type();
+ switch (code) {
+ case builtin_encoding::kIntegerType:
+ return readIntegerType(reader);
+ case builtin_encoding::kIndexType:
+ return IndexType::get(getContext());
+
+ case builtin_encoding::kFunctionType:
+ return readFunctionType(reader);
+ default:
+ reader.emitError() << "unknown builtin type code: " << code;
+ return Type();
+ }
+}
+
+IntegerType BuiltinDialectBytecodeInterface::readIntegerType(
+ DialectBytecodeReader &reader) const {
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)))
+ return IntegerType();
+ return IntegerType::get(
+ getContext(), encoding >> 2,
+ static_cast<IntegerType::SignednessSemantics>(encoding & 0x3));
+}
+
+FunctionType BuiltinDialectBytecodeInterface::readFunctionType(
+ DialectBytecodeReader &reader) const {
+ SmallVector<Type> inputs, results;
+ if (failed(reader.readTypes(inputs)) || failed(reader.readTypes(results)))
+ return FunctionType();
+ return FunctionType::get(getContext(), inputs, results);
+}
+
+//===----------------------------------------------------------------------===//
+// Types: Writer
+
+LogicalResult BuiltinDialectBytecodeInterface::writeType(
+ Type type, DialectBytecodeWriter &writer) const {
+ return TypeSwitch<Type, LogicalResult>(type)
+ .Case<IntegerType, FunctionType>([&](auto type) {
+ write(type, writer);
+ return success();
+ })
+ .Case([&](IndexType) {
+ return writer.writeVarInt(builtin_encoding::kIndexType), success();
+ })
+ .Default([&](Type) { return failure(); });
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ IntegerType type, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kIntegerType);
+ writer.writeVarInt((type.getWidth() << 2) | type.getSignedness());
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ FunctionType type, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kFunctionType);
+ writer.writeTypes(type.getInputs());
+ writer.writeTypes(type.getResults());
+}
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.h b/mlir/lib/IR/BuiltinDialectBytecode.h
new file mode 100644
index 0000000000000..775e8e0987184
--- /dev/null
+++ b/mlir/lib/IR/BuiltinDialectBytecode.h
@@ -0,0 +1,26 @@
+//===- BuiltinDialectBytecode.h - MLIR Bytecode Implementation --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines hooks into the builtin dialect bytecode implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H
+#define LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H
+
+namespace mlir {
+class BuiltinDialect;
+
+namespace builtin_dialect_detail {
+/// Add the interfaces necessary for encoding the builtin dialect components in
+/// bytecode.
+void addBytecodeInterface(BuiltinDialect *dialect);
+} // namespace builtin_dialect_detail
+} // namespace mlir
+
+#endif // LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 72f386c31a241..355ddd4d450ae 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRIR
BuiltinAttributeInterfaces.cpp
BuiltinAttributes.cpp
BuiltinDialect.cpp
+ BuiltinDialectBytecode.cpp
BuiltinTypes.cpp
BuiltinTypeInterfaces.cpp
Diagnostics.cpp
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index b8f5aa29c31f5..e72e071d8f95a 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -113,6 +113,10 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
DialectInterface::~DialectInterface() = default;
+MLIRContext *DialectInterface::getContext() const {
+ return dialect->getContext();
+}
+
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind) {
for (auto *dialect : ctx->getLoadedDialects()) {
diff --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
new file mode 100644
index 0000000000000..8f91a25768196
--- /dev/null
+++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: s390x-
+
+// CHECK-LABEL: @TestArray
+module @TestArray attributes {
+ // CHECK: bytecode.array = [unit]
+ bytecode.array = [unit]
+} {}
+
+// CHECK-LABEL: @TestString
+module @TestString attributes {
+ // CHECK: bytecode.string = "hello"
+ bytecode.string = "hello"
+} {}
diff --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
new file mode 100644
index 0000000000000..bb311aff4ae0f
--- /dev/null
+++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: s390x-
+
+// CHECK-LABEL: @TestInteger
+module @TestInteger attributes {
+ // CHECK: bytecode.int = i1024,
+ // CHECK: bytecode.int1 = si32,
+ // CHECK: bytecode.int2 = ui512
+ bytecode.int = i1024,
+ bytecode.int1 = si32,
+ bytecode.int2 = ui512
+} {}
+
+// CHECK-LABEL: @TestIndex
+module @TestIndex attributes {
+ // CHECK: bytecode.index = index
+ bytecode.index = index
+} {}
+
+// CHECK-LABEL: @TestFunc
+module @TestFunc attributes {
+ // CHECK: bytecode.func = () -> (),
+ // CHECK: bytecode.func1 = (i1) -> i32
+ bytecode.func = () -> (),
+ bytecode.func1 = (i1) -> (i32)
+} {}
More information about the Mlir-commits
mailing list