[Mlir-commits] [mlir] 5c90e1f - [mlir][bytecode] Return error instead of min version

Jacques Pienaar llvmlistbot at llvm.org
Sun Apr 30 22:11:08 PDT 2023


Author: Jacques Pienaar
Date: 2023-04-30T22:11:02-07:00
New Revision: 5c90e1ffb009478a644f9bd6439f37f96b09e399

URL: https://github.com/llvm/llvm-project/commit/5c90e1ffb009478a644f9bd6439f37f96b09e399
DIFF: https://github.com/llvm/llvm-project/commit/5c90e1ffb009478a644f9bd6439f37f96b09e399.diff

LOG: [mlir][bytecode] Return error instead of min version

Can't return a well-formed IR output while enabling version to be bumped
up during emission. Previously it would return min version but
potentially invalid IR which was confusing, instead make it return
error and abort immediately instead.

Differential Revision: https://reviews.llvm.org/D149569

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/Bytecode/BytecodeWriter.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
    mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 315ec3a846b65..90af14461e29e 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -565,27 +565,16 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op,
                                                     MlirStringCallback callback,
                                                     void *userData);
 
-struct MlirBytecodeWriterResult {
-  int64_t minVersion;
-};
-typedef struct MlirBytecodeWriterResult MlirBytecodeWriterResult;
-
-inline static bool
-mlirBytecodeWriterResultGetMinVersion(MlirBytecodeWriterResult res) {
-  return res.minVersion;
-}
-
-/// Same as mlirOperationPrint but writing the bytecode format and returns the
-/// minimum bytecode version the consumer needs to support.
-MLIR_CAPI_EXPORTED MlirBytecodeWriterResult mlirOperationWriteBytecode(
-    MlirOperation op, MlirStringCallback callback, void *userData);
-
-/// Same as mlirOperationWriteBytecode but with writer config.
-MLIR_CAPI_EXPORTED MlirBytecodeWriterResult
-mlirOperationWriteBytecodeWithConfig(MlirOperation op,
-                                     MlirBytecodeWriterConfig config,
-                                     MlirStringCallback callback,
-                                     void *userData);
+/// Same as mlirOperationPrint but writing the bytecode format.
+MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op,
+                                                   MlirStringCallback callback,
+                                                   void *userData);
+
+/// Same as mlirOperationWriteBytecode but with writer config and returns
+/// failure only if desired bytecode could not be honored.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(
+    MlirOperation op, MlirBytecodeWriterConfig config,
+    MlirStringCallback callback, void *userData);
 
 /// Prints an operation to stderr.
 MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);

diff  --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index 8877c59dd9da2..bd4087f6c498d 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -75,21 +75,15 @@ class BytecodeWriterConfig {
   std::unique_ptr<Impl> impl;
 };
 
-/// Status of bytecode serialization.
-struct BytecodeWriterResult {
-  /// The minimum version of the reader required to read the serialized file.
-  int64_t minVersion;
-};
-
 //===----------------------------------------------------------------------===//
 // Entry Points
 //===----------------------------------------------------------------------===//
 
 /// Write the bytecode for the given operation to the provided output stream.
 /// For streams where it matters, the given stream should be in "binary" mode.
-BytecodeWriterResult
-writeBytecodeToFile(Operation *op, raw_ostream &os,
-                    const BytecodeWriterConfig &config = {});
+/// It only ever fails if setDesiredByteCodeVersion can't be honored.
+LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os,
+                                  const BytecodeWriterConfig &config = {});
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 052998be18ffe..f2e188e78c4a7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -16,6 +16,7 @@
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -1134,9 +1135,8 @@ void PyOperationBase::print(py::object fileObject, bool binary,
   mlirOpPrintingFlagsDestroy(flags);
 }
 
