[Mlir-commits] [mlir] [mlir][bytecode] Implements back deployment capability for MLIR dialects (PR #70724)

Matteo Franciolini llvmlistbot at llvm.org
Tue Oct 31 13:01:02 PDT 2023


https://github.com/mfrancio updated https://github.com/llvm/llvm-project/pull/70724

>From 7c8639cd81ad146586cd1a7a4adf43288dd90e25 Mon Sep 17 00:00:00 2001
From: Matteo Franciolini <m_franciolini at apple.com>
Date: Fri, 27 Oct 2023 17:20:21 -0700
Subject: [PATCH] [mlir][bytecode] Implements downgrade hook for back
 deployment of MLIR dialects

When emitting bytecode, clients can specify a target dialect version to emit in `BytecodeWriterConfig`. This exposes a target dialect version to the DialectBytecodeWriter, which can be queried by name and used to back-deploy attributes, types, and properties. IR constructs can be downgraded according to the specified version using a new downgradeToVersion hook exposed through the `BytecodeDialectInterface`.
---
 .../mlir/Bytecode/BytecodeImplementation.h    |  13 +++
 mlir/include/mlir/Bytecode/BytecodeWriter.h   |  16 ++-
 mlir/lib/Bytecode/Writer/BytecodeWriter.cpp   |  42 +++++++-
 mlir/lib/Bytecode/Writer/IRNumbering.cpp      |  27 +++--
 mlir/test/Bytecode/bytecode_callback.mlir     |   4 +-
 .../bytecode_callback_full_override.mlir      |   2 +-
 ...tecode_callback_with_custom_attribute.mlir |   4 +-
 .../bytecode_callback_with_custom_type.mlir   |   4 +-
 mlir/test/Bytecode/bytecode_downgrade.mlir    |  10 ++
 mlir/test/lib/Dialect/Test/TestDialect.cpp    |  15 ++-
 .../Dialect/Test/TestDialectInterfaces.cpp    |  15 ++-
 mlir/test/lib/IR/CMakeLists.txt               |   2 +-
 ...allbacks.cpp => TestBytecodeRoundtrip.cpp} | 100 ++++++++++++++----
 mlir/tools/mlir-opt/mlir-opt.cpp              |   4 +-
 14 files changed, 207 insertions(+), 51 deletions(-)
 create mode 100644 mlir/test/Bytecode/bytecode_downgrade.mlir
 rename mlir/test/lib/IR/{TestBytecodeCallbacks.cpp => TestBytecodeRoundtrip.cpp} (80%)

diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index bb1f0f717d80017..56e9228404b74b5 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -54,6 +54,10 @@ class DialectBytecodeReader {
   /// Retrieve the dialect version by name if available.
   virtual FailureOr<const DialectVersion *>
   getDialectVersion(StringRef dialectName) const = 0;
+  template <class T>
+  FailureOr<const DialectVersion *> getDialectVersion() {
+    return getDialectVersion(T::getDialectNamespace());
+  }
 
   /// Retrieve the context associated to the reader.
   virtual MLIRContext *getContext() const = 0;
@@ -400,6 +404,15 @@ class DialectBytecodeWriter {
 
   /// Return the bytecode version being emitted for.
   virtual int64_t getBytecodeVersion() const = 0;
+
+  /// Retrieve the dialect version by name if available.
+  virtual FailureOr<const DialectVersion *>
+  getDialectVersion(StringRef dialectName) const = 0;
+
+  template <class T>
+  FailureOr<const DialectVersion *> getDialectVersion() const {
+    return getDialectVersion(T::getDialectNamespace());
+  };
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index e0c46c3dab27a7b..2cc489778d4de32 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -16,8 +16,9 @@
 #include "mlir/IR/AsmState.h"
 
 namespace mlir {
-class Operation;
 class DialectBytecodeWriter;
+class DialectVersion;
+class Operation;
 
 /// A class to interact with the attributes and types printer when emitting MLIR
 /// bytecode.
@@ -97,6 +98,19 @@ class BytecodeWriterConfig {
   /// Get the set desired bytecode version to emit.
   int64_t getDesiredBytecodeVersion() const;
 
+  /// A map containing the dialect versions to emit.
+  llvm::StringMap<std::unique_ptr<DialectVersion>> &
+  getDialectVersionMap() const;
+
+  /// Set a given dialect version to emit on the map.
+  template <class T>
+  void setDialectVersion(std::unique_ptr<DialectVersion> dialectVersion) const {
+    return setDialectVersion(T::getDialectNamespace(),
+                             std::move(dialectVersion));
+  };
+  void setDialectVersion(StringRef dialectName,
+                         std::unique_ptr<DialectVersion> dialectVersion) const;
+
   //===--------------------------------------------------------------------===//
   // Types and Attributes encoding
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 5628ff6c54af67f..01dcea1ca3848eb 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -39,6 +39,9 @@ struct BytecodeWriterConfig::Impl {
   /// Note: This only differs from kVersion if a specific version is set.
   int64_t bytecodeVersion = bytecode::kVersion;
 
+  /// A map containing dialect version information for each dialect to emit.
+  llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap;
+
   /// The producer of the bytecode.
   StringRef producer;
 
@@ -94,6 +97,19 @@ int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const {
   return impl->bytecodeVersion;
 }
 
+llvm::StringMap<std::unique_ptr<DialectVersion>> &
+BytecodeWriterConfig::getDialectVersionMap() const {
+  return impl->dialectVersionMap;
+}
+
+void BytecodeWriterConfig::setDialectVersion(
+    llvm::StringRef dialectName,
+    std::unique_ptr<DialectVersion> dialectVersion) const {
+  assert(!impl->dialectVersionMap.contains(dialectName) &&
+         "cannot override a previously set dialect version");
+  impl->dialectVersionMap.insert({dialectName, std::move(dialectVersion)});
+}
+
 //===----------------------------------------------------------------------===//
 // EncodingEmitter
 //===----------------------------------------------------------------------===//
@@ -340,12 +356,16 @@ class StringSectionBuilder {
 } // namespace
 
 class DialectWriter : public DialectBytecodeWriter {
+  using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>;
+
 public:
   DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter,
                 IRNumberingState &numberingState,
-                StringSectionBuilder &stringSection)
+                StringSectionBuilder &stringSection,
+                const DialectVersionMapT &dialectVersionMap)
       : bytecodeVersion(bytecodeVersion), emitter(emitter),
-        numberingState(numberingState), stringSection(stringSection) {}
+        numberingState(numberingState), stringSection(stringSection),
+        dialectVersionMap(dialectVersionMap) {}
 
   //===--------------------------------------------------------------------===//
   // IR
@@ -421,11 +441,20 @@ class DialectWriter : public DialectBytecodeWriter {
 
   int64_t getBytecodeVersion() const override { return bytecodeVersion; }
 
+  FailureOr<const DialectVersion *>
+  getDialectVersion(StringRef dialectName) const override {
+    auto dialectEntry = dialectVersionMap.find(dialectName);
+    if (dialectEntry == dialectVersionMap.end())
+      return failure();
+    return dialectEntry->getValue().get();
+  }
+
 private:
   int64_t bytecodeVersion;
   EncodingEmitter &emitter;
   IRNumberingState &numberingState;
   StringSectionBuilder &stringSection;
+  const DialectVersionMapT &dialectVersionMap;
 };
 
 namespace {
@@ -458,7 +487,8 @@ class PropertiesSectionBuilder {
 
     EncodingEmitter emitter;
     DialectWriter propertiesWriter(config.bytecodeVersion, emitter,
-                                   numberingState, stringSection);
+                                   numberingState, stringSection,
+                                   config.dialectVersionMap);
     auto iface = cast<BytecodeOpInterface>(op);
     iface.writeProperties(propertiesWriter);
     scratch.clear();
@@ -751,7 +781,8 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
     if (dialect.interface) {
       // The writer used when emitting using a custom bytecode encoding.
       DialectWriter versionWriter(config.bytecodeVersion, versionEmitter,
-                                  numberingState, stringSection);
+                                  numberingState, stringSection,
+                                  config.dialectVersionMap);
       dialect.interface->writeVersion(versionWriter);
     }
 
@@ -809,7 +840,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
       }
 
       DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
-                                  numberingState, stringSection);
+                                  numberingState, stringSection,
+                                  config.dialectVersionMap);
       if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
         for (const auto &callback : config.typeWriterCallbacks) {
           if (succeeded(callback->write(entryValue, dialectWriter)))
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 74c45723c222d27..036a9477cce6c1d 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -12,7 +12,6 @@
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
-#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 using namespace mlir::bytecode::detail;
@@ -22,7 +21,10 @@ using namespace mlir::bytecode::detail;
 //===----------------------------------------------------------------------===//
 
 struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
-  NumberingDialectWriter(IRNumberingState &state) : state(state) {}
+  NumberingDialectWriter(
+      IRNumberingState &state,
+      llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
+      : state(state), dialectVersionMap(dialectVersionMap) {}
 
   void writeAttribute(Attribute attr) override { state.number(attr); }
   void writeOptionalAttribute(Attribute attr) override {
@@ -51,8 +53,19 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
     return state.getDesiredBytecodeVersion();
   }
 
+  FailureOr<const DialectVersion *>
+  getDialectVersion(StringRef dialectName) const override {
+    auto dialectEntry = dialectVersionMap.find(dialectName);
+    if (dialectEntry == dialectVersionMap.end())
+      return failure();
+    return dialectEntry->getValue().get();
+  }
+
   /// The parent numbering state that is populated by this writer.
   IRNumberingState &state;
+
+  /// A map containing dialect version information for each dialect to emit.
+  llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
 };
 
 //===----------------------------------------------------------------------===//
@@ -318,7 +331,7 @@ void IRNumberingState::number(Attribute attr) {
   if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
     // Try overriding emission with callbacks.
     for (const auto &callback : config.getAttributeWriterCallbacks()) {
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       // The client has the ability to override the group name through the
       // callback.
       std::optional<StringRef> groupNameOverride;
@@ -330,7 +343,7 @@ void IRNumberingState::number(Attribute attr) {
     }
 
     if (const auto *interface = numbering->dialect->interface) {
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       if (succeeded(interface->writeAttribute(attr, writer)))
         return;
     }
@@ -426,7 +439,7 @@ void IRNumberingState::number(Operation &op) {
     if (op.isRegistered()) {
       // Operation that have properties *must* implement this interface.
       auto iface = cast<BytecodeOpInterface>(op);
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       iface.writeProperties(writer);
     } else {
       // Unregistered op are storing properties as an optional attribute.
@@ -481,7 +494,7 @@ void IRNumberingState::number(Type type) {
   if (!type.hasTrait<TypeTrait::IsMutable>()) {
     // Try overriding emission with callbacks.
     for (const auto &callback : config.getTypeWriterCallbacks()) {
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       // The client has the ability to override the group name through the
       // callback.
       std::optional<StringRef> groupNameOverride;
@@ -495,7 +508,7 @@ void IRNumberingState::number(Type type) {
     // If this attribute will be emitted using the bytecode format, perform a
     // dummy writing to number any nested components.
     if (const auto *interface = numbering->dialect->interface) {
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       if (succeeded(interface->writeType(type, writer)))
         return;
     }
diff --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir
index cf3981c86b94428..3b537bb974da9ca 100644
--- a/mlir/test/Bytecode/bytecode_callback.mlir
+++ b/mlir/test/Bytecode/bytecode_callback.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
-// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
 
 func.func @base_test(%arg0 : i32) -> f32 {
   %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
diff --git a/mlir/test/Bytecode/bytecode_callback_full_override.mlir b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
index 21ff947ad389b6a..792f8f6e2959b70 100644
--- a/mlir/test/Bytecode/bytecode_callback_full_override.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
@@ -1,4 +1,4 @@
-// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s
+// RUN: not mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=5" 2>&1 | FileCheck %s
 
 // CHECK-NOT: failed to read bytecode
 func.func @base_test(%arg0 : i32) -> f32 {
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
index 487972f85af5be0..30f25d9865792a8 100644
--- a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=3" | FileCheck %s --check-prefix=TEST_3
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=4" | FileCheck %s --check-prefix=TEST_4
 
 "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
 
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
index 1e272ec4f3afc28..aba194f1681a24b 100644
--- a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=1" | FileCheck %s --check-prefix=TEST_1
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=2" | FileCheck %s --check-prefix=TEST_2
 
 func.func @base_test(%arg0: !test.i32, %arg1: f32) {
   return
diff --git a/mlir/test/Bytecode/bytecode_downgrade.mlir b/mlir/test/Bytecode/bytecode_downgrade.mlir
new file mode 100644
index 000000000000000..fd98b9d7ee6a58c
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_downgrade.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=1.2 test-kind=6" -verify-diagnostics | FileCheck %s
+
+module {
+  "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
+}
+
+// COM: the property downgrader is executed twice: first for IR numbering and then for emission.
+// CHECK: downgrading op...
+// CHECK: downgrading op properties...
+// CHECK: downgrading op properties...
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index deef4a9683e7421..21400a60e653215 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1339,7 +1339,7 @@ TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
 
   // Check if we have a version. If not, assume we are parsing the current
   // version.
-  auto maybeVersion = reader.getDialectVersion("test");
+  auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
   if (succeeded(maybeVersion)) {
     // If version is less than 2.0, there is no additional attribute to parse.
     // We can materialize missing properties post parsing before verification.
@@ -1358,6 +1358,17 @@ TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
 void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) {
   auto &prop = getProperties();
   writer.writeAttribute(prop.dims);
+
+  auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
+  if (succeeded(maybeVersion)) {
+    // If version is less than 2.0, there is no additional attribute to write.
+    const auto *version =
+        reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+    if ((version->major_ < 2)) {
+      llvm::outs() << "downgrading op properties...\n";
+      return;
+    }
+  }
   writer.writeAttribute(prop.modifier);
 }
 
@@ -1369,7 +1380,7 @@ ::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
 
   // Check if we have a version. If not, assume we are parsing the current
   // version.
-  auto maybeVersion = reader.getDialectVersion("test");
+  auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
   bool needToParseAnotherInt = true;
   if (succeeded(maybeVersion)) {
     // If version is less than 2.0, there is no additional attribute to parse.
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 950af85007475b9..ab7d2486db9aec7 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -77,7 +77,7 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
   }
 
   Attribute readAttribute(DialectBytecodeReader &reader) const final {
-    auto versionOr = reader.getDialectVersion("test");
+    auto versionOr = reader.getDialectVersion<test::TestDialect>();
     // Assume current version if not available through the reader.
     const auto version =
         (succeeded(versionOr))
@@ -93,9 +93,16 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
 
   // Emit a specific version of the dialect.
   void writeVersion(DialectBytecodeWriter &writer) const final {
-    auto version = TestDialectVersion();
-    writer.writeVarInt(version.major_); // major
-    writer.writeVarInt(version.minor_); // minor
+    // Construct the current dialect version.
+    test::TestDialectVersion versionToEmit;
+
+    // Check if a target version to emit was specified on the writer configs.
+    auto versionOr = writer.getDialectVersion<test::TestDialect>();
+    if (succeeded(versionOr))
+      versionToEmit =
+          *reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+    writer.writeVarInt(versionToEmit.major_); // major
+    writer.writeVarInt(versionToEmit.minor_); // minor
   }
 
   std::unique_ptr<DialectVersion>
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index fb0eac399245578..69c63fd7e524b6f 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -1,6 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestIR
-  TestBytecodeCallbacks.cpp
+  TestBytecodeRoundtrip.cpp
   TestBuiltinAttributeInterfaces.cpp
   TestBuiltinDistinctAttributes.cpp
   TestClone.cpp
diff --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
similarity index 80%
rename from mlir/test/lib/IR/TestBytecodeCallbacks.cpp
rename to mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
index cc884c7cc76f37a..beecc57d7cdd0d7 100644
--- a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp
+++ b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
@@ -45,22 +45,28 @@ class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
 
 /// This is a test pass which uses callbacks to encode attributes and types in a
 /// custom fashion.
-struct TestBytecodeCallbackPass
-    : public PassWrapper<TestBytecodeCallbackPass, OperationPass<ModuleOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass)
+struct TestBytecodeRoundtripPass
+    : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass)
 
-  StringRef getArgument() const final { return "test-bytecode-callback"; }
+  StringRef getArgument() const final { return "test-bytecode-roundtrip"; }
   StringRef getDescription() const final {
-    return "Test encoding of a dialect type/attributes with a custom callback";
+    return "Test pass to implement bytecode roundtrip tests.";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<test::TestDialect>();
   }
-  TestBytecodeCallbackPass() = default;
-  TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {}
+  TestBytecodeRoundtripPass() = default;
+  TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {}
+
+  LogicalResult initialize(MLIRContext *context) override {
+    testDialect = context->getOrLoadDialect<test::TestDialect>();
+    return success();
+  }
 
   void runOnOperation() override {
     switch (testKind) {
+      // Tests 0-5 implement a custom roundtrip with callbacks.
     case (0):
       return runTest0(getOperation());
     case (1):
@@ -73,6 +79,10 @@ struct TestBytecodeCallbackPass
       return runTest4(getOperation());
     case (5):
       return runTest5(getOperation());
+    case (6):
+      // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from
+      // `targetVersion`.
+      return runTest6(getOperation());
     default:
       llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
     }
@@ -85,8 +95,8 @@ struct TestBytecodeCallbackPass
                     cl::init(test::TestDialectVersion())};
 
   mlir::Pass::Option<int> testKind{
-      *this, "callback-test",
-      llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)};
+      *this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
+      cl::init(0)};
 
 private:
   void doRoundtripWithConfigs(Operation *op,
@@ -122,11 +132,19 @@ struct TestBytecodeCallbackPass
     auto newCtx = std::make_shared<MLIRContext>();
     test::TestDialectVersion targetEmissionVersion = targetVersion;
     BytecodeWriterConfig writeConfig;
+    // Set the emission version for the test dialect.
+    writeConfig.setDialectVersion<test::TestDialect>(
+        std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
     writeConfig.attachTypeCallback(
         [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
             DialectBytecodeWriter &writer) -> LogicalResult {
-          // Do not override anything if version less than 2.0.
-          if (targetEmissionVersion.major_ >= 2)
+          // Do not override anything if version greater than 2.0.
+          auto versionOr = writer.getDialectVersion<test::TestDialect>();
+          assert(succeeded(versionOr) && "expected reader to be able to access "
+                                         "the version for test dialect");
+          const auto *version =
+              reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+          if (version->major_ >= 2)
             return failure();
 
           // For version less than 2.0, override the encoding of IntegerType.
@@ -146,19 +164,12 @@ struct TestBytecodeCallbackPass
         [&](DialectBytecodeReader &reader, StringRef dialectName,
             Type &entry) -> LogicalResult {
           // Get test dialect version from the version map.
-          auto versionOr = reader.getDialectVersion("test");
+          auto versionOr = reader.getDialectVersion<test::TestDialect>();
           assert(succeeded(versionOr) && "expected reader to be able to access "
                                          "the version for test dialect");
           const auto *version =
               reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
-
-          // TODO: once back-deployment is formally supported,
-          // `targetEmissionVersion` will be encoded in the bytecode file, and
-          // exposed through the versionMap. Right now though this is not yet
-          // supported. For the purpose of the test, just use
-          // `targetEmissionVersion`.
-          (void)version;
-          if (targetEmissionVersion.major_ >= 2)
+          if (version->major_ >= 2)
             return success();
 
           // `dialectName` is the name of the group we have the opportunity to
@@ -355,11 +366,56 @@ struct TestBytecodeCallbackPass
         });
     doRoundtripWithConfigs(op, writeConfig, parseConfig);
   }
+
+  LogicalResult downgradeToVersion(Operation *op,
+                                   const test::TestDialectVersion &version) {
+    if ((version.major_ == 2) && (version.minor_ == 0))
+      return success();
+    if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
+      return op->emitError() << "current test dialect version is 2.0, "
+                                "can't downgrade to version: "
+                             << version.major_ << "." << version.minor_;
+    }
+    // Prior version 2.0, the old op supported only a single attribute called
+    // "dimensions". We need to check that the modifier is false, otherwise we
+    // can't do the downgrade.
+    auto status = op->walk([&](test::TestVersionedOpA op) {
+      auto &prop = op.getProperties();
+      if (prop.modifier.getValue()) {
+        op->emitOpError() << "cannot downgrade to version " << version.major_
+                          << "." << version.minor_
+                          << " since the modifier is not compatible";
+        return WalkResult::interrupt();
+      }
+      llvm::outs() << "downgrading op...\n";
+      return WalkResult::advance();
+    });
+    return failure(status.wasInterrupted());
+  }
+
+  // Test6: Downgrade IR to `targetVersion`, write to bytecode. Then, read and
+  // upgrade IR when back in memory. The module is expected to be unmodified at
+  // the end of the function.
+  void runTest6(Operation *op) {
+    test::TestDialectVersion targetEmissionVersion = targetVersion;
+
+    // Downgrade IR constructs before writing the IR to bytecode.
+    auto status = downgradeToVersion(op, targetEmissionVersion);
+    assert(succeeded(status) && "expected the downgrade to succeed");
+
+    BytecodeWriterConfig writeConfig;
+    writeConfig.setDialectVersion<test::TestDialect>(
+        std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
+    doRoundtripWithConfigs(op, writeConfig, parseConfig);
+  }
+
+  test::TestDialect *testDialect;
 };
 } // namespace
 
 namespace mlir {
-void registerTestBytecodeCallbackPasses() {
-  PassRegistration<TestBytecodeCallbackPass>();
+void registerTestBytecodeRoundtripPasses() {
+  PassRegistration<TestBytecodeRoundtripPass>();
 }
 } // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b7647d7de78a10e..3e3223b48505601 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -44,7 +44,7 @@ void registerSymbolTestPasses();
 void registerRegionTestPasses();
 void registerTestAffineDataCopyPass();
 void registerTestAffineReifyValueBoundsPass();
-void registerTestBytecodeCallbackPasses();
+void registerTestBytecodeRoundtripPasses();
 void registerTestDecomposeAffineOpPass();
 void registerTestAffineLoopUnswitchingPass();
 void registerTestAllReduceLoweringPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
   registerTestDecomposeAffineOpPass();
   registerTestAffineLoopUnswitchingPass();
   registerTestAllReduceLoweringPass();
-  registerTestBytecodeCallbackPasses();
+  registerTestBytecodeRoundtripPasses();
   registerTestFunc();
   registerTestGpuMemoryPromotionPass();
   registerTestLoopPermutationPass();



More information about the Mlir-commits mailing list