[Mlir-commits] [mlir] b299ec1 - Expose callbacks for encoding of types/attributes
Mehdi Amini
llvmlistbot at llvm.org
Fri Jul 28 10:44:31 PDT 2023
Author: Mehdi Amini
Date: 2023-07-28T10:44:02-07:00
New Revision: b299ec16661f653df66cdaf161cdc5441bc9803c
URL: https://github.com/llvm/llvm-project/commit/b299ec16661f653df66cdaf161cdc5441bc9803c
DIFF: https://github.com/llvm/llvm-project/commit/b299ec16661f653df66cdaf161cdc5441bc9803c.diff
LOG: Expose callbacks for encoding of types/attributes
[mlir] Expose a mechanism to provide a callback for encoding types and attributes in MLIR bytecode.
Two callbacks are exposed, respectively, to the BytecodeWriterConfig and to the ParserConfig. At bytecode parsing/printing, clients have the ability to specify a callback to be used to optionally read/write the encoding. On failure, fallback path will execute the default parsers and printers for the dialect.
Testing shows how to leverage this functionality to support back-deployment and backward-compatibility usecases when roundtripping to bytecode a client dialect with type/attributes dependencies on upstream.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D153383
Added:
mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
mlir/test/Bytecode/bytecode_callback.mlir
mlir/test/Bytecode/bytecode_callback_full_override.mlir
mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
mlir/test/lib/IR/TestBytecodeCallbacks.cpp
Modified:
mlir/include/mlir/Bytecode/BytecodeImplementation.h
mlir/include/mlir/Bytecode/BytecodeReader.h
mlir/include/mlir/Bytecode/BytecodeWriter.h
mlir/include/mlir/IR/AsmState.h
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
mlir/lib/Bytecode/Writer/IRNumbering.cpp
mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/IR/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 9c9aa7a4fc0ed1..bb1f0f717d8001 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -24,6 +24,17 @@
#include "llvm/ADT/Twine.h"
namespace mlir {
+//===--------------------------------------------------------------------===//
+// Dialect Version Interface.
+//===--------------------------------------------------------------------===//
+
+/// This class is used to represent the version of a dialect, for the purpose
+/// of polymorphic destruction.
+class DialectVersion {
+public:
+ virtual ~DialectVersion() = default;
+};
+
//===----------------------------------------------------------------------===//
// DialectBytecodeReader
//===----------------------------------------------------------------------===//
@@ -38,7 +49,14 @@ class DialectBytecodeReader {
virtual ~DialectBytecodeReader() = default;
/// Emit an error to the reader.
- virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
+ virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
+
+ /// Retrieve the dialect version by name if available.
+ virtual FailureOr<const DialectVersion *>
+ getDialectVersion(StringRef dialectName) const = 0;
+
+ /// Retrieve the context associated to the reader.
+ virtual MLIRContext *getContext() const = 0;
/// Return the bytecode version being read.
virtual uint64_t getBytecodeVersion() const = 0;
@@ -384,17 +402,6 @@ class DialectBytecodeWriter {
virtual int64_t getBytecodeVersion() const = 0;
};
-//===--------------------------------------------------------------------===//
-// Dialect Version Interface.
-//===--------------------------------------------------------------------===//
-
-/// This class is used to represent the version of a dialect, for the purpose
-/// of polymorphic destruction.
-class DialectVersion {
-public:
- virtual ~DialectVersion() = default;
-};
-
//===----------------------------------------------------------------------===//
// BytecodeDialectInterface
//===----------------------------------------------------------------------===//
@@ -409,47 +416,23 @@ class BytecodeDialectInterface
//===--------------------------------------------------------------------===//
/// Read an attribute belonging to this dialect from the given reader. This
- /// method should return null in the case of failure.
+ /// method should return null in the case of failure. Optionally, the dialect
+ /// version can be accessed through the reader.
virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
reader.emitError() << "dialect " << getDialect()->getNamespace()
<< " does not support reading attributes from bytecode";
return Attribute();
}
- /// Read a versioned attribute encoding belonging to this dialect from the
- /// given reader. This method should return null in the case of failure, and
- /// falls back to the non-versioned reader in case the dialect implements
- /// versioning but it does not support versioned custom encodings for the
- /// attributes.
- virtual Attribute readAttribute(DialectBytecodeReader &reader,
- const DialectVersion &version) const {
- reader.emitError()
- << "dialect " << getDialect()->getNamespace()
- << " does not support reading versioned attributes from bytecode";
- return Attribute();
- }
-
/// Read a type belonging to this dialect from the given reader. This method
- /// should return null in the case of failure.
+ /// should return null in the case of failure. Optionally, the dialect version
+ /// can be accessed thorugh the reader.
virtual Type readType(DialectBytecodeReader &reader) const {
reader.emitError() << "dialect " << getDialect()->getNamespace()
<< " does not support reading types from bytecode";
return Type();
}
- /// Read a versioned type encoding belonging to this dialect from the given
- /// reader. This method should return null in the case of failure, and
- /// falls back to the non-versioned reader in case the dialect implements
- /// versioning but it does not support versioned custom encodings for the
- /// types.
- virtual Type readType(DialectBytecodeReader &reader,
- const DialectVersion &version) const {
- reader.emitError()
- << "dialect " << getDialect()->getNamespace()
- << " does not support reading versioned types from bytecode";
- return Type();
- }
-
//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h
index 206e42870ad85a..9f26506d486eec 100644
--- a/mlir/include/mlir/Bytecode/BytecodeReader.h
+++ b/mlir/include/mlir/Bytecode/BytecodeReader.h
@@ -25,7 +25,6 @@ class SourceMgr;
} // namespace llvm
namespace mlir {
-
/// The BytecodeReader allows to load MLIR bytecode files, while keeping the
/// state explicitly available in order to support lazy loading.
/// The `finalize` method must be called before destruction.
diff --git a/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h b/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
new file mode 100644
index 00000000000000..d623d0da2c0c90
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeReaderConfig.h
@@ -0,0 +1,120 @@
+//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines interfaces to read MLIR bytecode files/streams.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEREADERCONFIG_H
+#define MLIR_BYTECODE_BYTECODEREADERCONFIG_H
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+class Attribute;
+class DialectBytecodeReader;
+class Type;
+
+/// A class to interact with the attributes and types parser when parsing MLIR
+/// bytecode.
+template <class T>
+class AttrTypeBytecodeReader {
+public:
+ AttrTypeBytecodeReader() = default;
+ virtual ~AttrTypeBytecodeReader() = default;
+
+ virtual LogicalResult read(DialectBytecodeReader &reader,
+ StringRef dialectName, T &entry) = 0;
+
+ /// Return an Attribute/Type printer implemented via the given callable, whose
+ /// form should match that of the `parse` function above.
+ template <typename CallableT,
+ std::enable_if_t<
+ std::is_convertible_v<
+ CallableT, std::function<LogicalResult(
+ DialectBytecodeReader &, StringRef, T &)>>,
+ bool> = true>
+ static std::unique_ptr<AttrTypeBytecodeReader<T>>
+ fromCallable(CallableT &&readFn) {
+ struct Processor : public AttrTypeBytecodeReader<T> {
+ Processor(CallableT &&readFn)
+ : AttrTypeBytecodeReader(), readFn(std::move(readFn)) {}
+ LogicalResult read(DialectBytecodeReader &reader, StringRef dialectName,
+ T &entry) override {
+ return readFn(reader, dialectName, entry);
+ }
+
+ std::decay_t<CallableT> readFn;
+ };
+ return std::make_unique<Processor>(std::forward<CallableT>(readFn));
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// BytecodeReaderConfig
+//===----------------------------------------------------------------------===//
+
+/// A class containing bytecode-specific configurations of the `ParserConfig`.
+class BytecodeReaderConfig {
+public:
+ BytecodeReaderConfig() = default;
+
+ /// Returns the callbacks available to the parser.
+ ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
+ getAttributeCallbacks() const {
+ return attributeBytecodeParsers;
+ }
+ ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
+ getTypeCallbacks() const {
+ return typeBytecodeParsers;
+ }
+
+ /// Attach a custom bytecode parser callback to the configuration for parsing
+ /// of custom type/attributes encodings.
+ void attachAttributeCallback(
+ std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
+ attributeBytecodeParsers.emplace_back(std::move(parser));
+ }
+ void
+ attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) {
+ typeBytecodeParsers.emplace_back(std::move(parser));
+ }
+
+ /// Attach a custom bytecode parser callback to the configuration for parsing
+ /// of custom type/attributes encodings.
+ template <typename CallableT>
+ std::enable_if_t<std::is_convertible_v<
+ CallableT, std::function<LogicalResult(DialectBytecodeReader &, StringRef,
+ Attribute &)>>>
+ attachAttributeCallback(CallableT &&parserFn) {
+ attachAttributeCallback(AttrTypeBytecodeReader<Attribute>::fromCallable(
+ std::forward<CallableT>(parserFn)));
+ }
+ template <typename CallableT>
+ std::enable_if_t<std::is_convertible_v<
+ CallableT,
+ std::function<LogicalResult(DialectBytecodeReader &, StringRef, Type &)>>>
+ attachTypeCallback(CallableT &&parserFn) {
+ attachTypeCallback(AttrTypeBytecodeReader<Type>::fromCallable(
+ std::forward<CallableT>(parserFn)));
+ }
+
+private:
+ llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
+ attributeBytecodeParsers;
+ llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
+ typeBytecodeParsers;
+};
+
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEREADERCONFIG_H
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index c6df1a21a55bb4..e0c46c3dab27a7 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -17,6 +17,55 @@
namespace mlir {
class Operation;
+class DialectBytecodeWriter;
+
+/// A class to interact with the attributes and types printer when emitting MLIR
+/// bytecode.
+template <class T>
+class AttrTypeBytecodeWriter {
+public:
+ AttrTypeBytecodeWriter() = default;
+ virtual ~AttrTypeBytecodeWriter() = default;
+
+ /// Callback writer API used in IRNumbering, where groups are created and
+ /// type/attribute components are numbered. At this stage, writer is expected
+ /// to be a `NumberingDialectWriter`.
+ virtual LogicalResult write(T entry, std::optional<StringRef> &name,
+ DialectBytecodeWriter &writer) = 0;
+
+ /// Callback writer API used in BytecodeWriter, where groups are created and
+ /// type/attribute components are numbered. Here, DialectBytecodeWriter is
+ /// expected to be an actual writer. The optional stringref specified by
+ /// the user is ignored, since the group was already specified when numbering
+ /// the IR.
+ LogicalResult write(T entry, DialectBytecodeWriter &writer) {
+ std::optional<StringRef> dummy;
+ return write(entry, dummy, writer);
+ }
+
+ /// Return an Attribute/Type printer implemented via the given callable, whose
+ /// form should match that of the `write` function above.
+ template <typename CallableT,
+ std::enable_if_t<std::is_convertible_v<
+ CallableT, std::function<LogicalResult(
+ T, std::optional<StringRef> &,
+ DialectBytecodeWriter &)>>,
+ bool> = true>
+ static std::unique_ptr<AttrTypeBytecodeWriter<T>>
+ fromCallable(CallableT &&writeFn) {
+ struct Processor : public AttrTypeBytecodeWriter<T> {
+ Processor(CallableT &&writeFn)
+ : AttrTypeBytecodeWriter(), writeFn(std::move(writeFn)) {}
+ LogicalResult write(T entry, std::optional<StringRef> &name,
+ DialectBytecodeWriter &writer) override {
+ return writeFn(entry, name, writer);
+ }
+
+ std::decay_t<CallableT> writeFn;
+ };
+ return std::make_unique<Processor>(std::forward<CallableT>(writeFn));
+ }
+};
/// This class contains the configuration used for the bytecode writer. It
/// controls various aspects of bytecode generation, and contains all of the
@@ -48,6 +97,43 @@ class BytecodeWriterConfig {
/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;
+ //===--------------------------------------------------------------------===//
+ // Types and Attributes encoding
+ //===--------------------------------------------------------------------===//
+
+ /// Retrieve the callbacks.
+ ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
+ getAttributeWriterCallbacks() const;
+ ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
+ getTypeWriterCallbacks() const;
+
+ /// Attach a custom bytecode printer callback to the configuration for the
+ /// emission of custom type/attributes encodings.
+ void attachAttributeCallback(
+ std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback);
+ void
+ attachTypeCallback(std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback);
+
+ /// Attach a custom bytecode printer callback to the configuration for the
+ /// emission of custom type/attributes encodings.
+ template <typename CallableT>
+ std::enable_if_t<std::is_convertible_v<
+ CallableT,
+ std::function<LogicalResult(Attribute, std::optional<StringRef> &,
+ DialectBytecodeWriter &)>>>
+ attachAttributeCallback(CallableT &&emitFn) {
+ attachAttributeCallback(AttrTypeBytecodeWriter<Attribute>::fromCallable(
+ std::forward<CallableT>(emitFn)));
+ }
+ template <typename CallableT>
+ std::enable_if_t<std::is_convertible_v<
+ CallableT, std::function<LogicalResult(Type, std::optional<StringRef> &,
+ DialectBytecodeWriter &)>>>
+ attachTypeCallback(CallableT &&emitFn) {
+ attachTypeCallback(AttrTypeBytecodeWriter<Type>::fromCallable(
+ std::forward<CallableT>(emitFn)));
+ }
+
//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 2abeacb8443280..42cbedcf9f8837 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -14,6 +14,7 @@
#ifndef MLIR_IR_ASMSTATE_H_
#define MLIR_IR_ASMSTATE_H_
+#include "mlir/Bytecode/BytecodeReaderConfig.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/MapVector.h"
@@ -475,6 +476,11 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
+ /// Returns the parsing configurations associated to the bytecode read.
+ BytecodeReaderConfig &getBytecodeReaderConfig() const {
+ return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
+ }
+
/// Return the resource parser registered to the given name, or nullptr if no
/// parser with `name` is registered.
AsmResourceParser *getResourceParser(StringRef name) const {
@@ -509,6 +515,7 @@ class ParserConfig {
bool verifyAfterParse;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
+ BytecodeReaderConfig bytecodeReaderConfig;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 0639baf10b0bc0..91e47c4c0e4784 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -451,7 +451,7 @@ struct BytecodeDialect {
/// Returns failure if the dialect couldn't be loaded *and* the provided
/// context does not allow unregistered dialects. The provided reader is used
/// for error emission if necessary.
- LogicalResult load(DialectReader &reader, MLIRContext *ctx);
+ LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
@@ -505,10 +505,11 @@ struct BytecodeOperationName {
/// Parse a single dialect group encoded in the byte stream.
static LogicalResult parseDialectGrouping(
- EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
+ EncodingReader &reader,
+ MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
// Parse the dialect and the number of entries in the group.
- BytecodeDialect *dialect;
+ std::unique_ptr<BytecodeDialect> *dialect;
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
return failure();
uint64_t numEntries;
@@ -516,7 +517,7 @@ static LogicalResult parseDialectGrouping(
return failure();
for (uint64_t i = 0; i < numEntries; ++i)
- if (failed(entryCallback(dialect)))
+ if (failed(entryCallback(dialect->get())))
return failure();
return success();
}
@@ -532,7 +533,7 @@ class ResourceSectionReader {
/// Initialize the resource section reader with the given section data.
LogicalResult
initialize(Location fileLoc, const ParserConfig &config,
- MutableArrayRef<BytecodeDialect> dialects,
+ MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
@@ -682,7 +683,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
- MutableArrayRef<BytecodeDialect> dialects,
+ MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
@@ -731,19 +732,19 @@ LogicalResult ResourceSectionReader::initialize(
// Read the dialect resources from the bytecode.
MLIRContext *ctx = fileLoc->getContext();
while (!offsetReader.empty()) {
- BytecodeDialect *dialect;
+ std::unique_ptr<BytecodeDialect> *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
- failed(dialect->load(dialectReader, ctx)))
+ failed((*dialect)->load(dialectReader, ctx)))
return failure();
- Dialect *loadedDialect = dialect->getLoadedDialect();
+ Dialect *loadedDialect = (*dialect)->getLoadedDialect();
if (!loadedDialect) {
return resourceReader.emitError()
- << "dialect '" << dialect->name << "' is unknown";
+ << "dialect '" << (*dialect)->name << "' is unknown";
}
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
if (!handler) {
return resourceReader.emitError()
- << "unexpected resources for dialect '" << dialect->name << "'";
+ << "unexpected resources for dialect '" << (*dialect)->name << "'";
}
// Ensure that each resource is declared before being processed.
@@ -753,7 +754,7 @@ LogicalResult ResourceSectionReader::initialize(
if (failed(handle)) {
return resourceReader.emitError()
<< "unknown 'resource' key '" << key << "' for dialect '"
- << dialect->name << "'";
+ << (*dialect)->name << "'";
}
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
dialectResources.push_back(*handle);
@@ -796,15 +797,19 @@ class AttrTypeReader {
public:
AttrTypeReader(StringSectionReader &stringReader,
- ResourceSectionReader &resourceReader, Location fileLoc,
- uint64_t &bytecodeVersion)
+ ResourceSectionReader &resourceReader,
+ const llvm::StringMap<BytecodeDialect *> &dialectsMap,
+ uint64_t &bytecodeVersion, Location fileLoc,
+ const ParserConfig &config)
: stringReader(stringReader), resourceReader(resourceReader),
- fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
+ dialectsMap(dialectsMap), fileLoc(fileLoc),
+ bytecodeVersion(bytecodeVersion), parserConfig(config) {}
/// Initialize the attribute and type information within the reader.
- LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
- ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData);
+ LogicalResult
+ initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+ ArrayRef<uint8_t> sectionData,
+ ArrayRef<uint8_t> offsetSectionData);
/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
@@ -878,6 +883,10 @@ class AttrTypeReader {
/// parsing custom encoded attribute/type entries.
ResourceSectionReader &resourceReader;
+ /// The map of the loaded dialects used to retrieve dialect information, such
+ /// as the dialect version.
+ const llvm::StringMap<BytecodeDialect *> &dialectsMap;
+
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -887,27 +896,48 @@ class AttrTypeReader {
/// Current bytecode version being used.
uint64_t &bytecodeVersion;
+
+ /// Reference to the parser configuration.
+ const ParserConfig &parserConfig;
};
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
StringSectionReader &stringReader,
- ResourceSectionReader &resourceReader, EncodingReader &reader,
- uint64_t &bytecodeVersion)
+ ResourceSectionReader &resourceReader,
+ const llvm::StringMap<BytecodeDialect *> &dialectsMap,
+ EncodingReader &reader, uint64_t &bytecodeVersion)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
- resourceReader(resourceReader), reader(reader),
- bytecodeVersion(bytecodeVersion) {}
+ resourceReader(resourceReader), dialectsMap(dialectsMap),
+ reader(reader), bytecodeVersion(bytecodeVersion) {}
- InFlightDiagnostic emitError(const Twine &msg) override {
+ InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
}
+ FailureOr<const DialectVersion *>
+ getDialectVersion(StringRef dialectName) const override {
+ // First check if the dialect is available in the map.
+ auto dialectEntry = dialectsMap.find(dialectName);
+ if (dialectEntry == dialectsMap.end())
+ return failure();
+ // If the dialect was found, try to load it. This will trigger reading the
+ // bytecode version from the version buffer if it wasn't already processed.
+ // Return failure if either of those two actions could not be completed.
+ if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
+ dialectEntry->getValue()->loadedVersion.get() == nullptr)
+ return failure();
+ return dialectEntry->getValue()->loadedVersion.get();
+ }
+
+ MLIRContext *getContext() const override { return getLoc().getContext(); }
+
uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
- DialectReader withEncodingReader(EncodingReader &encReader) {
+ DialectReader withEncodingReader(EncodingReader &encReader) const {
return DialectReader(attrTypeReader, stringReader, resourceReader,
- encReader, bytecodeVersion);
+ dialectsMap, encReader, bytecodeVersion);
}
Location getLoc() const { return reader.getLoc(); }
@@ -1010,6 +1040,7 @@ class DialectReader : public DialectBytecodeReader {
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
ResourceSectionReader &resourceReader;
+ const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
};
@@ -1096,10 +1127,9 @@ class PropertiesSectionReader {
};
} // namespace
-LogicalResult
-AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
- ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData) {
+LogicalResult AttrTypeReader::initialize(
+ MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+ ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
EncodingReader offsetReader(offsetSectionData, fileLoc);
// Parse the number of attribute and type entries.
@@ -1151,6 +1181,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
return offsetReader.emitError(
"unexpected trailing data in the Attribute/Type offset section");
}
+
return success();
}
@@ -1216,32 +1247,54 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
- DialectReader dialectReader(*this, stringReader, resourceReader, reader,
- bytecodeVersion);
+ DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
+ reader, bytecodeVersion);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
+
+ if constexpr (std::is_same_v<T, Type>) {
+ // Try parsing with callbacks first if available.
+ for (const auto &callback :
+ parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
+ if (failed(
+ callback->read(dialectReader, entry.dialect->name, entry.entry)))
+ return failure();
+ // Early return if parsing was successful.
+ if (!!entry.entry)
+ return success();
+
+ // Reset the reader if we failed to parse, so we can fall through the
+ // other parsing functions.
+ reader = EncodingReader(entry.data, reader.getLoc());
+ }
+ } else {
+ // Try parsing with callbacks first if available.
+ for (const auto &callback :
+ parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
+ if (failed(
+ callback->read(dialectReader, entry.dialect->name, entry.entry)))
+ return failure();
+ // Early return if parsing was successful.
+ if (!!entry.entry)
+ return success();
+
+ // Reset the reader if we failed to parse, so we can fall through the
+ // other parsing functions.
+ reader = EncodingReader(entry.data, reader.getLoc());
+ }
+ }
+
// Ensure that the dialect implements the bytecode interface.
if (!entry.dialect->interface) {
return reader.emitError("dialect '", entry.dialect->name,
"' does not implement the bytecode interface");
}
- // Ask the dialect to parse the entry. If the dialect is versioned, parse
- // using the versioned encoding readers.
- if (entry.dialect->loadedVersion.get()) {
- if constexpr (std::is_same_v<T, Type>)
- entry.entry = entry.dialect->interface->readType(
- dialectReader, *entry.dialect->loadedVersion);
- else
- entry.entry = entry.dialect->interface->readAttribute(
- dialectReader, *entry.dialect->loadedVersion);
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(dialectReader);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(dialectReader);
- } else {
- if constexpr (std::is_same_v<T, Type>)
- entry.entry = entry.dialect->interface->readType(dialectReader);
- else
- entry.entry = entry.dialect->interface->readAttribute(dialectReader);
- }
return success(!!entry.entry);
}
@@ -1262,7 +1315,8 @@ class mlir::BytecodeReader::Impl {
llvm::MemoryBufferRef buffer,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
- attrTypeReader(stringReader, resourceReader, fileLoc, version),
+ attrTypeReader(stringReader, resourceReader, dialectsMap, version,
+ fileLoc, config),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -1528,7 +1582,8 @@ class mlir::BytecodeReader::Impl {
StringRef producer;
/// The table of IR units referenced within the bytecode file.
- SmallVector<BytecodeDialect> dialects;
+ SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
+ llvm::StringMap<BytecodeDialect *> dialectsMap;
SmallVector<BytecodeOperationName> opNames;
/// The reader used to process resources within the bytecode.
@@ -1675,7 +1730,8 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Dialect Section
-LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
+LogicalResult BytecodeDialect::load(const DialectReader &reader,
+ MLIRContext *ctx) {
if (dialect)
return success();
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
@@ -1719,13 +1775,15 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
// Parse each of the dialects.
for (uint64_t i = 0; i < numDialects; ++i) {
+ dialects[i] = std::make_unique<BytecodeDialect>();
/// Before version kDialectVersioning, there wasn't any versioning available
/// for dialects, and the entryIdx represent the string itself.
if (version < bytecode::kDialectVersioning) {
- if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+ if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
return failure();
continue;
}
+
// Parse ID representing dialect and version.
uint64_t dialectNameIdx;
bool versionAvailable;
@@ -1733,18 +1791,19 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
versionAvailable)))
return failure();
if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
- dialects[i].name)))
+ dialects[i]->name)))
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
- if (failed(
- sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
+ if (failed(sectionReader.parseSection(sectionID,
+ dialects[i]->versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
emitError(fileLoc, "expected dialect version section");
return failure();
}
}
+ dialectsMap[dialects[i]->name] = dialects[i].get();
}
// Parse the operation names, which are grouped by dialect.
@@ -1792,7 +1851,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
if (!opName->opName) {
// Load the dialect and its version.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader, version);
+ dialectsMap, reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
// If the opName is empty, this is because we use to accept names such as
@@ -1835,7 +1894,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
// Initialize the resource reader with the resource sections.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader, version);
+ dialectsMap, reader, version);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
dialectReader, bufferOwnerRef);
@@ -2036,14 +2095,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
"parsed use-list orders were invalid and could not be applied");
// Resolve dialect version.
- for (const BytecodeDialect &byteCodeDialect : dialects) {
+ for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
// Parsing is complete, give an opportunity to each dialect to visit the
// IR and perform upgrades.
- if (!byteCodeDialect.loadedVersion)
+ if (!byteCodeDialect->loadedVersion)
continue;
- if (byteCodeDialect.interface &&
- failed(byteCodeDialect.interface->upgradeFromVersion(
- *moduleOp, *byteCodeDialect.loadedVersion)))
+ if (byteCodeDialect->interface &&
+ failed(byteCodeDialect->interface->upgradeFromVersion(
+ *moduleOp, *byteCodeDialect->loadedVersion)))
return failure();
}
@@ -2196,7 +2255,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// interface and control the serialization.
if (wasRegistered) {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader, version);
+ dialectsMap, reader, version);
if (failed(
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
return failure();
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index d8f2cb106510d9..75315b5ec75e3d 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -18,15 +18,10 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/CachedHashString.h"
#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Endian.h"
-#include <cstddef>
-#include <cstdint>
-#include <cstring>
+#include "llvm/Support/raw_ostream.h"
#include <optional>
-#include <sys/types.h>
#define DEBUG_TYPE "mlir-bytecode-writer"
@@ -47,6 +42,12 @@ struct BytecodeWriterConfig::Impl {
/// The producer of the bytecode.
StringRef producer;
+ /// Printer callbacks used to emit custom type and attribute encodings.
+ llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
+ attributeWriterCallbacks;
+ llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
+ typeWriterCallbacks;
+
/// A collection of non-dialect resource printers.
SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
};
@@ -60,6 +61,26 @@ BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map,
}
BytecodeWriterConfig::~BytecodeWriterConfig() = default;
+ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
+BytecodeWriterConfig::getAttributeWriterCallbacks() const {
+ return impl->attributeWriterCallbacks;
+}
+
+ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
+BytecodeWriterConfig::getTypeWriterCallbacks() const {
+ return impl->typeWriterCallbacks;
+}
+
+void BytecodeWriterConfig::attachAttributeCallback(
+ std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) {
+ impl->attributeWriterCallbacks.emplace_back(std::move(callback));
+}
+
+void BytecodeWriterConfig::attachTypeCallback(
+ std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) {
+ impl->typeWriterCallbacks.emplace_back(std::move(callback));
+}
+
void BytecodeWriterConfig::attachResourcePrinter(
std::unique_ptr<AsmResourcePrinter> printer) {
impl->externalResourcePrinters.emplace_back(std::move(printer));
@@ -774,32 +795,50 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
auto emitAttrOrType = [&](auto &entry) {
auto entryValue = entry.getValue();
- // First, try to emit this entry using the dialect bytecode interface.
- bool hasCustomEncoding = false;
- if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
- // The writer used when emitting using a custom bytecode encoding.
+ auto emitAttrOrTypeRawImpl = [&]() -> void {
+ RawEmitterOstream(attrTypeEmitter) << entryValue;
+ attrTypeEmitter.emitByte(0);
+ };
+ auto emitAttrOrTypeImpl = [&]() -> bool {
+ // TODO: We don't currently support custom encoded mutable types and
+ // attributes.
+ if (entryValue.template hasTrait<TypeTrait::IsMutable>() ||
+ entryValue.template hasTrait<AttributeTrait::IsMutable>()) {
+ emitAttrOrTypeRawImpl();
+ return false;
+ }
+
DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
numberingState, stringSection);
-
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
- // TODO: We don't currently support custom encoded mutable types.
- hasCustomEncoding =
- !entryValue.template hasTrait<TypeTrait::IsMutable>() &&
- succeeded(interface->writeType(entryValue, dialectWriter));
+ for (const auto &callback : config.typeWriterCallbacks) {
+ if (succeeded(callback->write(entryValue, dialectWriter)))
+ return true;
+ }
+ if (const BytecodeDialectInterface *interface =
+ entry.dialect->interface) {
+ if (succeeded(interface->writeType(entryValue, dialectWriter)))
+ return true;
+ }
} else {
- // TODO: We don't currently support custom encoded mutable attributes.
- hasCustomEncoding =
- !entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
- succeeded(interface->writeAttribute(entryValue, dialectWriter));
+ for (const auto &callback : config.attributeWriterCallbacks) {
+ if (succeeded(callback->write(entryValue, dialectWriter)))
+ return true;
+ }
+ if (const BytecodeDialectInterface *interface =
+ entry.dialect->interface) {
+ if (succeeded(interface->writeAttribute(entryValue, dialectWriter)))
+ return true;
+ }
}
- }
- // If the entry was not emitted using the dialect interface, emit it using
- // the textual format.
- if (!hasCustomEncoding) {
- RawEmitterOstream(attrTypeEmitter) << entryValue;
- attrTypeEmitter.emitByte(0);
- }
+ // If the entry was not emitted using a callback or a dialect interface,
+ // emit it using the textual format.
+ emitAttrOrTypeRawImpl();
+ return false;
+ };
+
+ bool hasCustomEncoding = emitAttrOrTypeImpl();
// Record the offset of this entry.
uint64_t curOffset = attrTypeEmitter.size();
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index ef643ca6d74c76..67f929059e4709 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -314,9 +314,22 @@ void IRNumberingState::number(Attribute attr) {
// If this attribute will be emitted using the bytecode format, perform a
// dummy writing to number any nested components.
- if (const auto *interface = numbering->dialect->interface) {
- // TODO: We don't allow custom encodings for mutable attributes right now.
- if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
+ // TODO: We don't allow custom encodings for mutable attributes right now.
+ if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
+ // Try overriding emission with callbacks.
+ for (const auto &callback : config.getAttributeWriterCallbacks()) {
+ NumberingDialectWriter writer(*this);
+ // The client has the ability to override the group name through the
+ // callback.
+ std::optional<StringRef> groupNameOverride;
+ if (succeeded(callback->write(attr, groupNameOverride, writer))) {
+ if (groupNameOverride.has_value())
+ numbering->dialect = &numberDialect(*groupNameOverride);
+ return;
+ }
+ }
+
+ if (const auto *interface = numbering->dialect->interface) {
NumberingDialectWriter writer(*this);
if (succeeded(interface->writeAttribute(attr, writer)))
return;
@@ -464,9 +477,24 @@ void IRNumberingState::number(Type type) {
// If this type will be emitted using the bytecode format, perform a dummy
// writing to number any nested components.
- if (const auto *interface = numbering->dialect->interface) {
- // TODO: We don't allow custom encodings for mutable types right now.
- if (!type.hasTrait<TypeTrait::IsMutable>()) {
+ // TODO: We don't allow custom encodings for mutable types right now.
+ if (!type.hasTrait<TypeTrait::IsMutable>()) {
+ // Try overriding emission with callbacks.
+ for (const auto &callback : config.getTypeWriterCallbacks()) {
+ NumberingDialectWriter writer(*this);
+ // The client has the ability to override the group name through the
+ // callback.
+ std::optional<StringRef> groupNameOverride;
+ if (succeeded(callback->write(type, groupNameOverride, writer))) {
+ if (groupNameOverride.has_value())
+ numbering->dialect = &numberDialect(*groupNameOverride);
+ return;
+ }
+ }
+
+ // If this attribute will be emitted using the bytecode format, perform a
+ // dummy writing to number any nested components.
+ if (const auto *interface = numbering->dialect->interface) {
NumberingDialectWriter writer(*this);
if (succeeded(interface->writeType(type, writer)))
return;
diff --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir
new file mode 100644
index 00000000000000..cf3981c86b9442
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_callback.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
+// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
+
+func.func @base_test(%arg0 : i32) -> f32 {
+ %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
+ %1 = "test.cast"(%0) : (i32) -> f32
+ return %1 : f32
+}
+
+// VERSION_1_2: Overriding IntegerType encoding...
+// VERSION_1_2: Overriding parsing of IntegerType encoding...
+
+// VERSION_2_0-NOT: Overriding IntegerType encoding...
+// VERSION_2_0-NOT: Overriding parsing of IntegerType encoding...
diff --git a/mlir/test/Bytecode/bytecode_callback_full_override.mlir b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
new file mode 100644
index 00000000000000..21ff947ad389b6
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
@@ -0,0 +1,18 @@
+// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s
+
+// CHECK-NOT: failed to read bytecode
+func.func @base_test(%arg0 : i32) -> f32 {
+ %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
+ %1 = "test.cast"(%0) : (i32) -> f32
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: error: unknown attribute code: 99
+// CHECK: failed to read bytecode
+func.func @base_test(%arg0 : !test.i32) -> f32 {
+ %0 = "test.addi"(%arg0, %arg0) : (!test.i32, !test.i32) -> !test.i32
+ %1 = "test.cast"(%0) : (!test.i32) -> f32
+ return %1 : f32
+}
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
new file mode 100644
index 00000000000000..487972f85af5be
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3
+// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4
+
+"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
+
+// TEST_3: Overriding TestAttrParamsAttr encoding...
+// TEST_3: "test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> ()
+
+// -----
+
+"test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> ()
+
+// TEST_4: Overriding parsing of TestAttrParamsAttr encoding...
+// TEST_4: "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
new file mode 100644
index 00000000000000..1e272ec4f3afc2
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1
+// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2
+
+func.func @base_test(%arg0: !test.i32, %arg1: f32) {
+ return
+}
+
+// TEST_1: Overriding TestI32Type encoding...
+// TEST_1: func.func @base_test([[ARG0:%.+]]: i32, [[ARG1:%.+]]: f32) {
+
+// -----
+
+func.func @base_test(%arg0: i32, %arg1: f32) {
+ return
+}
+
+// TEST_2: Overriding parsing of TestI32Type encoding...
+// TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) {
diff --git a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir
index aba6b3fd1a34aa..87beaa6dd7a056 100644
--- a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir
+++ b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir
@@ -5,12 +5,12 @@
// Index
//===--------------------------------------------------------------------===//
-// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX
+// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=INDEX
// INDEX: invalid Attribute index: 3
//===--------------------------------------------------------------------===//
// Trailing Data
//===--------------------------------------------------------------------===//
-// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA
+// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA
// TRAILING_DATA: trailing characters found after Attribute assembly format: trailing
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 34936783d62ae1..c3235b7b7c68b4 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -14,9 +14,10 @@
#ifndef MLIR_TESTDIALECT_H
#define MLIR_TESTDIALECT_H
-#include "TestTypes.h"
#include "TestAttributes.h"
#include "TestInterfaces.h"
+#include "TestTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -57,6 +58,19 @@ class RewritePatternSet;
#include "TestOpsDialect.h.inc"
namespace test {
+
+//===----------------------------------------------------------------------===//
+// TestDialect version utilities
+//===----------------------------------------------------------------------===//
+
+struct TestDialectVersion : public mlir::DialectVersion {
+ TestDialectVersion() = default;
+ TestDialectVersion(uint32_t _major, uint32_t _minor)
+ : major(_major), minor(_minor){};
+ uint32_t major = 2;
+ uint32_t minor = 0;
+};
+
// Define some classes to exercises the Properties feature.
struct PropertiesWithCustomPrint {
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 7315b253df998e..3dfb76fd0f5f7c 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -14,15 +14,6 @@
using namespace mlir;
using namespace test;
-//===----------------------------------------------------------------------===//
-// TestDialect version utilities
-//===----------------------------------------------------------------------===//
-
-struct TestDialectVersion : public DialectVersion {
- uint32_t major = 2;
- uint32_t minor = 0;
-};
-
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
@@ -47,7 +38,7 @@ struct TestResourceBlobManagerInterface
};
namespace {
-enum test_encoding { k_attr_params = 0 };
+enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
}
// Test support for interacting with the Bytecode reader/writer.
@@ -56,6 +47,24 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
TestBytecodeDialectInterface(Dialect *dialect)
: BytecodeDialectInterface(dialect) {}
+ LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const final {
+ if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
+ writer.writeVarInt(test_encoding::k_test_i32);
+ return success();
+ }
+ return failure();
+ }
+
+ Type readType(DialectBytecodeReader &reader) const final {
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)))
+ return Type();
+ if (encoding == test_encoding::k_test_i32)
+ return TestI32Type::get(getContext());
+ return Type();
+ }
+
LogicalResult writeAttribute(Attribute attr,
DialectBytecodeWriter &writer) const final {
if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
@@ -67,9 +76,13 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
return failure();
}
- Attribute readAttribute(DialectBytecodeReader &reader,
- const DialectVersion &version_) const final {
- const auto &version = static_cast<const TestDialectVersion &>(version_);
+ Attribute readAttribute(DialectBytecodeReader &reader) const final {
+ auto versionOr = reader.getDialectVersion("test");
+ // Assume current version if not available through the reader.
+ const auto version =
+ (succeeded(versionOr))
+ ? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
+ : TestDialectVersion();
if (version.major < 2)
return readAttrOldEncoding(reader);
if (version.major == 2 && version.minor == 0)
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9f897a6a30f541..fb0c54ce7c3b15 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1258,8 +1258,9 @@ def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_fo
}
def TestAddIOp : TEST_Op<"addi"> {
- let arguments = (ins I32:$op1, I32:$op2);
- let results = (outs I32);
+ let arguments = (ins AnyTypeOf<[I32, TestI32]>:$op1,
+ AnyTypeOf<[I32, TestI32]>:$op2);
+ let results = (outs AnyTypeOf<[I32, TestI32]>);
}
def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
@@ -2620,6 +2621,12 @@ def TestVersionedOpB : TEST_Op<"versionedB"> {
);
}
+def TestVersionedOpC : TEST_Op<"versionedC"> {
+ let arguments = (ins AnyAttrOf<[TestAttrParams,
+ I32ElementsAttr]>:$attribute
+ );
+}
+
//===----------------------------------------------------------------------===//
// Test Properties
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 15dbd74aec118f..f899d72219d058 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -369,4 +369,8 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
}
+def TestI32 : Test_Type<"TestI32"> {
+ let mnemonic = "i32";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 447a2481e8dbad..1696a14654831b 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestIR
+ TestBytecodeCallbacks.cpp
TestBuiltinAttributeInterfaces.cpp
TestBuiltinDistinctAttributes.cpp
TestClone.cpp
diff --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp
new file mode 100644
index 00000000000000..1464a80865f776
--- /dev/null
+++ b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp
@@ -0,0 +1,372 @@
+//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Bytecode/BytecodeReader.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/raw_ostream.h"
+#include <list>
+
+using namespace mlir;
+using namespace llvm;
+
+namespace {
+class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
+public:
+ TestDialectVersionParser(cl::Option &O)
+ : cl::parser<test::TestDialectVersion>(O) {}
+
+ bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg,
+ test::TestDialectVersion &v) {
+ long long major, minor;
+ if (getAsSignedInteger(arg.split(".").first, 10, major))
+ return O.error("Invalid argument '" + arg);
+ if (getAsSignedInteger(arg.split(".").second, 10, minor))
+ return O.error("Invalid argument '" + arg);
+ v = test::TestDialectVersion(major, minor);
+ // Returns true on error.
+ return false;
+ }
+ static void print(raw_ostream &os, const test::TestDialectVersion &v) {
+ os << v.major << "." << v.minor;
+ };
+};
+
+/// This is a test pass which uses callbacks to encode attributes and types in a
+/// custom fashion.
+struct TestBytecodeCallbackPass
+ : public PassWrapper<TestBytecodeCallbackPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass)
+
+ StringRef getArgument() const final { return "test-bytecode-callback"; }
+ StringRef getDescription() const final {
+ return "Test encoding of a dialect type/attributes with a custom callback";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<test::TestDialect>();
+ }
+ TestBytecodeCallbackPass() = default;
+ TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {}
+
+ void runOnOperation() override {
+ switch (testKind) {
+ case (0):
+ return runTest0(getOperation());
+ case (1):
+ return runTest1(getOperation());
+ case (2):
+ return runTest2(getOperation());
+ case (3):
+ return runTest3(getOperation());
+ case (4):
+ return runTest4(getOperation());
+ case (5):
+ return runTest5(getOperation());
+ default:
+ llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
+ }
+ }
+
+ mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
+ targetVersion{*this, "test-dialect-version",
+ llvm::cl::desc(
+ "Specifies the test dialect version to emit and parse"),
+ cl::init(test::TestDialectVersion())};
+
+ mlir::Pass::Option<int> testKind{
+ *this, "callback-test",
+ llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)};
+
+private:
+ void doRoundtripWithConfigs(Operation *op,
+ const BytecodeWriterConfig &writeConfig,
+ const ParserConfig &parseConfig) {
+ std::string bytecode;
+ llvm::raw_string_ostream os(bytecode);
+ if (failed(writeBytecodeToFile(op, os, writeConfig))) {
+ op->emitError() << "failed to write bytecode\n";
+ signalPassFailure();
+ return;
+ }
+ auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
+ if (!newModuleOp.get()) {
+ op->emitError() << "failed to read bytecode\n";
+ signalPassFailure();
+ return;
+ }
+ // Print the module to the output stream, so that we can filecheck the
+ // result.
+ newModuleOp->print(llvm::outs());
+ return;
+ }
+
+ // Test0: let's assume that versions older than 2.0 were relying on a special
+ // integer attribute of a deprecated dialect called "funky". Assume that its
+ // encoding was made by two varInts, the first was the ID (999) and the second
+ // contained width and signedness info. We can emit it using a callback
+ // writing a custom encoding for the "funky" dialect group, and parse it back
+ // with a custom parser reading the same encoding in the same dialect group.
+ // Note that the ID 999 does not correspond to a valid integer type in the
+ // current encodings of builtin types.
+ void runTest0(Operation *op) {
+ auto newCtx = std::make_shared<MLIRContext>();
+ test::TestDialectVersion targetEmissionVersion = targetVersion;
+ BytecodeWriterConfig writeConfig;
+ writeConfig.attachTypeCallback(
+ [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
+ DialectBytecodeWriter &writer) -> LogicalResult {
+ // Do not override anything if version less than 2.0.
+ if (targetEmissionVersion.major >= 2)
+ return failure();
+
+ // For version less than 2.0, override the encoding of IntegerType.
+ if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
+ llvm::outs() << "Overriding IntegerType encoding...\n";
+ dialectGroupName = StringLiteral("funky");
+ writer.writeVarInt(/* IntegerType */ 999);
+ writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
+ return success();
+ }
+ return failure();
+ });
+ newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
+ newCtx->allowUnregisteredDialects();
+ ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
+ parseConfig.getBytecodeReaderConfig().attachTypeCallback(
+ [&](DialectBytecodeReader &reader, StringRef dialectName,
+ Type &entry) -> LogicalResult {
+ // Get test dialect version from the version map.
+ auto versionOr = reader.getDialectVersion("test");
+ assert(succeeded(versionOr) && "expected reader to be able to access "
+ "the version for test dialect");
+ const auto *version =
+ reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+
+ // TODO: once back-deployment is formally supported,
+ // `targetEmissionVersion` will be encoded in the bytecode file, and
+ // exposed through the versionMap. Right now though this is not yet
+ // supported. For the purpose of the test, just use
+ // `targetEmissionVersion`.
+ (void)version;
+ if (targetEmissionVersion.major >= 2)
+ return success();
+
+ // `dialectName` is the name of the group we have the opportunity to
+ // override. In this case, override only the dialect group "funky",
+ // for which does not exist in memory.
+ if (dialectName != StringLiteral("funky"))
+ return success();
+
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)) || encoding != 999)
+ return success();
+ llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
+ uint64_t _widthAndSignedness, width;
+ IntegerType::SignednessSemantics signedness;
+ if (succeeded(reader.readVarInt(_widthAndSignedness)) &&
+ ((width = _widthAndSignedness >> 2), true) &&
+ ((signedness = static_cast<IntegerType::SignednessSemantics>(
+ _widthAndSignedness & 0x3)),
+ true))
+ entry = IntegerType::get(reader.getContext(), width, signedness);
+ // Return nullopt to fall through the rest of the parsing code path.
+ return success();
+ });
+ doRoundtripWithConfigs(op, writeConfig, parseConfig);
+ return;
+ }
+
+ // Test1: When writing bytecode, we override the encoding of TestI32Type with
+ // the encoding of builtin IntegerType. We can natively parse this without
+ // the use of a callback, relying on the existing builtin reader mechanism.
+ void runTest1(Operation *op) {
+ auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+ BytecodeDialectInterface *iface =
+ builtin->getRegisteredInterface<BytecodeDialectInterface>();
+ BytecodeWriterConfig writeConfig;
+ writeConfig.attachTypeCallback(
+ [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
+ DialectBytecodeWriter &writer) -> LogicalResult {
+ // Emit TestIntegerType using the builtin dialect encoding.
+ if (llvm::isa<test::TestI32Type>(entryValue)) {
+ llvm::outs() << "Overriding TestI32Type encoding...\n";
+ auto builtinI32Type =
+ IntegerType::get(op->getContext(), 32,
+ IntegerType::SignednessSemantics::Signless);
+ // Specify that this type will need to be written as part of the
+ // builtin group. This will override the default dialect group of
+ // the attribute (test).
+ dialectGroupName = StringLiteral("builtin");
+ if (succeeded(iface->writeType(builtinI32Type, writer)))
+ return success();
+ }
+ return failure();
+ });
+ // We natively parse the attribute as a builtin, so no callback needed.
+ ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
+ doRoundtripWithConfigs(op, writeConfig, parseConfig);
+ return;
+ }
+
+ // Test2: When writing bytecode, we write standard builtin IntegerTypes. At
+ // parsing, we use the encoding of IntegerType to intercept all i32. Then,
+ // instead of creating i32s, we assemble TestI32Type and return it.
+ void runTest2(Operation *op) {
+ auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+ BytecodeDialectInterface *iface =
+ builtin->getRegisteredInterface<BytecodeDialectInterface>();
+ BytecodeWriterConfig writeConfig;
+ ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
+ parseConfig.getBytecodeReaderConfig().attachTypeCallback(
+ [&](DialectBytecodeReader &reader, StringRef dialectName,
+ Type &entry) -> LogicalResult {
+ if (dialectName != StringLiteral("builtin"))
+ return success();
+ Type builtinAttr = iface->readType(reader);
+ if (auto integerType =
+ llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
+ if (integerType.getWidth() == 32 && integerType.isSignless()) {
+ llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
+ entry = test::TestI32Type::get(reader.getContext());
+ }
+ }
+ return success();
+ });
+ doRoundtripWithConfigs(op, writeConfig, parseConfig);
+ return;
+ }
+
+ // Test3: When writing bytecode, we override the encoding of
+ // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
+ // can natively parse this without the use of a callback, relying on the
+ // existing builtin reader mechanism.
+ void runTest3(Operation *op) {
+ auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+ BytecodeDialectInterface *iface =
+ builtin->getRegisteredInterface<BytecodeDialectInterface>();
+ auto i32Type = IntegerType::get(op->getContext(), 32,
+ IntegerType::SignednessSemantics::Signless);
+ BytecodeWriterConfig writeConfig;
+ writeConfig.attachAttributeCallback(
+ [&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
+ DialectBytecodeWriter &writer) -> LogicalResult {
+ // Emit TestIntegerType using the builtin dialect encoding.
+ if (auto testParamAttrs =
+ llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
+ llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
+ // Specify that this attribute will need to be written as part of
+ // the builtin group. This will override the default dialect group
+ // of the attribute (test).
+ dialectGroupName = StringLiteral("builtin");
+ auto denseAttr = DenseIntElementsAttr::get(
+ RankedTensorType::get({2}, i32Type),
+ {testParamAttrs.getV0(), testParamAttrs.getV1()});
+ if (succeeded(iface->writeAttribute(denseAttr, writer)))
+ return success();
+ }
+ return failure();
+ });
+ // We natively parse the attribute as a builtin, so no callback needed.
+ ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
+ doRoundtripWithConfigs(op, writeConfig, parseConfig);
+ return;
+ }
+
+ // Test4: When writing bytecode, we write standard builtin
+ // DenseIntElementsAttr. At parsing, we use the encoding of
+ // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
+ // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
+ // TestAttrParamsAttr and return it.
+ void runTest4(Operation *op) {
+ auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+ BytecodeDialectInterface *iface =
+ builtin->getRegisteredInterface<BytecodeDialectInterface>();
+ auto i32Type = IntegerType::get(op->getContext(), 32,
+ IntegerType::SignednessSemantics::Signless);
+ BytecodeWriterConfig writeConfig;
+ ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
+ parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
+ [&](DialectBytecodeReader &reader, StringRef dialectName,
+ Attribute &entry) -> LogicalResult {
+ // Override only the case where the return type of the builtin reader
+ // is an i32 and fall through on all the other cases, since we want to
+ // still use TestDialect normal codepath to parse the other types.
+ Attribute builtinAttr = iface->readAttribute(reader);
+ if (auto denseAttr =
+ llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
+ if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
+ denseAttr.getElementType() == i32Type) {
+ llvm::outs()
+ << "Overriding parsing of TestAttrParamsAttr encoding...\n";
+ int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
+ int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
+ entry =
+ test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
+ }
+ }
+ return success();
+ });
+ doRoundtripWithConfigs(op, writeConfig, parseConfig);
+ return;
+ }
+
+ // Test5: When writing bytecode, we want TestDialect to use nothing else than
+ // the builtin types and attributes and take full control of the encoding,
+ // returning failure if any type or attribute is not part of builtin.
+ void runTest5(Operation *op) {
+ auto builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
+ BytecodeDialectInterface *iface =
+ builtin->getRegisteredInterface<BytecodeDialectInterface>();
+ BytecodeWriterConfig writeConfig;
+ writeConfig.attachAttributeCallback(
+ [&](Attribute attr, std::optional<StringRef> &dialectGroupName,
+ DialectBytecodeWriter &writer) -> LogicalResult {
+ return iface->writeAttribute(attr, writer);
+ });
+ writeConfig.attachTypeCallback(
+ [&](Type type, std::optional<StringRef> &dialectGroupName,
+ DialectBytecodeWriter &writer) -> LogicalResult {
+ return iface->writeType(type, writer);
+ });
+ ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
+ parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
+ [&](DialectBytecodeReader &reader, StringRef dialectName,
+ Attribute &entry) -> LogicalResult {
+ Attribute builtinAttr = iface->readAttribute(reader);
+ if (!builtinAttr)
+ return failure();
+ entry = builtinAttr;
+ return success();
+ });
+ parseConfig.getBytecodeReaderConfig().attachTypeCallback(
+ [&](DialectBytecodeReader &reader, StringRef dialectName,
+ Type &entry) -> LogicalResult {
+ Type builtinType = iface->readType(reader);
+ if (!builtinType) {
+ return failure();
+ }
+ entry = builtinType;
+ return success();
+ });
+ doRoundtripWithConfigs(op, writeConfig, parseConfig);
+ return;
+ }
+};
+} // namespace
+
+namespace mlir {
+void registerTestBytecodeCallbackPasses() {
+ PassRegistration<TestBytecodeCallbackPass>();
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e91cb118461ec5..78bd70b40c91e7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -43,6 +43,7 @@ void registerSymbolTestPasses();
void registerRegionTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAffineReifyValueBoundsPass();
+void registerTestBytecodeCallbackPasses();
void registerTestDecomposeAffineOpPass();
void registerTestAffineLoopUnswitchingPass();
void registerTestAllReduceLoweringPass();
@@ -167,6 +168,7 @@ void registerTestPasses() {
registerTestDecomposeAffineOpPass();
registerTestAffineLoopUnswitchingPass();
registerTestAllReduceLoweringPass();
+ registerTestBytecodeCallbackPasses();
registerTestFunc();
registerTestGpuMemoryPromotionPass();
registerTestLoopPermutationPass();
More information about the Mlir-commits
mailing list