[Mlir-commits] [mlir] [mlir][bytecode] Implements downgrade hook for back deployment of MLIR dialects (PR #70724)
Matteo Franciolini
llvmlistbot at llvm.org
Mon Oct 30 13:40:39 PDT 2023
https://github.com/mfrancio created https://github.com/llvm/llvm-project/pull/70724
When emitting bytecode, clients can specify a target dialect version to emit in `BytecodeWriterConfig`. This exposes a target dialect version to the DialectBytecodeWriter, which can be queried by name and used to back-deploy attributes, types, and properties. IR constructs can be downgraded according to the specified version using a new downgradeToVersion hook exposed through the `BytecodeDialectInterface`.
>From 334be0771e06811ee69a541d3a7042e624cddfcb Mon Sep 17 00:00:00 2001
From: Matteo Franciolini <m_franciolini at apple.com>
Date: Fri, 27 Oct 2023 17:20:21 -0700
Subject: [PATCH] [mlir][bytecode] Implements downgrade hook for back
deployment of MLIR dialects
When emitting bytecode, clients can specify a target dialect version to emit in `BytecodeWriterConfig`. This exposes a target dialect version to the DialectBytecodeWriter, which can be queried by name and used to back-deploy attributes, types, and properties. IR constructs can be downgraded according to the specified version using a new downgradeToVersion hook exposed through the `BytecodeDialectInterface`.
---
.../mlir/Bytecode/BytecodeImplementation.h | 16 +++++
mlir/include/mlir/Bytecode/BytecodeWriter.h | 11 +++-
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 53 ++++++++++++++--
mlir/lib/Bytecode/Writer/IRNumbering.cpp | 27 ++++++---
mlir/test/Bytecode/bytecode_callback.mlir | 4 +-
.../bytecode_callback_full_override.mlir | 2 +-
...tecode_callback_with_custom_attribute.mlir | 4 +-
.../bytecode_callback_with_custom_type.mlir | 4 +-
mlir/test/Bytecode/bytecode_downgrade.mlir | 10 ++++
mlir/test/lib/Dialect/Test/TestDialect.cpp | 11 ++++
.../Dialect/Test/TestDialectInterfaces.cpp | 40 ++++++++++++-
mlir/test/lib/IR/CMakeLists.txt | 2 +-
...allbacks.cpp => TestBytecodeRoundtrip.cpp} | 60 ++++++++++++-------
mlir/tools/mlir-opt/mlir-opt.cpp | 4 +-
14 files changed, 201 insertions(+), 47 deletions(-)
create mode 100644 mlir/test/Bytecode/bytecode_downgrade.mlir
rename mlir/test/lib/IR/{TestBytecodeCallbacks.cpp => TestBytecodeRoundtrip.cpp} (88%)
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index bb1f0f717d80017..c9729a72a7bac4a 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -400,6 +400,10 @@ class DialectBytecodeWriter {
/// Return the bytecode version being emitted for.
virtual int64_t getBytecodeVersion() const = 0;
+
+ /// Retrieve the dialect version by name if available.
+ virtual FailureOr<const DialectVersion *>
+ getDialectVersion(StringRef dialectName) const = 0;
};
//===----------------------------------------------------------------------===//
@@ -475,6 +479,18 @@ class BytecodeDialectInterface
const DialectVersion &version) const {
return success();
}
+
+ /// Hook invoked before writing the top level operation to bytecode if any of
+ /// the registered dialects requiested a specific emission version in the
+ /// BytecodeWriterConfig. This hook exposes the ability to visit the IR and
+ /// downgrade constructs for each of the dialects present in the module
+ /// according to the provided version specified as part of the
+ /// `BytecodeWriterConfig`.
+ virtual LogicalResult
+ downgradeToVersion(Operation *topLevelOp,
+ const DialectVersion &version) const {
+ return success();
+ }
};
/// Helper for resource handle reading that returns LogicalResult.
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index e0c46c3dab27a7b..90251068c8303f7 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -16,8 +16,9 @@
#include "mlir/IR/AsmState.h"
namespace mlir {
-class Operation;
class DialectBytecodeWriter;
+class DialectVersion;
+class Operation;
/// A class to interact with the attributes and types printer when emitting MLIR
/// bytecode.
@@ -97,6 +98,14 @@ class BytecodeWriterConfig {
/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;
+ /// A map containing the dialect versions to emit.
+ llvm::StringMap<std::unique_ptr<DialectVersion>> &
+ getDialectVersionMap() const;
+
+ /// Set a given dialect version to emit on the map.
+ void setDialectVersion(StringRef dialectName,
+ std::unique_ptr<DialectVersion> dialectVersion) const;
+
//===--------------------------------------------------------------------===//
// Types and Attributes encoding
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 5628ff6c54af67f..f9d60e66796cf60 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -39,6 +39,9 @@ struct BytecodeWriterConfig::Impl {
/// Note: This only differs from kVersion if a specific version is set.
int64_t bytecodeVersion = bytecode::kVersion;
+ /// A map containing dialect version information for each dialect to emit.
+ llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap;
+
/// The producer of the bytecode.
StringRef producer;
@@ -94,6 +97,19 @@ int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const {
return impl->bytecodeVersion;
}
+llvm::StringMap<std::unique_ptr<DialectVersion>> &
+BytecodeWriterConfig::getDialectVersionMap() const {
+ return impl->dialectVersionMap;
+}
+
+void BytecodeWriterConfig::setDialectVersion(
+ llvm::StringRef dialectName,
+ std::unique_ptr<DialectVersion> dialectVersion) const {
+ assert(!impl->dialectVersionMap.contains(dialectName) &&
+ "cannot override a previously set dialect version");
+ impl->dialectVersionMap.insert({dialectName, std::move(dialectVersion)});
+}
+
//===----------------------------------------------------------------------===//
// EncodingEmitter
//===----------------------------------------------------------------------===//
@@ -340,12 +356,16 @@ class StringSectionBuilder {
} // namespace
class DialectWriter : public DialectBytecodeWriter {
+ using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>;
+
public:
DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter,
IRNumberingState &numberingState,
- StringSectionBuilder &stringSection)
+ StringSectionBuilder &stringSection,
+ const DialectVersionMapT &dialectVersionMap)
: bytecodeVersion(bytecodeVersion), emitter(emitter),
- numberingState(numberingState), stringSection(stringSection) {}
+ numberingState(numberingState), stringSection(stringSection),
+ dialectVersionMap(dialectVersionMap) {}
//===--------------------------------------------------------------------===//
// IR
@@ -421,11 +441,20 @@ class DialectWriter : public DialectBytecodeWriter {
int64_t getBytecodeVersion() const override { return bytecodeVersion; }
+ FailureOr<const DialectVersion *>
+ getDialectVersion(StringRef dialectName) const override {
+ auto dialectEntry = dialectVersionMap.find(dialectName);
+ if (dialectEntry == dialectVersionMap.end())
+ return failure();
+ return dialectEntry->getValue().get();
+ }
+
private:
int64_t bytecodeVersion;
EncodingEmitter &emitter;
IRNumberingState &numberingState;
StringSectionBuilder &stringSection;
+ const DialectVersionMapT &dialectVersionMap;
};
namespace {
@@ -458,7 +487,8 @@ class PropertiesSectionBuilder {
EncodingEmitter emitter;
DialectWriter propertiesWriter(config.bytecodeVersion, emitter,
- numberingState, stringSection);
+ numberingState, stringSection,
+ config.dialectVersionMap);
auto iface = cast<BytecodeOpInterface>(op);
iface.writeProperties(propertiesWriter);
scratch.clear();
@@ -751,7 +781,8 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
if (dialect.interface) {
// The writer used when emitting using a custom bytecode encoding.
DialectWriter versionWriter(config.bytecodeVersion, versionEmitter,
- numberingState, stringSection);
+ numberingState, stringSection,
+ config.dialectVersionMap);
dialect.interface->writeVersion(versionWriter);
}
@@ -809,7 +840,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
}
DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
- numberingState, stringSection);
+ numberingState, stringSection,
+ config.dialectVersionMap);
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
for (const auto &callback : config.typeWriterCallbacks) {
if (succeeded(callback->write(entryValue, dialectWriter)))
@@ -1256,6 +1288,17 @@ void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) {
LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config) {
+ // Before creating the bytecode writer, give an opportunity to each of the
+ // dialects that registered a target version emission to perform IR
+ // downgrades. Note that it is responsibility of client dialect to handle the
+ // potentially destructive changes that a downgrade function could do to the
+ // IR.
+ for (auto &item : config.getDialectVersionMap()) {
+ Dialect *currentDialect = op->getContext()->getOrLoadDialect(item.getKey());
+ if (auto *iface = llvm::dyn_cast<BytecodeDialectInterface>(currentDialect))
+ if (failed(iface->downgradeToVersion(op, *item.getValue())))
+ return failure();
+ }
BytecodeWriter writer(op, config);
return writer.write(op, os);
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 74c45723c222d27..036a9477cce6c1d 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -12,7 +12,6 @@
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
-#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
using namespace mlir::bytecode::detail;
@@ -22,7 +21,10 @@ using namespace mlir::bytecode::detail;
//===----------------------------------------------------------------------===//
struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
- NumberingDialectWriter(IRNumberingState &state) : state(state) {}
+ NumberingDialectWriter(
+ IRNumberingState &state,
+ llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
+ : state(state), dialectVersionMap(dialectVersionMap) {}
void writeAttribute(Attribute attr) override { state.number(attr); }
void writeOptionalAttribute(Attribute attr) override {
@@ -51,8 +53,19 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
return state.getDesiredBytecodeVersion();
}
+ FailureOr<const DialectVersion *>
+ getDialectVersion(StringRef dialectName) const override {
+ auto dialectEntry = dialectVersionMap.find(dialectName);
+ if (dialectEntry == dialectVersionMap.end())
+ return failure();
+ return dialectEntry->getValue().get();
+ }
+
/// The parent numbering state that is populated by this writer.
IRNumberingState &state;
+
+ /// A map containing dialect version information for each dialect to emit.
+ llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
};
//===----------------------------------------------------------------------===//
@@ -318,7 +331,7 @@ void IRNumberingState::number(Attribute attr) {
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
// Try overriding emission with callbacks.
for (const auto &callback : config.getAttributeWriterCallbacks()) {
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
// The client has the ability to override the group name through the
// callback.
std::optional<StringRef> groupNameOverride;
@@ -330,7 +343,7 @@ void IRNumberingState::number(Attribute attr) {
}
if (const auto *interface = numbering->dialect->interface) {
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
if (succeeded(interface->writeAttribute(attr, writer)))
return;
}
@@ -426,7 +439,7 @@ void IRNumberingState::number(Operation &op) {
if (op.isRegistered()) {
// Operation that have properties *must* implement this interface.
auto iface = cast<BytecodeOpInterface>(op);
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
iface.writeProperties(writer);
} else {
// Unregistered op are storing properties as an optional attribute.
@@ -481,7 +494,7 @@ void IRNumberingState::number(Type type) {
if (!type.hasTrait<TypeTrait::IsMutable>()) {
// Try overriding emission with callbacks.
for (const auto &callback : config.getTypeWriterCallbacks()) {
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
// The client has the ability to override the group name through the
// callback.
std::optional<StringRef> groupNameOverride;
@@ -495,7 +508,7 @@ void IRNumberingState::number(Type type) {
// If this attribute will be emitted using the bytecode format, perform a
// dummy writing to number any nested components.
if (const auto *interface = numbering->dialect->interface) {
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
if (succeeded(interface->writeType(type, writer)))
return;
}
diff --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir
index cf3981c86b94428..3b537bb974da9ca 100644
--- a/mlir/test/Bytecode/bytecode_callback.mlir
+++ b/mlir/test/Bytecode/bytecode_callback.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
-// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
func.func @base_test(%arg0 : i32) -> f32 {
%0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
diff --git a/mlir/test/Bytecode/bytecode_callback_full_override.mlir b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
index 21ff947ad389b6a..792f8f6e2959b70 100644
--- a/mlir/test/Bytecode/bytecode_callback_full_override.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
@@ -1,4 +1,4 @@
-// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s
+// RUN: not mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=5" 2>&1 | FileCheck %s
// CHECK-NOT: failed to read bytecode
func.func @base_test(%arg0 : i32) -> f32 {
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
index 487972f85af5be0..30f25d9865792a8 100644
--- a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=3" | FileCheck %s --check-prefix=TEST_3
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=4" | FileCheck %s --check-prefix=TEST_4
"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
index 1e272ec4f3afc28..aba194f1681a24b 100644
--- a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=1" | FileCheck %s --check-prefix=TEST_1
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=2" | FileCheck %s --check-prefix=TEST_2
func.func @base_test(%arg0: !test.i32, %arg1: f32) {
return
diff --git a/mlir/test/Bytecode/bytecode_downgrade.mlir b/mlir/test/Bytecode/bytecode_downgrade.mlir
new file mode 100644
index 000000000000000..fd98b9d7ee6a58c
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_downgrade.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=1.2 test-kind=6" -verify-diagnostics | FileCheck %s
+
+module {
+ "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
+}
+
+// COM: the property downgrader is executed twice: first for IR numbering and then for emission.
+// CHECK: downgrading op...
+// CHECK: downgrading op properties...
+// CHECK: downgrading op properties...
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index deef4a9683e7421..61758085ba08172 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1358,6 +1358,17 @@ TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) {
auto &prop = getProperties();
writer.writeAttribute(prop.dims);
+
+ auto maybeVersion = writer.getDialectVersion("test");
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to write.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2)) {
+ llvm::outs() << "downgrading op properties...\n";
+ return;
+ }
+ }
writer.writeAttribute(prop.modifier);
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 950af85007475b9..6ddccdf759de7b8 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -93,9 +93,16 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
// Emit a specific version of the dialect.
void writeVersion(DialectBytecodeWriter &writer) const final {
- auto version = TestDialectVersion();
- writer.writeVarInt(version.major_); // major
- writer.writeVarInt(version.minor_); // minor
+ // Construct the current dialect version.
+ test::TestDialectVersion versionToEmit;
+
+ // Check if a target version to emit was specified on the writer configs.
+ auto versionOr = writer.getDialectVersion("test");
+ if (succeeded(versionOr))
+ versionToEmit =
+ *reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+ writer.writeVarInt(versionToEmit.major_); // major
+ writer.writeVarInt(versionToEmit.minor_); // minor
}
std::unique_ptr<DialectVersion>
@@ -130,6 +137,33 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
return success();
}
+ LogicalResult downgradeToVersion(Operation *topLevelOp,
+ const DialectVersion &version_) const final {
+ const auto &version = static_cast<const TestDialectVersion &>(version_);
+ if ((version.major_ == 2) && (version.minor_ == 0))
+ return success();
+ if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
+ return topLevelOp->emitError() << "current test dialect version is 2.0, "
+ "can't downgrade to version: "
+ << version.major_ << "." << version.minor_;
+ }
+ // Prior version 2.0, the old op supported only a single attribute called
+ // "dimensions". We need to check that the modifier is false, otherwise we
+ // can't do the downgrade.
+ auto status = topLevelOp->walk([&](TestVersionedOpA op) {
+ auto &prop = op.getProperties();
+ if (prop.modifier.getValue()) {
+ op->emitOpError() << "cannot downgrade to version " << version.major_
+ << "." << version.minor_
+ << " since the modifier is not compatible";
+ return WalkResult::interrupt();
+ }
+ llvm::outs() << "downgrading op...\n";
+ return WalkResult::advance();
+ });
+ return failure(status.wasInterrupted());
+ }
+
private:
Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
uint64_t encoding;
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index fb0eac399245578..69c63fd7e524b6f 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -1,6 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestIR
- TestBytecodeCallbacks.cpp
+ TestBytecodeRoundtrip.cpp
TestBuiltinAttributeInterfaces.cpp
TestBuiltinDistinctAttributes.cpp
TestClone.cpp
diff --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
similarity index 88%
rename from mlir/test/lib/IR/TestBytecodeCallbacks.cpp
rename to mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
index cc884c7cc76f37a..cd01533c7930e72 100644
--- a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp
+++ b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
@@ -45,22 +45,23 @@ class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
/// This is a test pass which uses callbacks to encode attributes and types in a
/// custom fashion.
-struct TestBytecodeCallbackPass
- : public PassWrapper<TestBytecodeCallbackPass, OperationPass<ModuleOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass)
+struct TestBytecodeRoundtripPass
+ : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass)
- StringRef getArgument() const final { return "test-bytecode-callback"; }
+ StringRef getArgument() const final { return "test-bytecode-roundtrip"; }
StringRef getDescription() const final {
- return "Test encoding of a dialect type/attributes with a custom callback";
+ return "Test pass to implement bytecode roundtrip tests.";
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<test::TestDialect>();
}
- TestBytecodeCallbackPass() = default;
- TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {}
+ TestBytecodeRoundtripPass() = default;
+ TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {}
void runOnOperation() override {
switch (testKind) {
+ // Tests 0-5 implement a custom roundtrip with callbacks.
case (0):
return runTest0(getOperation());
case (1):
@@ -73,6 +74,10 @@ struct TestBytecodeCallbackPass
return runTest4(getOperation());
case (5):
return runTest5(getOperation());
+ case (6):
+ // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from
+ // `targetVersion`.
+ return runTest6(getOperation());
default:
llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
}
@@ -85,8 +90,8 @@ struct TestBytecodeCallbackPass
cl::init(test::TestDialectVersion())};
mlir::Pass::Option<int> testKind{
- *this, "callback-test",
- llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)};
+ *this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
+ cl::init(0)};
private:
void doRoundtripWithConfigs(Operation *op,
@@ -122,11 +127,20 @@ struct TestBytecodeCallbackPass
auto newCtx = std::make_shared<MLIRContext>();
test::TestDialectVersion targetEmissionVersion = targetVersion;
BytecodeWriterConfig writeConfig;
+ // Set the emission version for the test dialect.
+ writeConfig.setDialectVersion(
+ "test",
+ std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
writeConfig.attachTypeCallback(
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
- // Do not override anything if version less than 2.0.
- if (targetEmissionVersion.major_ >= 2)
+ // Do not override anything if version greater than 2.0.
+ auto versionOr = writer.getDialectVersion("test");
+ assert(succeeded(versionOr) && "expected reader to be able to access "
+ "the version for test dialect");
+ const auto *version =
+ reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+ if (version->major_ >= 2)
return failure();
// For version less than 2.0, override the encoding of IntegerType.
@@ -151,14 +165,7 @@ struct TestBytecodeCallbackPass
"the version for test dialect");
const auto *version =
reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
-
- // TODO: once back-deployment is formally supported,
- // `targetEmissionVersion` will be encoded in the bytecode file, and
- // exposed through the versionMap. Right now though this is not yet
- // supported. For the purpose of the test, just use
- // `targetEmissionVersion`.
- (void)version;
- if (targetEmissionVersion.major_ >= 2)
+ if (version->major_ >= 2)
return success();
// `dialectName` is the name of the group we have the opportunity to
@@ -355,11 +362,22 @@ struct TestBytecodeCallbackPass
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
+
+ // Test6: Plain roundtrip with downgrade to speficied target version.
+ void runTest6(Operation *op) {
+ BytecodeWriterConfig writeConfig;
+ test::TestDialectVersion targetEmissionVersion = targetVersion;
+ writeConfig.setDialectVersion(
+ "test",
+ std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
+ ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
+ doRoundtripWithConfigs(op, writeConfig, parseConfig);
+ }
};
} // namespace
namespace mlir {
-void registerTestBytecodeCallbackPasses() {
- PassRegistration<TestBytecodeCallbackPass>();
+void registerTestBytecodeRoundtripPasses() {
+ PassRegistration<TestBytecodeRoundtripPass>();
}
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b7647d7de78a10e..3e3223b48505601 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -44,7 +44,7 @@ void registerSymbolTestPasses();
void registerRegionTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAffineReifyValueBoundsPass();
-void registerTestBytecodeCallbackPasses();
+void registerTestBytecodeRoundtripPasses();
void registerTestDecomposeAffineOpPass();
void registerTestAffineLoopUnswitchingPass();
void registerTestAllReduceLoweringPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
registerTestDecomposeAffineOpPass();
registerTestAffineLoopUnswitchingPass();
registerTestAllReduceLoweringPass();
- registerTestBytecodeCallbackPasses();
+ registerTestBytecodeRoundtripPasses();
registerTestFunc();
registerTestGpuMemoryPromotionPass();
registerTestLoopPermutationPass();
More information about the Mlir-commits
mailing list