-MlirBytecodeWriterResult
-PyOperationBase::writeBytecode(const py::object &fileObject,
-                               std::optional<int64_t> bytecodeVersion) {
+void PyOperationBase::writeBytecode(const py::object &fileObject,
+                                    std::optional<int64_t> bytecodeVersion) {
   PyOperation &operation = getOperation();
   operation.checkValid();
   PyFileAccumulator accum(fileObject, /*binary=*/true);
@@ -1147,8 +1147,12 @@ PyOperationBase::writeBytecode(const py::object &fileObject,
 
   MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
   mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
-  return mlirOperationWriteBytecodeWithConfig(
+  MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
       operation, config, accum.getCallback(), accum.getUserData());
+  if (mlirLogicalResultIsFailure(res))
+    throw py::value_error((Twine("Unable to honor desired bytecode version ") +
+                           Twine(*bytecodeVersion))
+                              .str());
 }
 
 py::object PyOperationBase::getAsm(bool binary,
@@ -3378,10 +3382,6 @@ void mlir::python::populateIRCore(py::module &m) {
                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
                   py::arg("callback"));
 
-  py::class_<MlirBytecodeWriterResult>(m, "BytecodeResult", py::module_local())
-      .def("min_version",
-           [](MlirBytecodeWriterResult &res) { return res.minVersion; });
-
   // Container bindings.
   PyBlockArgumentList::bind(m);
   PyBlockIterator::bind(m);

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 56bb834b4eac9..ade790ba0ed13 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -554,9 +554,8 @@ class PyOperationBase {
                           bool assumeVerified);
 
   // Implement the bound 'writeBytecode' method.
-  MlirBytecodeWriterResult
-  writeBytecode(const pybind11::object &fileObject,
-                std::optional<int64_t> bytecodeVersion);
+  void writeBytecode(const pybind11::object &fileObject,
+                     std::optional<int64_t> bytecodeVersion);
 
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 95729a2a4fa5c..801f3022d0e47 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -887,12 +887,10 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
 // Entry Points
 //===----------------------------------------------------------------------===//
 
-BytecodeWriterResult
-mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
-                          const BytecodeWriterConfig &config) {
+LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
+                                        const BytecodeWriterConfig &config) {
   BytecodeWriter writer(op, config.getImpl());
   writer.write(op, os);
-  // Return the bytecode version emitted - currently there is no additional
-  // feedback as to minimum beyond the requested one.
-  return {config.getImpl().bytecodeVersion};
+  // Currently there is no failure case.
+  return success();
 }

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 03f154965e965..0069bf10263a2 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -524,25 +524,18 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
   unwrap(op)->print(stream, *unwrap(flags));
 }
 
-MlirBytecodeWriterResult mlirOperationWriteBytecode(MlirOperation op,
-                                                    MlirStringCallback callback,
-                                                    void *userData) {
+void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
+                                void *userData) {
   detail::CallbackOstream stream(callback, userData);
-  MlirBytecodeWriterResult res;
-  BytecodeWriterResult r = writeBytecodeToFile(unwrap(op), stream);
-  res.minVersion = r.minVersion;
-  return res;
+  // As no desired version is set, no failure can occur.
+  (void)writeBytecodeToFile(unwrap(op), stream);
 }
 
-MlirBytecodeWriterResult mlirOperationWriteBytecodeWithConfig(
+MlirLogicalResult mlirOperationWriteBytecodeWithConfig(
     MlirOperation op, MlirBytecodeWriterConfig config,
     MlirStringCallback callback, void *userData) {
   detail::CallbackOstream stream(callback, userData);
-  BytecodeWriterResult r =
-      writeBytecodeToFile(unwrap(op), stream, *unwrap(config));
-  MlirBytecodeWriterResult res;
-  res.minVersion = r.minVersion;
-  return res;
+  return wrap(writeBytecodeToFile(unwrap(op), stream, *unwrap(config)));
 }
 
 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }

diff  --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 8597423946c79..5e4dc07c4c3f1 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -885,7 +885,8 @@ MLIRDocument::convertToBytecode() {
 
     std::string rawBytecodeBuffer;
     llvm::raw_string_ostream os(rawBytecodeBuffer);
-    writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
+    // No desired bytecode version set, so no need to check for error.
+    (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
     result.output = llvm::encodeBase64(rawBytecodeBuffer);
   }
   return result;

diff  --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 3b2b5ed178f1a..18060edb2ed7a 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -264,14 +264,9 @@ performActions(raw_ostream &os,
   TimingScope outputTiming = timing.nest("Output");
   if (config.shouldEmitBytecode()) {
     BytecodeWriterConfig writerConfig(fallbackResourceMap);
-    if (auto v = config.bytecodeVersionToEmit()) {
+    if (auto v = config.bytecodeVersionToEmit())
       writerConfig.setDesiredBytecodeVersion(*v);
-      // Returns failure if requested version couldn't be used for opt tools.
-      return success(
-          writeBytecodeToFile(op.get(), os, writerConfig).minVersion <= *v);
-    }
-    writeBytecodeToFile(op.get(), os, writerConfig);
-    return success();
+    return writeBytecodeToFile(op.get(), os, writerConfig);
   }
 
   if (config.bytecodeVersionToEmit().has_value())

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 1afe1d9ed0da8..ea84d113a9488 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -571,8 +571,7 @@ def testOperationPrint():
 
   # Test roundtrip to bytecode.
   bytecode_stream = io.BytesIO()
-  result = module.operation.write_bytecode(bytecode_stream, desired_version=1)
-  assert result.min_version() == 1, "Requested version not serialized to"
+  module.operation.write_bytecode(bytecode_stream, desired_version=1)
   bytecode = bytecode_stream.getvalue()
   assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR"
   module_roundtrip = Module.parse(bytecode, ctx)


        


More information about the Mlir-commits mailing list