[Mlir-commits] [mlir] 0e0b607 - Implements MLIR Bytecode versioning capability
Mehdi Amini
llvmlistbot at llvm.org
Fri Mar 10 14:29:16 PST 2023
Author: Matteo Franciolini
Date: 2023-03-10T23:28:56+01:00
New Revision: 0e0b6070fd2a2a8f188ddb32aa526beda38190b7
URL: https://github.com/llvm/llvm-project/commit/0e0b6070fd2a2a8f188ddb32aa526beda38190b7
DIFF: https://github.com/llvm/llvm-project/commit/0e0b6070fd2a2a8f188ddb32aa526beda38190b7.diff
LOG: Implements MLIR Bytecode versioning capability
A dialect can opt-in to handle versioning through the
`BytecodeDialectInterface`. Few hooks are exposed to the dialect to allow
managing a version encoded into the bytecode file. The version is loaded
lazily and allows to retrieve the version information while parsing the input
IR, and gives an opportunity to each dialect for which a version is present
to perform IR upgrades post-parsing through the `upgradeFromVersion` method.
Custom Attribute and Type encodings can also be upgraded according to the
dialect version using readAttribute and readType methods.
There is no restriction on what kind of information a dialect is allowed to
encode to model its versioning. Currently, versioning is supported only for
bytecode formats.
Reviewed By: rriddle, mehdi_amini
Differential Revision: https://reviews.llvm.org/D143647
Added:
mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc
mlir/test/Bytecode/versioning/versioned-attr-2.0.mlirbc
mlir/test/Bytecode/versioning/versioned-op-1.12.mlirbc
mlir/test/Bytecode/versioning/versioned-op-2.0.mlirbc
mlir/test/Bytecode/versioning/versioned-op-2.2.mlirbc
mlir/test/Bytecode/versioning/versioned_attr.mlir
mlir/test/Bytecode/versioning/versioned_op.mlir
Modified:
mlir/docs/BytecodeFormat.md
mlir/docs/LangRef.md
mlir/include/mlir/Bytecode/BytecodeImplementation.h
mlir/lib/Bytecode/Encoding.h
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
mlir/test/Bytecode/invalid/invalid-structure.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index 4d061e9f08d36..b4f7400274f43 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -6,7 +6,8 @@ This documents describes the MLIR bytecode format and its encoding.
## Magic Number
-MLIR uses the following four-byte magic number to indicate bytecode files:
+MLIR uses the following four-byte magic number to
+indicate bytecode files:
'\[‘M’<sub>8</sub>, ‘L’<sub>8</sub>, ‘ï’<sub>8</sub>, ‘R’<sub>8</sub>\]'
@@ -157,16 +158,25 @@ dialect_section {
}
op_name_group {
- dialect: varint,
+ dialect: varint // (dialectID << 1) | (hasVersion),
+ version : dialect_version_section
numOpNames: varint,
opNames: varint[]
}
+
+dialect_version_section {
+ size: varint,
+ version: byte[]
+}
+
```
-Dialects are encoded as indexes to the name string within the string section.
-Operation names are encoded in groups by dialect, with each group containing the
-dialect, the number of operation names, and the array of indexes to each name
-within the string section.
+Dialects are encoded as a `varint` containing the index to the name string
+within the string section, plus a flag indicating whether the dialect is
+versioned. Operation names are encoded in groups by dialect, with each group
+containing the dialect, the number of operation names, and the array of indexes
+to each name within the string section. The version is encoded as a nested
+section.
### Attribute/Type Sections
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 07f804018c2aa..467c1987ce1ac 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -845,3 +845,18 @@ The [builtin dialect](Dialects/Builtin.md) defines a set of attribute values
that are directly usable by any other dialect in MLIR. These types cover a range
from primitive integer and floating-point values, attribute dictionaries, dense
multi-dimensional arrays, and more.
+
+### IR Versionning
+
+A dialect can opt-in to handle versioning through the
+`BytecodeDialectInterface`. Few hooks are exposed to the dialect to allow
+managing a version encoded into the bytecode file. The version is loaded lazily
+and allows to retrieve the version information while parsing the input IR, and
+gives an opportunity to each dialect for which a version is present to perform
+IR upgrades post-parsing through the `upgradeFromVersion` method. Custom
+Attribute and Type encodings can also be upgraded according to the dialect
+version using readAttribute and readType methods.
+
+There is no restriction on what kind of information a dialect is allowed to
+encode to model its versioning. Currently, versioning is supported only for
+bytecode formats.
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 1ae839d5f40e7..ea9bcad735b36 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -235,6 +235,17 @@ class DialectBytecodeWriter {
virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
};
+//===--------------------------------------------------------------------===//
+// Dialect Version Interface.
+//===--------------------------------------------------------------------===//
+
+/// This class is used to represent the version of a dialect, for the purpose
+/// of polymorphic destruction.
+class DialectVersion {
+public:
+ virtual ~DialectVersion() = default;
+};
+
//===----------------------------------------------------------------------===//
// BytecodeDialectInterface
//===----------------------------------------------------------------------===//
@@ -256,6 +267,19 @@ class BytecodeDialectInterface
return Attribute();
}
+ /// Read a versioned attribute encoding belonging to this dialect from the
+ /// given reader. This method should return null in the case of failure, and
+ /// falls back to the non-versioned reader in case the dialect implements
+ /// versioning but it does not support versioned custom encodings for the
+ /// attributes.
+ virtual Attribute readAttribute(DialectBytecodeReader &reader,
+ const DialectVersion &version) const {
+ reader.emitError()
+ << "dialect " << getDialect()->getNamespace()
+ << " does not support reading versioned attributes from bytecode";
+ return Attribute();
+ }
+
/// Read a type belonging to this dialect from the given reader. This method
/// should return null in the case of failure.
virtual Type readType(DialectBytecodeReader &reader) const {
@@ -264,6 +288,19 @@ class BytecodeDialectInterface
return Type();
}
+ /// Read a versioned type encoding belonging to this dialect from the given
+ /// reader. This method should return null in the case of failure, and
+ /// falls back to the non-versioned reader in case the dialect implements
+ /// versioning but it does not support versioned custom encodings for the
+ /// types.
+ virtual Type readType(DialectBytecodeReader &reader,
+ const DialectVersion &version) const {
+ reader.emitError()
+ << "dialect " << getDialect()->getNamespace()
+ << " does not support reading versioned types from bytecode";
+ return Type();
+ }
+
//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
@@ -285,6 +322,27 @@ class BytecodeDialectInterface
DialectBytecodeWriter &writer) const {
return failure();
}
+
+ /// Write the version of this dialect to the given writer.
+ virtual void writeVersion(DialectBytecodeWriter &writer) const {}
+
+ // Read the version of this dialect from the provided reader and return it as
+ // a `unique_ptr` to a dialect version object.
+ virtual std::unique_ptr<DialectVersion>
+ readVersion(DialectBytecodeReader &reader) const {
+ reader.emitError("Dialect does not support versioning");
+ return nullptr;
+ }
+
+ /// Hook invoked after parsing completed, if a version directive was present
+ /// and included an entry for the current dialect. This hook offers the
+ /// opportunity to the dialect to visit the IR and upgrades constructs emitted
+ /// by the version of the dialect corresponding to the provided version.
+ virtual LogicalResult
+ upgradeFromVersion(Operation *topLevelOp,
+ const DialectVersion &version) const {
+ return success();
+ }
};
} // namespace mlir
diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h
index ee1789fdcaf4c..0072538154806 100644
--- a/mlir/lib/Bytecode/Encoding.h
+++ b/mlir/lib/Bytecode/Encoding.h
@@ -23,8 +23,11 @@ namespace bytecode {
//===----------------------------------------------------------------------===//
enum {
+ /// The minimum supported version of the bytecode.
+ kMinSupportedVersion = 0,
+
/// The current bytecode version.
- kVersion = 0,
+ kVersion = 1,
/// An arbitrary value used to fill alignment padding.
kAlignmentByte = 0xCB,
@@ -61,8 +64,11 @@ enum ID : uint8_t {
/// section.
kResourceOffset = 6,
+ /// This section contains the versions of each dialect.
+ kDialectVersions = 7,
+
/// The total number of section types.
- kNumSections = 7,
+ kNumSections = 8,
};
} // namespace Section
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 5e71c3a9e5f45..d6f1e18d35ae9 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -47,6 +47,8 @@ static std::string toString(bytecode::Section::ID sectionID) {
return "Resource (5)";
case bytecode::Section::kResourceOffset:
return "ResourceOffset (6)";
+ case bytecode::Section::kDialectVersions:
+ return "DialectVersions (7)";
default:
return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
}
@@ -63,6 +65,7 @@ static bool isSectionOptional(bytecode::Section::ID sectionID) {
return false;
case bytecode::Section::kResource:
case bytecode::Section::kResourceOffset:
+ case bytecode::Section::kDialectVersions:
return true;
default:
llvm_unreachable("unknown section ID");
@@ -350,6 +353,13 @@ class StringSectionReader {
return parseEntry(reader, strings, result, "string");
}
+ /// Parse a shared string from the string section. The shared string is
+ /// encoded using an index to a corresponding string in the string section.
+ LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
+ StringRef &result) {
+ return resolveEntry(reader, strings, index, result, "string");
+ }
+
private:
/// The table of strings referenced within the bytecode file.
SmallVector<StringRef> strings;
@@ -400,31 +410,15 @@ LogicalResult StringSectionReader::initialize(Location fileLoc,
//===----------------------------------------------------------------------===//
namespace {
+class DialectReader;
+
/// This struct represents a dialect entry within the bytecode.
struct BytecodeDialect {
/// Load the dialect into the provided context if it hasn't been loaded yet.
/// Returns failure if the dialect couldn't be loaded *and* the provided
/// context does not allow unregistered dialects. The provided reader is used
/// for error emission if necessary.
- LogicalResult load(EncodingReader &reader, MLIRContext *ctx) {
- if (dialect)
- return success();
- Dialect *loadedDialect = ctx->getOrLoadDialect(name);
- if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
- return reader.emitError(
- "dialect '", name,
- "' is unknown. If this is intended, please call "
- "allowUnregisteredDialects() on the MLIRContext, or use "
- "-allow-unregistered-dialect with the MLIR tool used.");
- }
- dialect = loadedDialect;
-
- // If the dialect was actually loaded, check to see if it has a bytecode
- // interface.
- if (loadedDialect)
- interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
- return success();
- }
+ LogicalResult load(DialectReader &reader, MLIRContext *ctx);
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
@@ -446,6 +440,12 @@ struct BytecodeDialect {
/// The name of the dialect.
StringRef name;
+
+ /// A buffer containing the encoding of the dialect version parsed.
+ ArrayRef<uint8_t> versionBuffer;
+
+ /// Lazy loaded dialect version from the handle above.
+ std::unique_ptr<DialectVersion> loadedVersion;
};
/// This struct represents an operation name entry within the bytecode.
@@ -496,7 +496,7 @@ class ResourceSectionReader {
initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData,
+ ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
/// Parse a dialect resource handle from the resource section.
@@ -643,7 +643,7 @@ LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData,
+ ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
EncodingReader resourceReader(sectionData, fileLoc);
EncodingReader offsetReader(offsetSectionData, fileLoc);
@@ -684,7 +684,7 @@ LogicalResult ResourceSectionReader::initialize(
while (!offsetReader.empty()) {
BytecodeDialect *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
- failed(dialect->load(resourceReader, ctx)))
+ failed(dialect->load(dialectReader, ctx)))
return failure();
Dialect *loadedDialect = dialect->getLoadedDialect();
if (!loadedDialect) {
@@ -1051,7 +1051,8 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
- if (failed(entry.dialect->load(reader, fileLoc.getContext())))
+ DialectReader dialectReader(*this, stringReader, resourceReader, reader);
+ if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
// Ensure that the dialect implements the bytecode interface.
@@ -1060,12 +1061,22 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
"' does not implement the bytecode interface");
}
- // Ask the dialect to parse the entry.
- DialectReader dialectReader(*this, stringReader, resourceReader, reader);
- if constexpr (std::is_same_v<T, Type>)
- entry.entry = entry.dialect->interface->readType(dialectReader);
- else
- entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ // Ask the dialect to parse the entry. If the dialect is versioned, parse
+ // using the versioned encoding readers.
+ if (entry.dialect->loadedVersion.get()) {
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(
+ dialectReader, *entry.dialect->loadedVersion);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(
+ dialectReader, *entry.dialect->loadedVersion);
+
+ } else {
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(dialectReader);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ }
return success(!!entry.entry);
}
@@ -1122,7 +1133,8 @@ class BytecodeReader {
// Resource Section
LogicalResult
- parseResourceSection(std::optional<ArrayRef<uint8_t>> resourceData,
+ parseResourceSection(EncodingReader &reader,
+ std::optional<ArrayRef<uint8_t>> resourceData,
std::optional<ArrayRef<uint8_t>> resourceOffsetData);
//===--------------------------------------------------------------------===//
@@ -1306,7 +1318,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
// Process the resource section if present.
if (failed(parseResourceSection(
- sectionDatas[bytecode::Section::kResource],
+ reader, sectionDatas[bytecode::Section::kResource],
sectionDatas[bytecode::Section::kResourceOffset])))
return failure();
@@ -1326,7 +1338,8 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
// Validate the bytecode version.
uint64_t currentVersion = bytecode::kVersion;
- if (version < currentVersion) {
+ uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
+ if (version < minSupportedVersion) {
return reader.emitError("bytecode version ", version,
" is older than the current version of ",
currentVersion, ", and upgrade is not supported");
@@ -1342,6 +1355,36 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Dialect Section
+LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
+ if (dialect)
+ return success();
+ Dialect *loadedDialect = ctx->getOrLoadDialect(name);
+ if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
+ return reader.emitError("dialect '")
+ << name
+ << "' is unknown. If this is intended, please call "
+ "allowUnregisteredDialects() on the MLIRContext, or use "
+ "-allow-unregistered-dialect with the MLIR tool used.";
+ }
+ dialect = loadedDialect;
+
+ // If the dialect was actually loaded, check to see if it has a bytecode
+ // interface.
+ if (loadedDialect)
+ interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
+ if (!versionBuffer.empty()) {
+ if (!interface)
+ return reader.emitError("dialect '")
+ << name
+ << "' does not implement the bytecode interface, "
+ "but found a version entry";
+ loadedVersion = interface->readVersion(reader);
+ if (!loadedVersion)
+ return failure();
+ }
+ return success();
+}
+
LogicalResult
BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
EncodingReader sectionReader(sectionData, fileLoc);
@@ -1353,9 +1396,34 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
dialects.resize(numDialects);
// Parse each of the dialects.
- for (uint64_t i = 0; i < numDialects; ++i)
- if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+ for (uint64_t i = 0; i < numDialects; ++i) {
+ /// Before version 1, there wasn't any versioning available for dialects,
+ /// and the entryIdx represent the string itself.
+ if (version == 0) {
+ if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+ return failure();
+ continue;
+ }
+ // Parse ID representing dialect and version.
+ uint64_t dialectNameIdx;
+ bool versionAvailable;
+ if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
+ versionAvailable)))
+ return failure();
+ if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
+ dialects[i].name)))
return failure();
+ if (versionAvailable) {
+ bytecode::Section::ID sectionID;
+ if (failed(
+ sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
+ return failure();
+ if (sectionID != bytecode::Section::kDialectVersions) {
+ emitError(fileLoc, "expected dialect version section");
+ return failure();
+ }
+ }
+ }
// Parse the operation names, which are grouped by dialect.
auto parseOpName = [&](BytecodeDialect *dialect) {
@@ -1379,7 +1447,11 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
if (!opName->opName) {
- if (failed(opName->dialect->load(reader, getContext())))
+ // Load the dialect and its version.
+ EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc);
+ DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
+ versionReader);
+ if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
getContext());
@@ -1391,7 +1463,7 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
// Resource Section
LogicalResult BytecodeReader::parseResourceSection(
- std::optional<ArrayRef<uint8_t>> resourceData,
+ EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
// Ensure both sections are either present or not.
if (resourceData.has_value() != resourceOffsetData.has_value()) {
@@ -1408,9 +1480,11 @@ LogicalResult BytecodeReader::parseResourceSection(
return success();
// Initialize the resource reader with the resource sections.
+ DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
+ reader);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
- bufferOwnerRef);
+ dialectReader, bufferOwnerRef);
}
//===----------------------------------------------------------------------===//
@@ -1442,6 +1516,18 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
"not all forward unresolved forward operand references");
}
+ // Resolve dialect version.
+ for (const BytecodeDialect &byteCodeDialect : dialects) {
+ // Parsing is complete, give an opportunity to each dialect to visit the
+ // IR and perform upgrades.
+ if (!byteCodeDialect.loadedVersion)
+ continue;
+ if (byteCodeDialect.interface &&
+ failed(byteCodeDialect.interface->upgradeFromVersion(
+ *moduleOp, *byteCodeDialect.loadedVersion)))
+ return failure();
+ }
+
// Verify that the parsed operations are valid.
if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
return failure();
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 1fd313453fb24..8433f3d84b852 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -10,13 +10,10 @@
#include "../Encoding.h"
#include "IRNumbering.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
-#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/CachedHashString.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
-#include "llvm/Support/Debug.h"
-#include <random>
#define DEBUG_TYPE "mlir-bytecode-writer"
@@ -261,6 +258,116 @@ class EncodingEmitter {
unsigned requiredAlignment = 1;
};
+//===----------------------------------------------------------------------===//
+// StringSectionBuilder
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to simplify the process of emitting the string section.
+class StringSectionBuilder {
+public:
+ /// Add the given string to the string section, and return the index of the
+ /// string within the section.
+ size_t insert(StringRef str) {
+ auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()});
+ return it.first->second;
+ }
+
+ /// Write the current set of strings to the given emitter.
+ void write(EncodingEmitter &emitter) {
+ emitter.emitVarInt(strings.size());
+
+ // Emit the sizes in reverse order, so that we don't need to backpatch an
+ // offset to the string data or have a separate section.
+ for (const auto &it : llvm::reverse(strings))
+ emitter.emitVarInt(it.first.size() + 1);
+ // Emit the string data itself.
+ for (const auto &it : strings)
+ emitter.emitNulTerminatedString(it.first.val());
+ }
+
+private:
+ /// A set of strings referenced within the bytecode. The value of the map is
+ /// unused.
+ llvm::MapVector<llvm::CachedHashStringRef, size_t> strings;
+};
+} // namespace
+
+class DialectWriter : public DialectBytecodeWriter {
+public:
+ DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState,
+ StringSectionBuilder &stringSection)
+ : emitter(emitter), numberingState(numberingState),
+ stringSection(stringSection) {}
+
+ //===--------------------------------------------------------------------===//
+ // IR
+ //===--------------------------------------------------------------------===//
+
+ void writeAttribute(Attribute attr) override {
+ emitter.emitVarInt(numberingState.getNumber(attr));
+ }
+ void writeType(Type type) override {
+ emitter.emitVarInt(numberingState.getNumber(type));
+ }
+
+ void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
+ emitter.emitVarInt(numberingState.getNumber(resource));
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Primitives
+ //===--------------------------------------------------------------------===//
+
+ void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); }
+
+ void writeSignedVarInt(int64_t value) override {
+ emitter.emitSignedVarInt(value);
+ }
+
+ void writeAPIntWithKnownWidth(const APInt &value) override {
+ size_t bitWidth = value.getBitWidth();
+
+ // If the value is a single byte, just emit it directly without going
+ // through a varint.
+ if (bitWidth <= 8)
+ return emitter.emitByte(value.getLimitedValue());
+
+ // If the value fits within a single varint, emit it directly.
+ if (bitWidth <= 64)
+ return emitter.emitSignedVarInt(value.getLimitedValue());
+
+ // Otherwise, we need to encode a variable number of active words. We use
+ // active words instead of the number of total words under the observation
+ // that smaller values will be more common.
+ unsigned numActiveWords = value.getActiveWords();
+ emitter.emitVarInt(numActiveWords);
+
+ const uint64_t *rawValueData = value.getRawData();
+ for (unsigned i = 0; i < numActiveWords; ++i)
+ emitter.emitSignedVarInt(rawValueData[i]);
+ }
+
+ void writeAPFloatWithKnownSemantics(const APFloat &value) override {
+ writeAPIntWithKnownWidth(value.bitcastToAPInt());
+ }
+
+ void writeOwnedString(StringRef str) override {
+ emitter.emitVarInt(stringSection.insert(str));
+ }
+
+ void writeOwnedBlob(ArrayRef<char> blob) override {
+ emitter.emitVarInt(blob.size());
+ emitter.emitOwnedBlob(ArrayRef<uint8_t>(
+ reinterpret_cast<const uint8_t *>(blob.data()), blob.size()));
+ }
+
+private:
+ EncodingEmitter &emitter;
+ IRNumberingState &numberingState;
+ StringSectionBuilder &stringSection;
+};
+
/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
/// to go through an intermediate buffer when interacting with code that wants a
/// raw_ostream.
@@ -307,41 +414,6 @@ void EncodingEmitter::emitMultiByteVarInt(uint64_t value) {
emitBytes({reinterpret_cast<uint8_t *>(&value), sizeof(value)});
}
-//===----------------------------------------------------------------------===//
-// StringSectionBuilder
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class is used to simplify the process of emitting the string section.
-class StringSectionBuilder {
-public:
- /// Add the given string to the string section, and return the index of the
- /// string within the section.
- size_t insert(StringRef str) {
- auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()});
- return it.first->second;
- }
-
- /// Write the current set of strings to the given emitter.
- void write(EncodingEmitter &emitter) {
- emitter.emitVarInt(strings.size());
-
- // Emit the sizes in reverse order, so that we don't need to backpatch an
- // offset to the string data or have a separate section.
- for (const auto &it : llvm::reverse(strings))
- emitter.emitVarInt(it.first.size() + 1);
- // Emit the string data itself.
- for (const auto &it : strings)
- emitter.emitNulTerminatedString(it.first.val());
- }
-
-private:
- /// A set of strings referenced within the bytecode. The value of the map is
- /// unused.
- llvm::MapVector<llvm::CachedHashStringRef, size_t> strings;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// Bytecode Writer
//===----------------------------------------------------------------------===//
@@ -464,8 +536,28 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
// Emit the referenced dialects.
auto dialects = numberingState.getDialects();
dialectEmitter.emitVarInt(llvm::size(dialects));
- for (DialectNumbering &dialect : dialects)
- dialectEmitter.emitVarInt(stringSection.insert(dialect.name));
+ for (DialectNumbering &dialect : dialects) {
+ // Write the string section and get the ID.
+ size_t nameID = stringSection.insert(dialect.name);
+
+ // Try writing the version to the versionEmitter.
+ EncodingEmitter versionEmitter;
+ if (dialect.interface) {
+ // The writer used when emitting using a custom bytecode encoding.
+ DialectWriter versionWriter(versionEmitter, numberingState,
+ stringSection);
+ dialect.interface->writeVersion(versionWriter);
+ }
+
+ // If the version emitter is empty, version is not available. We can encode
+ // this in the dialect ID, so if there is no version, we don't write the
+ // section.
+ size_t versionAvailable = versionEmitter.size() > 0;
+ dialectEmitter.emitVarIntWithFlag(nameID, versionAvailable);
+ if (versionAvailable)
+ dialectEmitter.emitSection(bytecode::Section::kDialectVersions,
+ std::move(versionEmitter));
+ }
// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
@@ -479,83 +571,6 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
//===----------------------------------------------------------------------===//
// Attributes and Types
-namespace {
-class DialectWriter : public DialectBytecodeWriter {
-public:
- DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState,
- StringSectionBuilder &stringSection)
- : emitter(emitter), numberingState(numberingState),
- stringSection(stringSection) {}
-
- //===--------------------------------------------------------------------===//
- // IR
- //===--------------------------------------------------------------------===//
-
- void writeAttribute(Attribute attr) override {
- emitter.emitVarInt(numberingState.getNumber(attr));
- }
- void writeType(Type type) override {
- emitter.emitVarInt(numberingState.getNumber(type));
- }
-
- void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
- emitter.emitVarInt(numberingState.getNumber(resource));
- }
-
- //===--------------------------------------------------------------------===//
- // Primitives
- //===--------------------------------------------------------------------===//
-
- void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); }
-
- void writeSignedVarInt(int64_t value) override {
- emitter.emitSignedVarInt(value);
- }
-
- void writeAPIntWithKnownWidth(const APInt &value) override {
- size_t bitWidth = value.getBitWidth();
-
- // If the value is a single byte, just emit it directly without going
- // through a varint.
- if (bitWidth <= 8)
- return emitter.emitByte(value.getLimitedValue());
-
- // If the value fits within a single varint, emit it directly.
- if (bitWidth <= 64)
- return emitter.emitSignedVarInt(value.getLimitedValue());
-
- // Otherwise, we need to encode a variable number of active words. We use
- // active words instead of the number of total words under the observation
- // that smaller values will be more common.
- unsigned numActiveWords = value.getActiveWords();
- emitter.emitVarInt(numActiveWords);
-
- const uint64_t *rawValueData = value.getRawData();
- for (unsigned i = 0; i < numActiveWords; ++i)
- emitter.emitSignedVarInt(rawValueData[i]);
- }
-
- void writeAPFloatWithKnownSemantics(const APFloat &value) override {
- writeAPIntWithKnownWidth(value.bitcastToAPInt());
- }
-
- void writeOwnedString(StringRef str) override {
- emitter.emitVarInt(stringSection.insert(str));
- }
-
- void writeOwnedBlob(ArrayRef<char> blob) override {
- emitter.emitVarInt(blob.size());
- emitter.emitOwnedBlob(ArrayRef<uint8_t>(
- reinterpret_cast<const uint8_t *>(blob.data()), blob.size()));
- }
-
-private:
- EncodingEmitter &emitter;
- IRNumberingState &numberingState;
- StringSectionBuilder &stringSection;
-};
-} // namespace
-
void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
EncodingEmitter attrTypeEmitter;
EncodingEmitter offsetEmitter;
diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir
index 8b70a3ff25d9a..d98c6c6191b87 100644
--- a/mlir/test/Bytecode/invalid/invalid-structure.mlir
+++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir
@@ -9,7 +9,7 @@
//===--------------------------------------------------------------------===//
// RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION
-// VERSION: bytecode version 127 is newer than the current version 0
+// VERSION: bytecode version 127 is newer than the current version 1
//===--------------------------------------------------------------------===//
// Producer
diff --git a/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc
new file mode 100644
index 0000000000000..3645a2f37c4ba
Binary files /dev/null and b/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc
diff er
diff --git a/mlir/test/Bytecode/versioning/versioned-attr-2.0.mlirbc b/mlir/test/Bytecode/versioning/versioned-attr-2.0.mlirbc
new file mode 100644
index 0000000000000..83ff044d7f907
Binary files /dev/null and b/mlir/test/Bytecode/versioning/versioned-attr-2.0.mlirbc
diff er
diff --git a/mlir/test/Bytecode/versioning/versioned-op-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-op-1.12.mlirbc
new file mode 100644
index 0000000000000..688948257008f
Binary files /dev/null and b/mlir/test/Bytecode/versioning/versioned-op-1.12.mlirbc
diff er
diff --git a/mlir/test/Bytecode/versioning/versioned-op-2.0.mlirbc b/mlir/test/Bytecode/versioning/versioned-op-2.0.mlirbc
new file mode 100644
index 0000000000000..fdc8e0d28b2f0
Binary files /dev/null and b/mlir/test/Bytecode/versioning/versioned-op-2.0.mlirbc
diff er
diff --git a/mlir/test/Bytecode/versioning/versioned-op-2.2.mlirbc b/mlir/test/Bytecode/versioning/versioned-op-2.2.mlirbc
new file mode 100644
index 0000000000000..2dc4719ca4975
Binary files /dev/null and b/mlir/test/Bytecode/versioning/versioned-op-2.2.mlirbc
diff er
diff --git a/mlir/test/Bytecode/versioning/versioned_attr.mlir b/mlir/test/Bytecode/versioning/versioned_attr.mlir
new file mode 100644
index 0000000000000..98756cc66a99c
--- /dev/null
+++ b/mlir/test/Bytecode/versioning/versioned_attr.mlir
@@ -0,0 +1,29 @@
+// This file contains a test case representative of a dialect parsing an
+// attribute with versioned custom encoding.
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: target=s390x-{{.*}}
+
+//===--------------------------------------------------------------------===//
+// Test attribute upgrade
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 1.12
+// COM: "test.versionedB"() {attribute = #test.attr_params<24, 42>} : () -> ()
+// COM: }
+// RUN: mlir-opt %S/versioned-attr-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1
+// CHECK1: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
+
+//===--------------------------------------------------------------------===//
+// Test attribute upgrade
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 2.0
+// COM: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
+// COM: }
+// RUN: mlir-opt %S/versioned-attr-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2
+// CHECK2: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
diff --git a/mlir/test/Bytecode/versioning/versioned_op.mlir b/mlir/test/Bytecode/versioning/versioned_op.mlir
new file mode 100644
index 0000000000000..b3141de67823f
--- /dev/null
+++ b/mlir/test/Bytecode/versioning/versioned_op.mlir
@@ -0,0 +1,41 @@
+// This file contains test cases related to the dialect post-parsing upgrade
+// mechanism.
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: target=s390x-{{.*}}
+
+//===--------------------------------------------------------------------===//
+// Test generic
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 2.0
+// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+// COM: }
+// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1
+// CHECK1: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+
+//===--------------------------------------------------------------------===//
+// Test upgrade
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 1.12
+// COM: "test.versionedA"() {dimensions = 123 : i64} : () -> ()
+// COM: }
+// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2
+// CHECK2: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+
+//===--------------------------------------------------------------------===//
+// Test forbidden downgrade
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 2.2
+// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+// COM: }
+// RUN: not mlir-opt %S/versioned-op-2.2.mlirbc 2>&1 | FileCheck %s --check-prefix=ERR_NEW_VERSION
+// ERR_NEW_VERSION: current test dialect version is 2.0, can't parse version: 2.2
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 09f2fcc108ee8..160637a128f56 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -10,15 +10,14 @@
#include "TestAttributes.h"
#include "TestInterfaces.h"
#include "TestTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
@@ -32,9 +31,9 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
-#include <optional>
#include <numeric>
+#include <optional>
// Include this before the using namespace lines below to
// test that we don't have namespace dependencies.
@@ -47,6 +46,15 @@ void test::registerTestDialect(DialectRegistry ®istry) {
registry.insert<TestDialect>();
}
+//===----------------------------------------------------------------------===//
+// TestDialect version utilities
+//===----------------------------------------------------------------------===//
+
+struct TestDialectVersion : public DialectVersion {
+ uint32_t major = 2;
+ uint32_t minor = 0;
+};
+
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
@@ -70,6 +78,107 @@ struct TestResourceBlobManagerInterface
TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
};
+namespace {
+enum test_encoding { k_attr_params = 0 };
+}
+
+// Test support for interacting with the Bytecode reader/writer.
+struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
+ using BytecodeDialectInterface::BytecodeDialectInterface;
+ TestBytecodeDialectInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const final {
+ if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
+ writer.writeVarInt(test_encoding::k_attr_params);
+ writer.writeVarInt(concreteAttr.getV0());
+ writer.writeVarInt(concreteAttr.getV1());
+ return success();
+ }
+ writer.writeAttribute(attr);
+ return success();
+ }
+
+ Attribute readAttribute(DialectBytecodeReader &reader,
+ const DialectVersion &version_) const final {
+ const auto &version = static_cast<const TestDialectVersion &>(version_);
+ if (version.major < 2)
+ return readAttrOldEncoding(reader);
+ if (version.major == 2 && version.minor == 0)
+ return readAttrNewEncoding(reader);
+ // Forbid reading future versions by returning nullptr.
+ return Attribute();
+ }
+
+ // Emit a specific version of the dialect.
+ void writeVersion(DialectBytecodeWriter &writer) const final {
+ auto version = TestDialectVersion();
+ writer.writeVarInt(version.major); // major
+ writer.writeVarInt(version.minor); // minor
+ }
+
+ std::unique_ptr<DialectVersion>
+ readVersion(DialectBytecodeReader &reader) const final {
+ uint64_t major, minor;
+ if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor)))
+ return nullptr;
+ auto version = std::make_unique<TestDialectVersion>();
+ version->major = major;
+ version->minor = minor;
+ return version;
+ }
+
+ LogicalResult upgradeFromVersion(Operation *topLevelOp,
+ const DialectVersion &version_) const final {
+ const auto &version = static_cast<const TestDialectVersion &>(version_);
+ if ((version.major == 2) && (version.minor == 0))
+ return success();
+ if (version.major > 2 || (version.major == 2 && version.minor > 0)) {
+ return topLevelOp->emitError()
+ << "current test dialect version is 2.0, can't parse version: "
+ << version.major << "." << version.minor;
+ }
+ // Prior version 2.0, the old op supported only a single attribute called
+ // "dimensions". We can perform the upgrade.
+ topLevelOp->walk([](TestVersionedOpA op) {
+ if (auto dims = op->getAttr("dimensions")) {
+ op->removeAttr("dimensions");
+ op->setAttr("dims", dims);
+ }
+ op->setAttr("modifier", BoolAttr::get(op->getContext(), false));
+ });
+ return success();
+ }
+
+private:
+ Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)) ||
+ encoding != test_encoding::k_attr_params)
+ return Attribute();
+ // The new encoding has v0 first, v1 second.
+ uint64_t v0, v1;
+ if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1)))
+ return Attribute();
+ return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
+ static_cast<int>(v1));
+ }
+
+ Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)) ||
+ encoding != test_encoding::k_attr_params)
+ return Attribute();
+ // The old encoding has v1 first, v0 second.
+ uint64_t v0, v1;
+ if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0)))
+ return Attribute();
+ return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
+ static_cast<int>(v1));
+ }
+};
+
// Test support for interacting with the AsmPrinter.
struct TestOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
@@ -367,7 +476,7 @@ void TestDialect::initialize() {
addInterface<TestOpAsmInterface>(blobInterface);
addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
- TestReductionPatternInterface>();
+ TestReductionPatternInterface, TestBytecodeDialectInterface>();
allowUnknownOperations();
// Instantiate our fallback op interface that we'll use on specific
@@ -1103,9 +1212,7 @@ OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
return getOperand();
}
-OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) {
- return getValue();
-}
+OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b4447eb8cc142..3f642b8a87ea2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3149,4 +3149,43 @@ def TestCSEOfSingleBlockOp : TEST_Op<"cse_of_single_block_op",
}];
}
+//===----------------------------------------------------------------------===//
+// Test Ops to upgrade base on the dialect versions
+//===----------------------------------------------------------------------===//
+
+def TestVersionedOpA : TEST_Op<"versionedA"> {
+ // A previous version of the dialect (let's say 1.*) supported an attribute
+ // named "dimensions":
+ // let arguments = (ins
+ // AnyI64Attr:$dimensions
+ // );
+
+ // In the current version (2.0) "dimensions" was renamed to "dims", and a new
+ // boolean attribute "modifier" was added. The previous version of the op
+ // corresponds to "modifier=false". We support loading old IR through
+ // upgrading, see `upgradeFromVersion()` in `TestBytecodeDialectInterface`.
+ let arguments = (ins
+ AnyI64Attr:$dims,
+ BoolAttr:$modifier
+ );
+}
+
+def TestVersionedOpB : TEST_Op<"versionedB"> {
+ // A previous version of the dialect (let's say 1.*) we encoded TestAttrParams
+ // with a custom encoding:
+ //
+ // #test.attr_params<X, Y> -> { varInt: Y, varInt: X }
+ //
+ // In the current version (2.0) the encoding changed and the two parameters of
+ // the attribute are swapped:
+ //
+ // #test.attr_params<X, Y> -> { varInt: X, varInt: Y }
+ //
+ // We support loading old IR through a custom readAttribute method, see
+ // `readAttribute()` in `TestBytecodeDialectInterface`
+ let arguments = (ins
+ TestAttrParams:$attribute
+ );
+}
+
#endif // TEST_OPS
More information about the Mlir-commits
mailing list