[Mlir-commits] [mlir] [mlir][python] Cache import of ir module in type casters. (PR #160000)
Peter Hawkins
llvmlistbot at llvm.org
Sun Sep 21 12:21:20 PDT 2025
https://github.com/hawkinsp created https://github.com/llvm/llvm-project/pull/160000
In a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.
>From 4006133819d3bfdf271a86b3c23474e13ede391e Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins at google.com>
Date: Sun, 21 Sep 2025 19:17:23 +0000
Subject: [PATCH] [mlir][python] Cache import of ir module in type casters.
In a JAX benchmark that traces a large language model, this change
reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.
---
.../mlir/Bindings/Python/NanobindAdaptors.h | 99 +++++++++++++------
1 file changed, 70 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 8744d8d0e4bca..aeb51542f9b6d 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -19,7 +19,9 @@
#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
+#include <atomic>
#include <cstdint>
+#include <memory>
#include <optional>
#include "mlir-c/Diagnostics.h"
@@ -30,6 +32,56 @@
// clang-format on
#include "llvm/ADT/Twine.h"
+namespace mlir {
+namespace python {
+namespace {
+
+// Safely calls Python initialization code on first use, avoiding deadlocks.
+template <typename T> class SafeInit {
+public:
+ typedef std::unique_ptr<T> (*F)();
+
+ explicit SafeInit(F init_fn) : init_fn_(init_fn) {}
+
+ T &Get() {
+ if (T *result = output_.load()) {
+ return *result;
+ }
+
+ // Note: init_fn() may be called multiple times if, for example, the GIL is
+ // released during its execution. The intended use case is for module
+ // imports which are safe to perform multiple times. We are careful not to
+ // hold a lock across init_fn() to avoid lock ordering problems.
+ std::unique_ptr<T> m = init_fn_();
+ {
+ nanobind::ft_lock_guard lock(mu_);
+ if (T *result = output_.load()) {
+ return *result;
+ }
+ T *p = m.release();
+ output_.store(p);
+ return *p;
+ }
+ }
+
+private:
+ nanobind::ft_mutex mu_;
+ std::atomic<T *> output_{nullptr};
+ F init_fn_;
+};
+
+nanobind::module_ &IrModule() {
+ static SafeInit<nanobind::module_> init([]() {
+ return std::make_unique<nanobind::module_>(
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")));
+ });
+ return init.Get();
+}
+
+} // namespace
+} // namespace python
+} // namespace mlir
+
// Raw CAPI type casters need to be declared before use, so always include them
// first.
namespace nanobind {
@@ -75,7 +127,7 @@ struct type_caster<MlirAffineMap> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonAffineMapToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("AffineMap")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -97,7 +149,7 @@ struct type_caster<MlirAttribute> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonAttributeToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Attribute")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -128,9 +180,7 @@ struct type_caster<MlirContext> {
// TODO: This raises an error of "No current context" currently.
// Update the implementation to pretty-print the helpful error that the
// core implementations print in this case.
- src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Context")
- .attr("current");
+ src = mlir::python::IrModule().attr("Context").attr("current");
}
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToContext(capsule->ptr());
@@ -153,7 +203,7 @@ struct type_caster<MlirDialectRegistry> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule = nanobind::steal<nanobind::object>(
mlirPythonDialectRegistryToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("DialectRegistry")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -167,9 +217,7 @@ struct type_caster<MlirLocation> {
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (src.is_none()) {
// Gets the current thread-bound context.
- src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Location")
- .attr("current");
+ src = mlir::python::IrModule().attr("Location").attr("current");
}
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToLocation(capsule->ptr());
@@ -181,7 +229,7 @@ struct type_caster<MlirLocation> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonLocationToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Location")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -203,7 +251,7 @@ struct type_caster<MlirModule> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonModuleToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Module")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -250,7 +298,7 @@ struct type_caster<MlirOperation> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Operation")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -274,7 +322,7 @@ struct type_caster<MlirValue> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Value")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -312,7 +360,7 @@ struct type_caster<MlirTypeID> {
return nanobind::none();
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("TypeID")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
@@ -334,7 +382,7 @@ struct type_caster<MlirType> {
cleanup_list *cleanup) noexcept {
nanobind::object capsule =
nanobind::steal<nanobind::object>(mlirPythonTypeToCapsule(t));
- return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ return mlir::python::IrModule()
.attr("Type")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -453,11 +501,9 @@ class mlir_attribute_subclass : public pure_subclass {
mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName,
IsAFunctionTy isaFunction,
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
- : mlir_attribute_subclass(
- scope, attrClassName, isaFunction,
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Attribute"),
- getTypeIDFunction) {}
+ : mlir_attribute_subclass(scope, attrClassName, isaFunction,
+ IrModule().attr("Attribute"),
+ getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Attribute super-class. This must
/// be used if the subclass is being defined in the same extension module
@@ -540,11 +586,8 @@ class mlir_type_subclass : public pure_subclass {
mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction,
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
- : mlir_type_subclass(
- scope, typeClassName, isaFunction,
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Type"),
- getTypeIDFunction) {}
+ : mlir_type_subclass(scope, typeClassName, isaFunction,
+ IrModule().attr("Type"), getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Type super-class. This must
/// be used if the subclass is being defined in the same extension module
@@ -631,10 +674,8 @@ class mlir_value_subclass : public pure_subclass {
/// Subclasses by looking up the super-class dynamically.
mlir_value_subclass(nanobind::handle scope, const char *valueClassName,
IsAFunctionTy isaFunction)
- : mlir_value_subclass(
- scope, valueClassName, isaFunction,
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Value")) {}
+ : mlir_value_subclass(scope, valueClassName, isaFunction,
+ IrModule().attr("Value")) {}
/// Subclasses with a provided mlir.ir.Value super-class. This must
/// be used if the subclass is being defined in the same extension module
More information about the Mlir-commits
mailing list