[Mlir-commits] [mlir] [mlir][bytecode] Implements downgrade hook for back deployment of MLIR dialects (PR #70724)

Matteo Franciolini llvmlistbot at llvm.org
Mon Oct 30 14:26:49 PDT 2023


https://github.com/mfrancio updated https://github.com/llvm/llvm-project/pull/70724

>From 3c913a5bd1caaef487d43720259c8e60b0c24d35 Mon Sep 17 00:00:00 2001
From: Matteo Franciolini <m_franciolini at apple.com>
Date: Fri, 27 Oct 2023 17:20:21 -0700
Subject: [PATCH] [mlir][bytecode] Implements downgrade hook for back
 deployment of MLIR dialects

When emitting bytecode, clients can specify a target dialect version to emit in `BytecodeWriterConfig`. This exposes a target dialect version to the DialectBytecodeWriter, which can be queried by name and used to back-deploy attributes, types, and properties. IR constructs can be downgraded according to the specified version using a new downgradeToVersion hook exposed through the `BytecodeDialectInterface`.
---
 .../mlir/Bytecode/BytecodeImplementation.h    | 16 +++++
 mlir/include/mlir/Bytecode/BytecodeWriter.h   | 11 +++-
 mlir/lib/Bytecode/Writer/BytecodeWriter.cpp   | 57 ++++++++++++++++--
 mlir/lib/Bytecode/Writer/IRNumbering.cpp      | 27 ++++++---
 mlir/test/Bytecode/bytecode_callback.mlir     |  4 +-
 .../bytecode_callback_full_override.mlir      |  2 +-
 ...tecode_callback_with_custom_attribute.mlir |  4 +-
 .../bytecode_callback_with_custom_type.mlir   |  4 +-
 mlir/test/Bytecode/bytecode_downgrade.mlir    | 10 ++++
 mlir/test/lib/Dialect/Test/TestDialect.cpp    | 11 ++++
 .../Dialect/Test/TestDialectInterfaces.cpp    | 40 ++++++++++++-
 mlir/test/lib/IR/CMakeLists.txt               |  2 +-
 ...allbacks.cpp => TestBytecodeRoundtrip.cpp} | 60 ++++++++++++-------
 mlir/tools/mlir-opt/mlir-opt.cpp              |  4 +-
 14 files changed, 205 insertions(+), 47 deletions(-)
 create mode 100644 mlir/test/Bytecode/bytecode_downgrade.mlir
 rename mlir/test/lib/IR/{TestBytecodeCallbacks.cpp => TestBytecodeRoundtrip.cpp} (88%)

diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index bb1f0f717d80017..c9729a72a7bac4a 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -400,6 +400,10 @@ class DialectBytecodeWriter {
 
   /// Return the bytecode version being emitted for.
   virtual int64_t getBytecodeVersion() const = 0;
+
+  /// Retrieve the dialect version by name if available.
+  virtual FailureOr<const DialectVersion *>
+  getDialectVersion(StringRef dialectName) const = 0;
 };
 
 //===----------------------------------------------------------------------===//
@@ -475,6 +479,18 @@ class BytecodeDialectInterface
                      const DialectVersion &version) const {
     return success();
   }
