[Mlir-commits] [mlir] [mlir] Python: write bytecode to a file path (PR #127118)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 13 12:10:32 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nikhil Kalra (nikalra)

<details>
<summary>Changes</summary>

The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.

---
Full diff: https://github.com/llvm/llvm-project/pull/127118.diff


4 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+31-9) 
- (modified) mlir/lib/Bindings/Python/NanobindUtils.h (+21-1) 
- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+3-3) 
- (modified) mlir/test/python/ir/operation.py (+5) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 827db5f3eba84..fbe54f8d81cf0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include <optional>
+#include <system_error>
 #include <utility>
 
 #include "Globals.h"
@@ -20,8 +21,10 @@
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/nanobind.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -1329,20 +1332,18 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
                               accum.getUserData());
 }
 
-void PyOperationBase::writeBytecode(const nb::object &fileObject,
-                                    std::optional<int64_t> bytecodeVersion) {
-  PyOperation &operation = getOperation();
-  operation.checkValid();
-  PyFileAccumulator accum(fileObject, /*binary=*/true);
-
+template <typename T>
+static void
+writeBytecodeForOperation(T &accumulator, MlirOperation operation,
+                          const std::optional<int64_t> &bytecodeVersion) {
   if (!bytecodeVersion.has_value())
-    return mlirOperationWriteBytecode(operation, accum.getCallback(),
-                                      accum.getUserData());
+    return mlirOperationWriteBytecode(operation, accumulator.getCallback(),
+                                      accumulator.getUserData());
 
   MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
   mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
   MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
-      operation, config, accum.getCallback(), accum.getUserData());
+      operation, config, accumulator.getCallback(), accumulator.getUserData());
   mlirBytecodeWriterConfigDestroy(config);
   if (mlirLogicalResultIsFailure(res))
     throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
@@ -1351,6 +1352,27 @@ void PyOperationBase::writeBytecode(const nb::object &fileObject,
                               .c_str());
 }
 
+void PyOperationBase::writeBytecode(const nb::object &fileObject,
+                                    std::optional<int64_t> bytecodeVersion) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+
+  std::string filePath;
+  if (nb::try_cast<std::string>(fileObject, filePath)) {
+    std::error_code ec;
+    llvm::raw_fd_ostream ostream(filePath, ec);
+    if (ec) {
+      throw nb::value_error("Unable to open file for writing");
+    }
+
+    OstreamAccumulator accum(ostream);
+    writeBytecodeForOperation(accum, operation, bytecodeVersion);
+  } else {
+    PyFileAccumulator accum(fileObject, /*binary=*/true);
+    writeBytecodeForOperation(accum, operation, bytecodeVersion);
+  }
+}
+
 void PyOperationBase::walk(
     std::function<MlirWalkResult(MlirOperation)> callback,
     MlirWalkOrder walkOrder) {
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index ee193cf9f8ef8..ca9aa064219cd 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -13,8 +13,10 @@
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/DataTypes.h"
+#include "llvm/Support/raw_ostream.h"
 
 template <>
 struct std::iterator_traits<nanobind::detail::fast_iterator> {
@@ -128,7 +130,7 @@ struct PyPrintAccumulator {
   }
 };
 
-/// Accumulates int a python file-like object, either writing text (default)
+/// Accumulates into a python file-like object, either writing text (default)
 /// or binary.
 class PyFileAccumulator {
 public:
@@ -158,6 +160,24 @@ class PyFileAccumulator {
   bool binary;
 };
 
+/// Accumulates into a LLVM ostream.
+class OstreamAccumulator {
+public:
+  OstreamAccumulator(llvm::raw_ostream &ostream) : ostream(ostream) {}
+
+  void *getUserData() { return this; }
+
+  MlirStringCallback getCallback() {
+    return [](MlirStringRef part, void *userData) {
+      OstreamAccumulator *accum = static_cast<OstreamAccumulator *>(userData);
+      accum->ostream << llvm::StringRef(part.data, part.length);
+    };
+  }
+
+private:
+  llvm::raw_ostream &ostream;
+};
+
 /// Accumulates into a python string from a method that is expected to make
 /// one (no more, no less) call to the callback (asserts internally on
 /// violation).
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index ab975a6954044..c93de2fe3154e 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -47,7 +47,7 @@ import collections
 from collections.abc import Callable, Sequence
 import io
 from pathlib import Path
-from typing import Any, ClassVar, TypeVar, overload
+from typing import Any, BinaryIO, ClassVar, TypeVar, overload
 
 __all__ = [
     "AffineAddExpr",
@@ -285,12 +285,12 @@ class _OperationBase:
         """
         Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
         """
-    def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
+    def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
         """
         Write the bytecode form of the operation to a file like object.
 
         Args:
-          file: The file like object to write to.
+          file: The file like object or path to write to.
           desired_version: The version of bytecode to emit.
         Returns:
           The bytecode writer status.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c2d3aed8808b4..43836abb74f5e 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -3,6 +3,7 @@
 import gc
 import io
 import itertools
+from tempfile import NamedTemporaryFile
 from mlir.ir import *
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import arith
@@ -617,6 +618,10 @@ def testOperationPrint():
     module.operation.write_bytecode(bytecode_stream, desired_version=1)
     bytecode = bytecode_stream.getvalue()
     assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+    with NamedTemporaryFile() as tmpfile:
+        module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
+        tmpfile.seek(0)
+        assert tmpfile.read().startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
     ctx2 = Context()
     module_roundtrip = Module.parse(bytecode, ctx2)
     f = io.StringIO()

``````````

</details>


https://github.com/llvm/llvm-project/pull/127118


More information about the Mlir-commits mailing list