[Mlir-commits] [mlir] 0610e2f - [mlir][bytecode] Allow client to specify a desired version.
Jacques Pienaar
llvmlistbot at llvm.org
Sat Apr 29 05:36:01 PDT 2023
Author: Jacques Pienaar
Date: 2023-04-29T05:35:53-07:00
New Revision: 0610e2f6a2d42d83ff7a75729b3afa45d75729cc
URL: https://github.com/llvm/llvm-project/commit/0610e2f6a2d42d83ff7a75729b3afa45d75729cc
DIFF: https://github.com/llvm/llvm-project/commit/0610e2f6a2d42d83ff7a75729b3afa45d75729cc.diff
LOG: [mlir][bytecode] Allow client to specify a desired version.
Add method to set a desired bytecode file format to generate. Change
write method to be able to return status including the minimum bytecode
version needed by reader. This enables generating an older version of
the bytecode (not dialect ops, attributes or types). But this does not
guarantee that an older version can always be generated, e.g., if a
dialect uses a new encoding only available at later bytecode version.
This clamps setting to at most current version.
Differential Revision: https://reviews.llvm.org/D146555
Added:
mlir/test/Bytecode/versioning/versioned_bytecode.mlir
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir/Bytecode/BytecodeImplementation.h
mlir/include/mlir/Bytecode/BytecodeWriter.h
mlir/include/mlir/CAPI/IR.h
mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
mlir/lib/Bytecode/Writer/IRNumbering.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b45b955363f67..315ec3a846b65 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -48,6 +48,7 @@ extern "C" {
}; \
typedef struct name name
+DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
DEFINE_C_API_STRUCT(MlirContext, void);
DEFINE_C_API_STRUCT(MlirDialect, void);
DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
@@ -408,6 +409,24 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
MLIR_CAPI_EXPORTED void
mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
+//===----------------------------------------------------------------------===//
+// Bytecode printing flags API.
+//===----------------------------------------------------------------------===//
+
+/// Creates new printing flags with defaults, intended for customization.
+/// Must be freed with a call to mlirBytecodeWriterConfigDestroy().
+MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig
+mlirBytecodeWriterConfigCreate(void);
+
+/// Destroys printing flags created with mlirBytecodeWriterConfigCreate.
+MLIR_CAPI_EXPORTED void
+mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config);
+
+/// Sets the version to emit in the writer config.
+MLIR_CAPI_EXPORTED void
+mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
+ int64_t version);
+
//===----------------------------------------------------------------------===//
// Operation API.
//===----------------------------------------------------------------------===//
@@ -546,10 +565,27 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op,
MlirStringCallback callback,
void *userData);
-/// Same as mlirOperationPrint but writing the bytecode format out.
-MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op,
- MlirStringCallback callback,
- void *userData);
+struct MlirBytecodeWriterResult {
+ int64_t minVersion;
+};
+typedef struct MlirBytecodeWriterResult MlirBytecodeWriterResult;
+
+inline static bool
+mlirBytecodeWriterResultGetMinVersion(MlirBytecodeWriterResult res) {
+ return res.minVersion;
+}
+
+/// Same as mlirOperationPrint but writing the bytecode format and returns the
+/// minimum bytecode version the consumer needs to support.
+MLIR_CAPI_EXPORTED MlirBytecodeWriterResult mlirOperationWriteBytecode(
+ MlirOperation op, MlirStringCallback callback, void *userData);
+
+/// Same as mlirOperationWriteBytecode but with writer config.
+MLIR_CAPI_EXPORTED MlirBytecodeWriterResult
+mlirOperationWriteBytecodeWithConfig(MlirOperation op,
+ MlirBytecodeWriterConfig config,
+ MlirStringCallback callback,
+ void *userData);
/// Prints an operation to stderr.
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 6e7b9ff26342e..60f5475609ac7 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -233,6 +233,9 @@ class DialectBytecodeWriter {
/// guaranteed to not die before the end of the bytecode process. The blob is
/// written as-is, with no additional compression or compaction.
virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
+
+ /// Return the bytecode version being emitted for.
+ virtual int64_t getBytecodeVersion() const = 0;
};
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index fb4329e4d66f0..8877c59dd9da2 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -40,6 +40,12 @@ class BytecodeWriterConfig {
/// Return an instance of the internal implementation.
const Impl &getImpl() const { return *impl; }
+ /// Set the desired bytecode version to emit. This function clamps the version
+ /// to the existing version if larger than existing. The desired version may
+ /// not be used depending on the features used and the actual version required
+ /// is returned by bytecode writer entry point.
+ void setDesiredBytecodeVersion(int64_t bytecodeVersion);
+
//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//
@@ -69,14 +75,21 @@ class BytecodeWriterConfig {
std::unique_ptr<Impl> impl;
};
+/// Status of bytecode serialization.
+struct BytecodeWriterResult {
+ /// The minimum version of the reader required to read the serialized file.
+ int64_t minVersion;
+};
+
//===----------------------------------------------------------------------===//
// Entry Points
//===----------------------------------------------------------------------===//
/// Write the bytecode for the given operation to the provided output stream.
/// For streams where it matters, the given stream should be in "binary" mode.
-void writeBytecodeToFile(Operation *op, raw_ostream &os,
- const BytecodeWriterConfig &config = {});
+BytecodeWriterResult
+writeBytecodeToFile(Operation *op, raw_ostream &os,
+ const BytecodeWriterConfig &config = {});
} // namespace mlir
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 2f32c76e12f06..b8ccec896c27b 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -15,11 +15,13 @@
#ifndef MLIR_CAPI_IR_H
#define MLIR_CAPI_IR_H
+#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
+DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 1581f6b62c23a..0cba2147a6e08 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -90,6 +90,15 @@ class MlirOptMainConfig {
}
StringRef getIrdlFile() const { return irdlFileFlag; }
+ /// Set the bytecode version to emit.
+ MlirOptMainConfig &setEmitBytecodeVersion(int64_t version) {
+ emitBytecodeVersion = version;
+ return *this;
+ }
+ std::optional<int64_t> bytecodeVersionToEmit() const {
+ return emitBytecodeVersion;
+ }
+
/// Set the callback to populate the pass manager.
MlirOptMainConfig &
setPassPipelineSetupFn(std::function<LogicalResult(PassManager &)> callback) {
@@ -168,6 +177,9 @@ class MlirOptMainConfig {
/// Location Breakpoints to filter the action logging.
std::vector<tracing::BreakpointManager *> logActionLocationFilter;
+ /// Emit bytecode at given version.
+ std::optional<int64_t> emitBytecodeVersion = std::nullopt;
+
/// The callback to populate the pass manager.
std::function<LogicalResult(PassManager &)> passPipelineCallback;
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 81c5cd2183107..052998be18ffe 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -124,6 +124,9 @@ static const char kOperationPrintBytecodeDocstring[] =
Args:
file: The file like object to write to.
+ desired_version: The version of bytecode to emit.
+Returns:
+ The bytecode writer status.
)";
static const char kOperationStrDunderDocstring[] =
@@ -1131,12 +1134,21 @@ void PyOperationBase::print(py::object fileObject, bool binary,
mlirOpPrintingFlagsDestroy(flags);
}
-void PyOperationBase::writeBytecode(const py::object &fileObject) {
+MlirBytecodeWriterResult
+PyOperationBase::writeBytecode(const py::object &fileObject,
+ std::optional<int64_t> bytecodeVersion) {
PyOperation &operation = getOperation();
operation.checkValid();
PyFileAccumulator accum(fileObject, /*binary=*/true);
- mlirOperationWriteBytecode(operation, accum.getCallback(),
- accum.getUserData());
+
+ if (!bytecodeVersion.has_value())
+ return mlirOperationWriteBytecode(operation, accum.getCallback(),
+ accum.getUserData());
+
+ MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
+ mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
+ return mlirOperationWriteBytecodeWithConfig(
+ operation, config, accum.getCallback(), accum.getUserData());
}
py::object PyOperationBase::getAsm(bool binary,
@@ -2757,6 +2769,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("use_local_scope") = false,
py::arg("assume_verified") = false, kOperationPrintDocstring)
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
+ py::arg("desired_version") = py::none(),
kOperationPrintBytecodeDocstring)
.def("get_asm", &PyOperationBase::getAsm,
// Careful: Lots of arguments must match up with get_asm method.
@@ -3365,6 +3378,10 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("from_op"), py::arg("all_sym_uses_visible"),
py::arg("callback"));
+ py::class_<MlirBytecodeWriterResult>(m, "BytecodeResult", py::module_local())
+ .def("min_version",
+ [](MlirBytecodeWriterResult &res) { return res.minVersion; });
+
// Container bindings.
PyBlockArgumentList::bind(m);
PyBlockIterator::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 920e6f467c5bc..56bb834b4eac9 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -554,7 +554,9 @@ class PyOperationBase {
bool assumeVerified);
// Implement the bound 'writeBytecode' method.
- void writeBytecode(const pybind11::object &fileObject);
+ MlirBytecodeWriterResult
+ writeBytecode(const pybind11::object &fileObject,
+ std::optional<int64_t> bytecodeVersion);
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 8433f3d84b852..95729a2a4fa5c 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -27,6 +27,10 @@ using namespace mlir::bytecode::detail;
struct BytecodeWriterConfig::Impl {
Impl(StringRef producer) : producer(producer) {}
+ /// Version to use when writing.
+ /// Note: This only
diff ers from kVersion if a specific version is set.
+ int64_t bytecodeVersion = bytecode::kVersion;
+
/// The producer of the bytecode.
StringRef producer;
@@ -48,6 +52,12 @@ void BytecodeWriterConfig::attachResourcePrinter(
impl->externalResourcePrinters.emplace_back(std::move(printer));
}
+void BytecodeWriterConfig::setDesiredBytecodeVersion(int64_t bytecodeVersion) {
+ // Clamp to current version.
+ impl->bytecodeVersion =
+ std::min<int64_t>(bytecodeVersion, bytecode::kVersion);
+}
+
//===----------------------------------------------------------------------===//
// EncodingEmitter
//===----------------------------------------------------------------------===//
@@ -295,7 +305,8 @@ class StringSectionBuilder {
class DialectWriter : public DialectBytecodeWriter {
public:
- DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState,
+ DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter,
+ IRNumberingState &numberingState,
StringSectionBuilder &stringSection)
: emitter(emitter), numberingState(numberingState),
stringSection(stringSection) {}
@@ -362,7 +373,10 @@ class DialectWriter : public DialectBytecodeWriter {
reinterpret_cast<const uint8_t *>(blob.data()), blob.size()));
}
+ int64_t getBytecodeVersion() const override { return bytecodeVersion; }
+
private:
+ int64_t bytecodeVersion;
EncodingEmitter &emitter;
IRNumberingState &numberingState;
StringSectionBuilder &stringSection;
@@ -421,11 +435,11 @@ void EncodingEmitter::emitMultiByteVarInt(uint64_t value) {
namespace {
class BytecodeWriter {
public:
- BytecodeWriter(Operation *op) : numberingState(op) {}
+ BytecodeWriter(Operation *op, const BytecodeWriterConfig::Impl &config)
+ : numberingState(op), config(config) {}
/// Write the bytecode for the given root operation.
- void write(Operation *rootOp, raw_ostream &os,
- const BytecodeWriterConfig::Impl &config);
+ void write(Operation *rootOp, raw_ostream &os);
private:
//===--------------------------------------------------------------------===//
@@ -449,8 +463,7 @@ class BytecodeWriter {
//===--------------------------------------------------------------------===//
// Resources
- void writeResourceSection(Operation *op, EncodingEmitter &emitter,
- const BytecodeWriterConfig::Impl &config);
+ void writeResourceSection(Operation *op, EncodingEmitter &emitter);
//===--------------------------------------------------------------------===//
// Strings
@@ -465,11 +478,13 @@ class BytecodeWriter {
/// The IR numbering state generated for the root operation.
IRNumberingState numberingState;
+
+ /// Configuration dictating bytecode emission.
+ const BytecodeWriterConfig::Impl &config;
};
} // namespace
-void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
- const BytecodeWriterConfig::Impl &config) {
+void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
EncodingEmitter emitter;
// Emit the bytecode file header. This is how we identify the output as a
@@ -477,7 +492,7 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
emitter.emitString("ML\xefR");
// Emit the bytecode version.
- emitter.emitVarInt(bytecode::kVersion);
+ emitter.emitVarInt(config.bytecodeVersion);
// Emit the producer.
emitter.emitNulTerminatedString(config.producer);
@@ -492,7 +507,7 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
writeIRSection(emitter, rootOp);
// Emit the resources section.
- writeResourceSection(rootOp, emitter, config);
+ writeResourceSection(rootOp, emitter);
// Emit the string section.
writeStringSection(emitter);
@@ -540,12 +555,17 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
// Write the string section and get the ID.
size_t nameID = stringSection.insert(dialect.name);
+ if (config.bytecodeVersion == 0) {
+ dialectEmitter.emitVarInt(nameID);
+ continue;
+ }
+
// Try writing the version to the versionEmitter.
EncodingEmitter versionEmitter;
if (dialect.interface) {
// The writer used when emitting using a custom bytecode encoding.
- DialectWriter versionWriter(versionEmitter, numberingState,
- stringSection);
+ DialectWriter versionWriter(config.bytecodeVersion, versionEmitter,
+ numberingState, stringSection);
dialect.interface->writeVersion(versionWriter);
}
@@ -586,8 +606,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
bool hasCustomEncoding = false;
if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
// The writer used when emitting using a custom bytecode encoding.
- DialectWriter dialectWriter(attrTypeEmitter, numberingState,
- stringSection);
+ DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
+ numberingState, stringSection);
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
// TODO: We don't currently support custom encoded mutable types.
@@ -787,9 +807,8 @@ class ResourceBuilder : public AsmResourceBuilder {
};
} // namespace
-void BytecodeWriter::writeResourceSection(
- Operation *op, EncodingEmitter &emitter,
- const BytecodeWriterConfig::Impl &config) {
+void BytecodeWriter::writeResourceSection(Operation *op,
+ EncodingEmitter &emitter) {
EncodingEmitter resourceEmitter;
EncodingEmitter resourceOffsetEmitter;
uint64_t prevOffset = 0;
@@ -868,8 +887,12 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
// Entry Points
//===----------------------------------------------------------------------===//
-void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
- const BytecodeWriterConfig &config) {
- BytecodeWriter writer(op);
- writer.write(op, os, config.getImpl());
+BytecodeWriterResult
+mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
+ const BytecodeWriterConfig &config) {
+ BytecodeWriter writer(op, config.getImpl());
+ writer.write(op, os);
+ // Return the bytecode version emitted - currently there is no additional
+ // feedback as to minimum beyond the requested one.
+ return {config.getImpl().bytecodeVersion};
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 06778b9c7c75e..7f56e9a94d299 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -7,11 +7,13 @@
//===----------------------------------------------------------------------===//
#include "IRNumbering.h"
+#include "../Encoding.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
using namespace mlir::bytecode::detail;
@@ -41,6 +43,10 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
}
void writeOwnedBlob(ArrayRef<char> blob) override {}
+ int64_t getBytecodeVersion() const override {
+ llvm_unreachable("unexpected querying of version in IRNumbering");
+ }
+
/// The parent numbering state that is populated by this writer.
IRNumberingState &state;
};
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 0bbcb3083062f..03f154965e965 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -145,6 +145,23 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
unwrap(flags)->assumeVerified();
}
+//===----------------------------------------------------------------------===//
+// Bytecode printing flags API.
+//===----------------------------------------------------------------------===//
+
+MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() {
+ return wrap(new BytecodeWriterConfig());
+}
+
+void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) {
+ delete unwrap(config);
+}
+
+void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags,
+ int64_t version) {
+ unwrap(flags)->setDesiredBytecodeVersion(version);
+}
+
//===----------------------------------------------------------------------===//
// Location API.
//===----------------------------------------------------------------------===//
@@ -507,10 +524,25 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
unwrap(op)->print(stream, *unwrap(flags));
}
-void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
- void *userData) {
+MlirBytecodeWriterResult mlirOperationWriteBytecode(MlirOperation op,
+ MlirStringCallback callback,
+ void *userData) {
+ detail::CallbackOstream stream(callback, userData);
+ MlirBytecodeWriterResult res;
+ BytecodeWriterResult r = writeBytecodeToFile(unwrap(op), stream);
+ res.minVersion = r.minVersion;
+ return res;
+}
+
+MlirBytecodeWriterResult mlirOperationWriteBytecodeWithConfig(
+ MlirOperation op, MlirBytecodeWriterConfig config,
+ MlirStringCallback callback, void *userData) {
detail::CallbackOstream stream(callback, userData);
- writeBytecodeToFile(unwrap(op), stream);
+ BytecodeWriterResult r =
+ writeBytecodeToFile(unwrap(op), stream, *unwrap(config));
+ MlirBytecodeWriterResult res;
+ res.minVersion = r.minVersion;
+ return res;
}
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 06039b9fa1ec1..3b2b5ed178f1a 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -52,6 +52,22 @@ using namespace mlir;
using namespace llvm;
namespace {
+class BytecodeVersionParser : public cl::parser<std::optional<int64_t>> {
+public:
+ BytecodeVersionParser(cl::Option &O)
+ : cl::parser<std::optional<int64_t>>(O) {}
+
+ bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg,
+ std::optional<int64_t> &v) {
+ long long w;
+ if (getAsSignedInteger(arg, 10, w))
+ return O.error("Invalid argument '" + arg +
+ "', only integer is supported.");
+ v = w;
+ return false;
+ }
+};
+
/// This class is intended to manage the handling of command line options for
/// creating a *-opt config. This is a singleton.
struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
@@ -74,6 +90,13 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
"emit-bytecode", cl::desc("Emit bytecode when generating output"),
cl::location(emitBytecodeFlag), cl::init(false));
+ static cl::opt<std::optional<int64_t>, /*ExternalStorage=*/true,
+ BytecodeVersionParser>
+ bytecodeVersion(
+ "emit-bytecode-version",
+ cl::desc("Use specified bytecode when generating output"),
+ cl::location(emitBytecodeVersion), cl::init(std::nullopt));
+
static cl::opt<std::string, /*ExternalStorage=*/true> irdlFile(
"irdl-file",
cl::desc("IRDL file to register before processing the input"),
@@ -241,13 +264,23 @@ performActions(raw_ostream &os,
TimingScope outputTiming = timing.nest("Output");
if (config.shouldEmitBytecode()) {
BytecodeWriterConfig writerConfig(fallbackResourceMap);
+ if (auto v = config.bytecodeVersionToEmit()) {
+ writerConfig.setDesiredBytecodeVersion(*v);
+ // Returns failure if requested version couldn't be used for opt tools.
+ return success(
+ writeBytecodeToFile(op.get(), os, writerConfig).minVersion <= *v);
+ }
writeBytecodeToFile(op.get(), os, writerConfig);
- } else {
- AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
- &fallbackResourceMap);
- op.get()->print(os, asmState);
- os << '\n';
+ return success();
}
+
+ if (config.bytecodeVersionToEmit().has_value())
+ return emitError(UnknownLoc::get(pm.getContext()))
+ << "bytecode version while not emitting bytecode";
+ AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
+ &fallbackResourceMap);
+ op.get()->print(os, asmState);
+ os << '\n';
return success();
}
diff --git a/mlir/test/Bytecode/versioning/versioned_bytecode.mlir b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir
new file mode 100644
index 0000000000000..bf08a23c03ae0
--- /dev/null
+++ b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir
@@ -0,0 +1,14 @@
+// This file contains test cases related to roundtripping.
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: target=s390x-{{.*}}
+
+//===--------------------------------------------------------------------===//
+// Test roundtrip
+//===--------------------------------------------------------------------===//
+
+// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -emit-bytecode \
+// RUN: -emit-bytecode-version=0 | mlir-opt -o %t.1 && \
+// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -o %t.2 && \
+// RUN:
diff %t.1 %t.2
+
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 2088e1633c66c..1afe1d9ed0da8 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -571,7 +571,8 @@ def testOperationPrint():
# Test roundtrip to bytecode.
bytecode_stream = io.BytesIO()
- module.operation.write_bytecode(bytecode_stream)
+ result = module.operation.write_bytecode(bytecode_stream, desired_version=1)
+ assert result.min_version() == 1, "Requested version not serialized to"
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR"
module_roundtrip = Module.parse(bytecode, ctx)
More information about the Mlir-commits
mailing list