+
+  /// Hook invoked before writing the top level operation to bytecode if any of
+  /// the registered dialects requiested a specific emission version in the
+  /// BytecodeWriterConfig. This hook exposes the ability to visit the IR and
+  /// downgrade constructs for each of the dialects present in the module
+  /// according to the provided version specified as part of the
+  /// `BytecodeWriterConfig`.
+  virtual LogicalResult
+  downgradeToVersion(Operation *topLevelOp,
+                     const DialectVersion &version) const {
+    return success();
+  }
 };
 
 /// Helper for resource handle reading that returns LogicalResult.
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index e0c46c3dab27a7b..90251068c8303f7 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -16,8 +16,9 @@
 #include "mlir/IR/AsmState.h"
 
 namespace mlir {
-class Operation;
 class DialectBytecodeWriter;
+class DialectVersion;
+class Operation;
 
 /// A class to interact with the attributes and types printer when emitting MLIR
 /// bytecode.
@@ -97,6 +98,14 @@ class BytecodeWriterConfig {
   /// Get the set desired bytecode version to emit.
   int64_t getDesiredBytecodeVersion() const;
 
+  /// A map containing the dialect versions to emit.
+  llvm::StringMap<std::unique_ptr<DialectVersion>> &
+  getDialectVersionMap() const;
+
+  /// Set a given dialect version to emit on the map.
+  void setDialectVersion(StringRef dialectName,
+                         std::unique_ptr<DialectVersion> dialectVersion) const;
+
   //===--------------------------------------------------------------------===//
   // Types and Attributes encoding
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 5628ff6c54af67f..33312e9eba06e04 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -39,6 +39,9 @@ struct BytecodeWriterConfig::Impl {
   /// Note: This only differs from kVersion if a specific version is set.
   int64_t bytecodeVersion = bytecode::kVersion;
 
+  /// A map containing dialect version information for each dialect to emit.
+  llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap;
+
   /// The producer of the bytecode.
   StringRef producer;
 
@@ -94,6 +97,19 @@ int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const {
   return impl->bytecodeVersion;
 }
 
+llvm::StringMap<std::unique_ptr<DialectVersion>> &
+BytecodeWriterConfig::getDialectVersionMap() const {
+  return impl->dialectVersionMap;
+}
+
+void BytecodeWriterConfig::setDialectVersion(
+    llvm::StringRef dialectName,
+    std::unique_ptr<DialectVersion> dialectVersion) const {
+  assert(!impl->dialectVersionMap.contains(dialectName) &&
+         "cannot override a previously set dialect version");
+  impl->dialectVersionMap.insert({dialectName, std::move(dialectVersion)});
+}
+
 //===----------------------------------------------------------------------===//
 // EncodingEmitter
 //===----------------------------------------------------------------------===//
@@ -340,12 +356,16 @@ class StringSectionBuilder {
 } // namespace
 
 class DialectWriter : public DialectBytecodeWriter {
+  using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>;
+
 public:
   DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter,
                 IRNumberingState &numberingState,
-                StringSectionBuilder &stringSection)
+                StringSectionBuilder &stringSection,
+                const DialectVersionMapT &dialectVersionMap)
       : bytecodeVersion(bytecodeVersion), emitter(emitter),
-        numberingState(numberingState), stringSection(stringSection) {}
+        numberingState(numberingState), stringSection(stringSection),
+        dialectVersionMap(dialectVersionMap) {}
 
   //===--------------------------------------------------------------------===//
   // IR
@@ -421,11 +441,20 @@ class DialectWriter : public DialectBytecodeWriter {
 
   int64_t getBytecodeVersion() const override { return bytecodeVersion; }
 
+  FailureOr<const DialectVersion *>
+  getDialectVersion(StringRef dialectName) const override {
+    auto dialectEntry = dialectVersionMap.find(dialectName);
+    if (dialectEntry == dialectVersionMap.end())
+      return failure();
+    return dialectEntry->getValue().get();
+  }
+
 private:
   int64_t bytecodeVersion;
   EncodingEmitter &emitter;
   IRNumberingState &numberingState;
   StringSectionBuilder &stringSection;
+  const DialectVersionMapT &dialectVersionMap;
 };
 
 namespace {
@@ -458,7 +487,8 @@ class PropertiesSectionBuilder {
 
     EncodingEmitter emitter;
     DialectWriter propertiesWriter(config.bytecodeVersion, emitter,
-                                   numberingState, stringSection);
+                                   numberingState, stringSection,
+                                   config.dialectVersionMap);
     auto iface = cast<BytecodeOpInterface>(op);
     iface.writeProperties(propertiesWriter);
     scratch.clear();
@@ -751,7 +781,8 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
     if (dialect.interface) {
       // The writer used when emitting using a custom bytecode encoding.
       DialectWriter versionWriter(config.bytecodeVersion, versionEmitter,
-                                  numberingState, stringSection);
+                                  numberingState, stringSection,
+                                  config.dialectVersionMap);
       dialect.interface->writeVersion(versionWriter);
     }
 
@@ -809,7 +840,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
       }
 
       DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
-                                  numberingState, stringSection);
+                                  numberingState, stringSection,
+                                  config.dialectVersionMap);
       if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
         for (const auto &callback : config.typeWriterCallbacks) {
           if (succeeded(callback->write(entryValue, dialectWriter)))
@@ -1256,6 +1288,21 @@ void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) {
 
 LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
                                         const BytecodeWriterConfig &config) {
+  // Before creating the bytecode writer, give an opportunity to each of the
+  // dialects that registered a target version emission to perform IR
+  // downgrades. Note that it is responsibility of client dialect to handle the
+  // potentially destructive changes that a downgrade function could do to the
+  // IR.
+  for (auto &item : config.getDialectVersionMap()) {
+    Dialect *currentDialect = op->getContext()->getOrLoadDialect(item.getKey());
+    assert(currentDialect &&
+           "version requested on a dialect that could not be loaded");
+    auto *iface = llvm::dyn_cast<BytecodeDialectInterface>(currentDialect);
+    assert(iface && "version requested on a dialect that did not implement a "
+                    "`BytecodeDialectInterface`");
+    if (failed(iface->downgradeToVersion(op, *item.getValue())))
+      return failure();
+  }
   BytecodeWriter writer(op, config);
   return writer.write(op, os);
 }
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 74c45723c222d27..036a9477cce6c1d 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -12,7 +12,6 @@
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
-#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 using namespace mlir::bytecode::detail;
@@ -22,7 +21,10 @@ using namespace mlir::bytecode::detail;
 //===----------------------------------------------------------------------===//
 
 struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
