[Mlir-commits] [mlir] [mlir] Python: write bytecode to a file path (PR #127118)

Nikhil Kalra llvmlistbot at llvm.org
Thu Feb 13 12:22:07 PST 2025


https://github.com/nikalra updated https://github.com/llvm/llvm-project/pull/127118

>From b6d637b02c31f90a1f959d5f0864866e3851ec0e Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Thu, 13 Feb 2025 12:08:49 -0800
Subject: [PATCH 1/3] [mlir] Python: write bytecode to a file path

The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.
---
 mlir/lib/Bindings/Python/IRCore.cpp      | 40 ++++++++++++++++++------
 mlir/lib/Bindings/Python/NanobindUtils.h | 22 ++++++++++++-
 mlir/python/mlir/_mlir_libs/_mlir/ir.pyi |  6 ++--
 mlir/test/python/ir/operation.py         |  5 +++
 4 files changed, 60 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 827db5f3eba84..fbe54f8d81cf0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include <optional>
+#include <system_error>
 #include <utility>
 
 #include "Globals.h"
@@ -20,8 +21,10 @@
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/nanobind.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -1329,20 +1332,18 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
                               accum.getUserData());
 }
 
-void PyOperationBase::writeBytecode(const nb::object &fileObject,
-                                    std::optional<int64_t> bytecodeVersion) {
-  PyOperation &operation = getOperation();
-  operation.checkValid();
-  PyFileAccumulator accum(fileObject, /*binary=*/true);
-
+template <typename T>
+static void
+writeBytecodeForOperation(T &accumulator, MlirOperation operation,
+                          const std::optional<int64_t> &bytecodeVersion) {
   if (!bytecodeVersion.has_value())
-    return mlirOperationWriteBytecode(operation, accum.getCallback(),
-                                      accum.getUserData());
+    return mlirOperationWriteBytecode(operation, accumulator.getCallback(),
+                                      accumulator.getUserData());
 
   MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
   mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
   MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
-      operation, config, accum.getCallback(), accum.getUserData());
+      operation, config, accumulator.getCallback(), accumulator.getUserData());
   mlirBytecodeWriterConfigDestroy(config);
   if (mlirLogicalResultIsFailure(res))
     throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
@@ -1351,6 +1352,27 @@ void PyOperationBase::writeBytecode(const nb::object &fileObject,
                               .c_str());
 }
 
+void PyOperationBase::writeBytecode(const nb::object &fileObject,
+                                    std::optional<int64_t> bytecodeVersion) {
+  PyOperation &operation = getOperation();
+  operation.checkValid();
+
+  std::string filePath;
+  if (nb::try_cast<std::string>(fileObject, filePath)) {
+    std::error_code ec;
+    llvm::raw_fd_ostream ostream(filePath, ec);
+    if (ec) {
+      throw nb::value_error("Unable to open file for writing");
+    }
+
+    OstreamAccumulator accum(ostream);
+    writeBytecodeForOperation(accum, operation, bytecodeVersion);
+  } else {
+    PyFileAccumulator accum(fileObject, /*binary=*/true);
+    writeBytecodeForOperation(accum, operation, bytecodeVersion);
+  }
+}
+
 void PyOperationBase::walk(
     std::function<MlirWalkResult(MlirOperation)> callback,
     MlirWalkOrder walkOrder) {
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index ee193cf9f8ef8..ca9aa064219cd 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -13,8 +13,10 @@
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/DataTypes.h"
+#include "llvm/Support/raw_ostream.h"
 
 template <>
 struct std::iterator_traits<nanobind::detail::fast_iterator> {
@@ -128,7 +130,7 @@ struct PyPrintAccumulator {
   }
 };
 
-/// Accumulates int a python file-like object, either writing text (default)
+/// Accumulates into a python file-like object, either writing text (default)
 /// or binary.
 class PyFileAccumulator {
 public:
@@ -158,6 +160,24 @@ class PyFileAccumulator {
   bool binary;
 };
 
+/// Accumulates into a LLVM ostream.
+class OstreamAccumulator {
+public:
+  OstreamAccumulator(llvm::raw_ostream &ostream) : ostream(ostream) {}
+
+  void *getUserData() { return this; }
+
+  MlirStringCallback getCallback() {
+    return [](MlirStringRef part, void *userData) {
+      OstreamAccumulator *accum = static_cast<OstreamAccumulator *>(userData);
+      accum->ostream << llvm::StringRef(part.data, part.length);
+    };
+  }
+
+private:
+  llvm::raw_ostream &ostream;
+};
+
 /// Accumulates into a python string from a method that is expected to make
 /// one (no more, no less) call to the callback (asserts internally on
 /// violation).
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index ab975a6954044..c93de2fe3154e 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -47,7 +47,7 @@ import collections
 from collections.abc import Callable, Sequence
 import io
 from pathlib import Path
-from typing import Any, ClassVar, TypeVar, overload
+from typing import Any, BinaryIO, ClassVar, TypeVar, overload
 
 __all__ = [
     "AffineAddExpr",
@@ -285,12 +285,12 @@ class _OperationBase:
         """
         Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
         """
-    def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
+    def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
         """
         Write the bytecode form of the operation to a file like object.
 
         Args:
-          file: The file like object to write to.
+          file: The file like object or path to write to.
           desired_version: The version of bytecode to emit.
         Returns:
           The bytecode writer status.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c2d3aed8808b4..43836abb74f5e 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -3,6 +3,7 @@
 import gc
 import io
 import itertools
+from tempfile import NamedTemporaryFile
 from mlir.ir import *
 from mlir.dialects.builtin import ModuleOp
 from mlir.dialects import arith
@@ -617,6 +618,10 @@ def testOperationPrint():
     module.operation.write_bytecode(bytecode_stream, desired_version=1)
     bytecode = bytecode_stream.getvalue()
     assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+    with NamedTemporaryFile() as tmpfile:
+        module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
+        tmpfile.seek(0)
+        assert tmpfile.read().startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
     ctx2 = Context()
     module_roundtrip = Module.parse(bytecode, ctx2)
     f = io.StringIO()

>From 3be708c7f330c57b101bfba9233323cd736237c6 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Thu, 13 Feb 2025 12:19:57 -0800
Subject: [PATCH 2/3] clang-format

---
 mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index fbe54f8d81cf0..180d2fd62d5ca 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -13,6 +13,7 @@
 #include "Globals.h"
 #include "IRModule.h"
 #include "NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Debug.h"
 #include "mlir-c/Diagnostics.h"
@@ -20,7 +21,6 @@
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "nanobind/nanobind.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"

>From 05454842d33376f66f60f98d0394b7c795433bab Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Thu, 13 Feb 2025 12:21:50 -0800
Subject: [PATCH 3/3] py formatter

---
 mlir/test/python/ir/operation.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 43836abb74f5e..090d0030fb062 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -621,7 +621,9 @@ def testOperationPrint():
     with NamedTemporaryFile() as tmpfile:
         module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
         tmpfile.seek(0)
-        assert tmpfile.read().startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+        assert tmpfile.read().startswith(
+            b"ML\xefR"
+        ), "Expected bytecode to start with MLïR"
     ctx2 = Context()
     module_roundtrip = Module.parse(bytecode, ctx2)
     f = io.StringIO()



More information about the Mlir-commits mailing list