[Mlir-commits] [mlir] a60e8a2 - [mlir] Python: write bytecode to a file path (#127118)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 24 17:51:53 PST 2025


Author: Nikhil Kalra
Date: 2025-02-24T17:51:49-08:00
New Revision: a60e8a2c2579252d66a0656c387af29475e9b908

URL: https://github.com/llvm/llvm-project/commit/a60e8a2c2579252d66a0656c387af29475e9b908
DIFF: https://github.com/llvm/llvm-project/commit/a60e8a2c2579252d66a0656c387af29475e9b908.diff

LOG: [mlir] Python: write bytecode to a file path (#127118)

The current `write_bytecode` implementation necessarily requires the
serialized module to be duplicated in memory when the python `bytes`
object is created and sent over the binding. For modules with large
resources, we may want to avoid this in-memory copy by serializing
directly to a file instead of sending bytes across the boundary.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/NanobindUtils.h
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 827db5f3eba84..b13a429d4a3c0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,12 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <optional>
-#include <utility>
-
 #include "Globals.h"
 #include "IRModule.h"
 #include "NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
@@ -19,9 +17,14 @@
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/nanobind.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <optional>
+#include <system_error>
+#include <utility>
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -1329,11 +1332,11 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
                               accum.getUserData());
 }
 
-void PyOperationBase::writeBytecode(const nb::object &fileObject,
+void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
                                     std::optional<int64_t> bytecodeVersion) {
   PyOperation &operation = getOperation();
   operation.checkValid();
-  PyFileAccumulator accum(fileObject, /*binary=*/true);
+  PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
 
   if (!bytecodeVersion.has_value())
     return mlirOperationWriteBytecode(operation, accum.getCallback(),

diff  --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index ee193cf9f8ef8..64ea4329f65f1 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -13,8 +13,13 @@
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/DataTypes.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <string>
+#include <variant>
 
 template <>
 struct std::iterator_traits<nanobind::detail::fast_iterator> {
@@ -128,33 +133,59 @@ struct PyPrintAccumulator {
   }
 };
 
-/// Accumulates int a python file-like object, either writing text (default)
-/// or binary.
+/// Accumulates into a file, either writing text (default)
+/// or binary. The file may be a Python file-like object or a path to a file.
 class PyFileAccumulator {
 public:
-  PyFileAccumulator(const nanobind::object &fileObject, bool binary)
-      : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
+  PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
+      : binary(binary) {
+    std::string filePath;
+    if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
+      std::error_code ec;
+      writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec);
+      if (ec) {
+        throw nanobind::value_error(
+            (std::string("Unable to open file for writing: ") + ec.message())
+                .c_str());
+      }
+    } else {
+      writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
+    }
+  }
+
+  MlirStringCallback getCallback() {
+    return writeTarget.index() == 0 ? getPyWriteCallback()
+                                    : getOstreamCallback();
+  }
 
   void *getUserData() { return this; }
 
-  MlirStringCallback getCallback() {
+private:
+  MlirStringCallback getPyWriteCallback() {
     return [](MlirStringRef part, void *userData) {
       nanobind::gil_scoped_acquire acquire;
       PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
       if (accum->binary) {
         // Note: Still has to copy and not avoidable with this API.
         nanobind::bytes pyBytes(part.data, part.length);
-        accum->pyWriteFunction(pyBytes);
+        std::get<nanobind::object>(accum->writeTarget)(pyBytes);
       } else {
         nanobind::str pyStr(part.data,
                             part.length); // Decodes as UTF-8 by default.
-        accum->pyWriteFunction(pyStr);
+        std::get<nanobind::object>(accum->writeTarget)(pyStr);
       }
     };
   }
 
-private:
-  nanobind::object pyWriteFunction;
+  MlirStringCallback getOstreamCallback() {
+    return [](MlirStringRef part, void *userData) {
+      PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
+      std::get<llvm::raw_fd_ostream>(accum->writeTarget)
+          .write(part.data, part.length);
+    };
+  }
+
+  std::variant<nanobind::object, llvm::raw_fd_ostream> writeTarget;
   bool binary;
 };
 

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index ab975a6954044..c93de2fe3154e 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -47,7 +47,7 @@ import collections
 from collections.abc import Callable, Sequence
 import io
 from pathlib import Path
-from typing import Any, ClassVar, TypeVar, overload
+from typing import Any, BinaryIO, ClassVar, TypeVar, overload
 
 __all__ = [
     "AffineAddExpr",
@@ -285,12 +285,12 @@ class _OperationBase:
         """
         Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
         """
-    def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
+    def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
         """
         Write the bytecode form of the operation to a file like object.
 
         Args:
-          file: The file like object to write to.
+          file: The file like object or path to write to.
           desired_version: The version of bytecode to emit.
         Returns:
           The bytecode writer status.

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c2d3aed8808b4..090d0030fb062 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -3,6 +3,7 @@
 import gc
 import io
 import itertools
+from tempfile import NamedTemporaryFile
 from mlir.ir import *
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import arith
@@ -617,6 +618,12 @@ def testOperationPrint():
     module.operation.write_bytecode(bytecode_stream, desired_version=1)
     bytecode = bytecode_stream.getvalue()
     assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+    with NamedTemporaryFile() as tmpfile:
+        module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
+        tmpfile.seek(0)
+        assert tmpfile.read().startswith(
+            b"ML\xefR"
+        ), "Expected bytecode to start with MLïR"
     ctx2 = Context()
     module_roundtrip = Module.parse(bytecode, ctx2)
     f = io.StringIO()


        


More information about the Mlir-commits mailing list