-  NumberingDialectWriter(IRNumberingState &state) : state(state) {}
+  NumberingDialectWriter(
+      IRNumberingState &state,
+      llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
+      : state(state), dialectVersionMap(dialectVersionMap) {}
 
   void writeAttribute(Attribute attr) override { state.number(attr); }
   void writeOptionalAttribute(Attribute attr) override {
@@ -51,8 +53,19 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
     return state.getDesiredBytecodeVersion();
   }
 
+  FailureOr<const DialectVersion *>
+  getDialectVersion(StringRef dialectName) const override {
+    auto dialectEntry = dialectVersionMap.find(dialectName);
+    if (dialectEntry == dialectVersionMap.end())
+      return failure();
+    return dialectEntry->getValue().get();
+  }
+
   /// The parent numbering state that is populated by this writer.
   IRNumberingState &state;
+
+  /// A map containing dialect version information for each dialect to emit.
+  llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
 };
 
 //===----------------------------------------------------------------------===//
@@ -318,7 +331,7 @@ void IRNumberingState::number(Attribute attr) {
   if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
     // Try overriding emission with callbacks.
     for (const auto &callback : config.getAttributeWriterCallbacks()) {
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       // The client has the ability to override the group name through the
       // callback.
       std::optional<StringRef> groupNameOverride;
@@ -330,7 +343,7 @@ void IRNumberingState::number(Attribute attr) {
     }
 
     if (const auto *interface = numbering->dialect->interface) {
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       if (succeeded(interface->writeAttribute(attr, writer)))
         return;
     }
@@ -426,7 +439,7 @@ void IRNumberingState::number(Operation &op) {
     if (op.isRegistered()) {
       // Operation that have properties *must* implement this interface.
       auto iface = cast<BytecodeOpInterface>(op);
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       iface.writeProperties(writer);
     } else {
       // Unregistered op are storing properties as an optional attribute.
@@ -481,7 +494,7 @@ void IRNumberingState::number(Type type) {
   if (!type.hasTrait<TypeTrait::IsMutable>()) {
     // Try overriding emission with callbacks.
     for (const auto &callback : config.getTypeWriterCallbacks()) {
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       // The client has the ability to override the group name through the
       // callback.
       std::optional<StringRef> groupNameOverride;
@@ -495,7 +508,7 @@ void IRNumberingState::number(Type type) {
     // If this attribute will be emitted using the bytecode format, perform a
     // dummy writing to number any nested components.
     if (const auto *interface = numbering->dialect->interface) {
-      NumberingDialectWriter writer(*this);
+      NumberingDialectWriter writer(*this, config.getDialectVersionMap());
       if (succeeded(interface->writeType(type, writer)))
         return;
     }
diff --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir
index cf3981c86b94428..3b537bb974da9ca 100644
--- a/mlir/test/Bytecode/bytecode_callback.mlir
+++ b/mlir/test/Bytecode/bytecode_callback.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
-// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
 
 func.func @base_test(%arg0 : i32) -> f32 {
   %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32
diff --git a/mlir/test/Bytecode/bytecode_callback_full_override.mlir b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
index 21ff947ad389b6a..792f8f6e2959b70 100644
--- a/mlir/test/Bytecode/bytecode_callback_full_override.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_full_override.mlir
@@ -1,4 +1,4 @@
-// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s
+// RUN: not mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=5" 2>&1 | FileCheck %s
 
 // CHECK-NOT: failed to read bytecode
 func.func @base_test(%arg0 : i32) -> f32 {
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
index 487972f85af5be0..30f25d9865792a8 100644
--- a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=3" | FileCheck %s --check-prefix=TEST_3
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=4" | FileCheck %s --check-prefix=TEST_4
 
 "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
 
diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
index 1e272ec4f3afc28..aba194f1681a24b 100644
--- a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
+++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1
-// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=1" | FileCheck %s --check-prefix=TEST_1
+// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=2" | FileCheck %s --check-prefix=TEST_2
 
 func.func @base_test(%arg0: !test.i32, %arg1: f32) {
   return
diff --git a/mlir/test/Bytecode/bytecode_downgrade.mlir b/mlir/test/Bytecode/bytecode_downgrade.mlir
new file mode 100644
index 000000000000000..fd98b9d7ee6a58c
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode_downgrade.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=1.2 test-kind=6" -verify-diagnostics | FileCheck %s
+
+module {
+  "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
+}
+
+// COM: the property downgrader is executed twice: first for IR numbering and then for emission.
+// CHECK: downgrading op...
+// CHECK: downgrading op properties...
+// CHECK: downgrading op properties...
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index deef4a9683e7421..61758085ba08172 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1358,6 +1358,17 @@ TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
 void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) {
   auto &prop = getProperties();
   writer.writeAttribute(prop.dims);
+
+  auto maybeVersion = writer.getDialectVersion("test");
+  if (succeeded(maybeVersion)) {
+    // If version is less than 2.0, there is no additional attribute to write.
+    const auto *version =
+        reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+    if ((version->major_ < 2)) {
+      llvm::outs() << "downgrading op properties...\n";
+      return;
+    }
+  }
   writer.writeAttribute(prop.modifier);
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 950af85007475b9..6ddccdf759de7b8 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -93,9 +93,16 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
 
   // Emit a specific version of the dialect.
   void writeVersion(DialectBytecodeWriter &writer) const final {
-    auto version = TestDialectVersion();
-    writer.writeVarInt(version.major_); // major
-    writer.writeVarInt(version.minor_); // minor
+    // Construct the current dialect version.
+    test::TestDialectVersion versionToEmit;
+
+    // Check if a target version to emit was specified on the writer configs.
+    auto versionOr = writer.getDialectVersion("test");
+    if (succeeded(versionOr))
+      versionToEmit =
+          *reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+    writer.writeVarInt(versionToEmit.major_); // major
+    writer.writeVarInt(versionToEmit.minor_); // minor
   }
 
   std::unique_ptr<DialectVersion>
@@ -130,6 +137,33 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
     return success();
   }
 
+  LogicalResult downgradeToVersion(Operation *topLevelOp,
+                                   const DialectVersion &version_) const final {
+    const auto &version = static_cast<const TestDialectVersion &>(version_);
+    if ((version.major_ == 2) && (version.minor_ == 0))
+      return success();
+    if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
+      return topLevelOp->emitError() << "current test dialect version is 2.0, "
+                                        "can't downgrade to version: "
+                                     << version.major_ << "." << version.minor_;
+    }
+    // Prior version 2.0, the old op supported only a single attribute called
+    // "dimensions". We need to check that the modifier is false, otherwise we
+    // can't do the downgrade.
+    auto status = topLevelOp->walk([&](TestVersionedOpA op) {
+      auto &prop = op.getProperties();
+      if (prop.modifier.getValue()) {
+        op->emitOpError() << "cannot downgrade to version " << version.major_
+                          << "." << version.minor_
+                          << " since the modifier is not compatible";
+        return WalkResult::interrupt();
+      }
+      llvm::outs() << "downgrading op...\n";
+      return WalkResult::advance();
+    });
+    return failure(status.wasInterrupted());
+  }
+
 private:
   Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
     uint64_t encoding;
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index fb0eac399245578..69c63fd7e524b6f 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -1,6 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestIR
-  TestBytecodeCallbacks.cpp
+  TestBytecodeRoundtrip.cpp
   TestBuiltinAttributeInterfaces.cpp
   TestBuiltinDistinctAttributes.cpp
   TestClone.cpp
diff --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
similarity index 88%
rename from mlir/test/lib/IR/TestBytecodeCallbacks.cpp
rename to mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
index cc884c7cc76f37a..cd01533c7930e72 100644
--- a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp
+++ b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
@@ -45,22 +45,23 @@ class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
 
 /// This is a test pass which uses callbacks to encode attributes and types in a
 /// custom fashion.
-struct TestBytecodeCallbackPass
-    : public PassWrapper<TestBytecodeCallbackPass, OperationPass<ModuleOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass)
+struct TestBytecodeRoundtripPass
+    : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass)
 
-  StringRef getArgument() const final { return "test-bytecode-callback"; }
+  StringRef getArgument() const final { return "test-bytecode-roundtrip"; }
   StringRef getDescription() const final {
-    return "Test encoding of a dialect type/attributes with a custom callback";
+    return "Test pass to implement bytecode roundtrip tests.";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<test::TestDialect>();
   }
-  TestBytecodeCallbackPass() = default;
-  TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {}
+  TestBytecodeRoundtripPass() = default;
+  TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {}
 
   void runOnOperation() override {
     switch (testKind) {
+      // Tests 0-5 implement a custom roundtrip with callbacks.
     case (0):
       return runTest0(getOperation());
     case (1):
@@ -73,6 +74,10 @@ struct TestBytecodeCallbackPass
       return runTest4(getOperation());
     case (5):
       return runTest5(getOperation());
+    case (6):
+      // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from
+      // `targetVersion`.
+      return runTest6(getOperation());
     default:
       llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
     }
@@ -85,8 +90,8 @@ struct TestBytecodeCallbackPass
                     cl::init(test::TestDialectVersion())};
 
   mlir::Pass::Option<int> testKind{
-      *this, "callback-test",
-      llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)};
+      *this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
+      cl::init(0)};
 
 private:
   void doRoundtripWithConfigs(Operation *op,
@@ -122,11 +127,20 @@ struct TestBytecodeCallbackPass
     auto newCtx = std::make_shared<MLIRContext>();
     test::TestDialectVersion targetEmissionVersion = targetVersion;
     BytecodeWriterConfig writeConfig;
+    // Set the emission version for the test dialect.
+    writeConfig.setDialectVersion(
+        "test",
+        std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
     writeConfig.attachTypeCallback(
         [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
             DialectBytecodeWriter &writer) -> LogicalResult {
-          // Do not override anything if version less than 2.0.
-          if (targetEmissionVersion.major_ >= 2)
+          // Do not override anything if version greater than 2.0.
+          auto versionOr = writer.getDialectVersion("test");
+          assert(succeeded(versionOr) && "expected reader to be able to access "
+                                         "the version for test dialect");
+          const auto *version =
+              reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
+          if (version->major_ >= 2)
             return failure();
 
           // For version less than 2.0, override the encoding of IntegerType.
@@ -151,14 +165,7 @@ struct TestBytecodeCallbackPass
                                          "the version for test dialect");
           const auto *version =
               reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
-
-          // TODO: once back-deployment is formally supported,
-          // `targetEmissionVersion` will be encoded in the bytecode file, and
-          // exposed through the versionMap. Right now though this is not yet
-          // supported. For the purpose of the test, just use
-          // `targetEmissionVersion`.
-          (void)version;
-          if (targetEmissionVersion.major_ >= 2)
+          if (version->major_ >= 2)
             return success();
 
           // `dialectName` is the name of the group we have the opportunity to
@@ -355,11 +362,22 @@ struct TestBytecodeCallbackPass
         });
     doRoundtripWithConfigs(op, writeConfig, parseConfig);
   }
+
+  // Test6: Plain roundtrip with downgrade to speficied target version.
+  void runTest6(Operation *op) {
+    BytecodeWriterConfig writeConfig;
+    test::TestDialectVersion targetEmissionVersion = targetVersion;
+    writeConfig.setDialectVersion(
+        "test",
+        std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
+    ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
+    doRoundtripWithConfigs(op, writeConfig, parseConfig);
+  }
 };
 } // namespace
 
 namespace mlir {
-void registerTestBytecodeCallbackPasses() {
-  PassRegistration<TestBytecodeCallbackPass>();
+void registerTestBytecodeRoundtripPasses() {
+  PassRegistration<TestBytecodeRoundtripPass>();
 }
 } // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b7647d7de78a10e..3e3223b48505601 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -44,7 +44,7 @@ void registerSymbolTestPasses();
 void registerRegionTestPasses();
 void registerTestAffineDataCopyPass();
 void registerTestAffineReifyValueBoundsPass();
-void registerTestBytecodeCallbackPasses();
+void registerTestBytecodeRoundtripPasses();
 void registerTestDecomposeAffineOpPass();
 void registerTestAffineLoopUnswitchingPass();
 void registerTestAllReduceLoweringPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
   registerTestDecomposeAffineOpPass();
   registerTestAffineLoopUnswitchingPass();
   registerTestAllReduceLoweringPass();
-  registerTestBytecodeCallbackPasses();
+  registerTestBytecodeRoundtripPasses();
   registerTestFunc();
   registerTestGpuMemoryPromotionPass();
   registerTestLoopPermutationPass();



More information about the Mlir-commits mailing list