[Mlir-commits] [mlir] [mlir] Python: write bytecode to a file path (PR #127118)
Nikhil Kalra
llvmlistbot at llvm.org
Thu Feb 13 12:09:58 PST 2025
https://github.com/nikalra created https://github.com/llvm/llvm-project/pull/127118
The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.
>From b6d637b02c31f90a1f959d5f0864866e3851ec0e Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Thu, 13 Feb 2025 12:08:49 -0800
Subject: [PATCH] [mlir] Python: write bytecode to a file path
The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.
---
mlir/lib/Bindings/Python/IRCore.cpp | 40 ++++++++++++++++++------
mlir/lib/Bindings/Python/NanobindUtils.h | 22 ++++++++++++-
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 6 ++--
mlir/test/python/ir/operation.py | 5 +++
4 files changed, 60 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 827db5f3eba84..fbe54f8d81cf0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include <optional>
+#include <system_error>
#include <utility>
#include "Globals.h"
@@ -20,8 +21,10 @@
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
namespace nb = nanobind;
using namespace nb::literals;
@@ -1329,20 +1332,18 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
accum.getUserData());
}
-void PyOperationBase::writeBytecode(const nb::object &fileObject,
- std::optional<int64_t> bytecodeVersion) {
- PyOperation &operation = getOperation();
- operation.checkValid();
- PyFileAccumulator accum(fileObject, /*binary=*/true);
-
+template <typename T>
+static void
+writeBytecodeForOperation(T &accumulator, MlirOperation operation,
+ const std::optional<int64_t> &bytecodeVersion) {
if (!bytecodeVersion.has_value())
- return mlirOperationWriteBytecode(operation, accum.getCallback(),
- accum.getUserData());
+ return mlirOperationWriteBytecode(operation, accumulator.getCallback(),
+ accumulator.getUserData());
MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
- operation, config, accum.getCallback(), accum.getUserData());
+ operation, config, accumulator.getCallback(), accumulator.getUserData());
mlirBytecodeWriterConfigDestroy(config);
if (mlirLogicalResultIsFailure(res))
throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
@@ -1351,6 +1352,27 @@ void PyOperationBase::writeBytecode(const nb::object &fileObject,
.c_str());
}
+void PyOperationBase::writeBytecode(const nb::object &fileObject,
+ std::optional<int64_t> bytecodeVersion) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
+
+ std::string filePath;
+ if (nb::try_cast<std::string>(fileObject, filePath)) {
+ std::error_code ec;
+ llvm::raw_fd_ostream ostream(filePath, ec);
+ if (ec) {
+ throw nb::value_error("Unable to open file for writing");
+ }
+
+ OstreamAccumulator accum(ostream);
+ writeBytecodeForOperation(accum, operation, bytecodeVersion);
+ } else {
+ PyFileAccumulator accum(fileObject, /*binary=*/true);
+ writeBytecodeForOperation(accum, operation, bytecodeVersion);
+ }
+}
+
void PyOperationBase::walk(
std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder) {
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index ee193cf9f8ef8..ca9aa064219cd 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -13,8 +13,10 @@
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"
+#include "llvm/Support/raw_ostream.h"
template <>
struct std::iterator_traits<nanobind::detail::fast_iterator> {
@@ -128,7 +130,7 @@ struct PyPrintAccumulator {
}
};
-/// Accumulates int a python file-like object, either writing text (default)
+/// Accumulates into a python file-like object, either writing text (default)
/// or binary.
class PyFileAccumulator {
public:
@@ -158,6 +160,24 @@ class PyFileAccumulator {
bool binary;
};
+/// Accumulates into a LLVM ostream.
+class OstreamAccumulator {
+public:
+ OstreamAccumulator(llvm::raw_ostream &ostream) : ostream(ostream) {}
+
+ void *getUserData() { return this; }
+
+ MlirStringCallback getCallback() {
+ return [](MlirStringRef part, void *userData) {
+ OstreamAccumulator *accum = static_cast<OstreamAccumulator *>(userData);
+ accum->ostream << llvm::StringRef(part.data, part.length);
+ };
+ }
+
+private:
+ llvm::raw_ostream &ostream;
+};
+
/// Accumulates into a python string from a method that is expected to make
/// one (no more, no less) call to the callback (asserts internally on
/// violation).
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index ab975a6954044..c93de2fe3154e 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -47,7 +47,7 @@ import collections
from collections.abc import Callable, Sequence
import io
from pathlib import Path
-from typing import Any, ClassVar, TypeVar, overload
+from typing import Any, BinaryIO, ClassVar, TypeVar, overload
__all__ = [
"AffineAddExpr",
@@ -285,12 +285,12 @@ class _OperationBase:
"""
Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
"""
- def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
+ def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
"""
Write the bytecode form of the operation to a file like object.
Args:
- file: The file like object to write to.
+ file: The file like object or path to write to.
desired_version: The version of bytecode to emit.
Returns:
The bytecode writer status.
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c2d3aed8808b4..43836abb74f5e 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -3,6 +3,7 @@
import gc
import io
import itertools
+from tempfile import NamedTemporaryFile
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import arith
@@ -617,6 +618,10 @@ def testOperationPrint():
module.operation.write_bytecode(bytecode_stream, desired_version=1)
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+ with NamedTemporaryFile() as tmpfile:
+ module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
+ tmpfile.seek(0)
+ assert tmpfile.read().startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
ctx2 = Context()
module_roundtrip = Module.parse(bytecode, ctx2)
f = io.StringIO()
More information about the Mlir-commits
mailing list