[Mlir-commits] [mlir] [mlir][python] Cache import of ir module in type casters. (PR #160000)

Peter Hawkins llvmlistbot at llvm.org
Sun Sep 21 12:21:20 PDT 2025


https://github.com/hawkinsp created https://github.com/llvm/llvm-project/pull/160000

In a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.

>From 4006133819d3bfdf271a86b3c23474e13ede391e Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins at google.com>
Date: Sun, 21 Sep 2025 19:17:23 +0000
Subject: [PATCH] [mlir][python] Cache import of ir module in type casters.

In a JAX benchmark that traces a large language model, this change
reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.
---
 .../mlir/Bindings/Python/NanobindAdaptors.h   | 99 +++++++++++++------
 1 file changed, 70 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 8744d8d0e4bca..aeb51542f9b6d 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -19,7 +19,9 @@
 #ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
 #define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
 
+#include <atomic>
 #include <cstdint>
+#include <memory>
 #include <optional>
 
 #include "mlir-c/Diagnostics.h"
@@ -30,6 +32,56 @@
 // clang-format on
 #include "llvm/ADT/Twine.h"
 
+namespace mlir {
+namespace python {
+namespace {
+
+// Safely calls Python initialization code on first use, avoiding deadlocks.
+template <typename T> class SafeInit {
+public:
+  typedef std::unique_ptr<T> (*F)();
+
+  explicit SafeInit(F init_fn) : init_fn_(init_fn) {}
+
+  T &Get() {
+    if (T *result = output_.load()) {
+      return *result;
+    }
+
+    // Note: init_fn() may be called multiple times if, for example, the GIL is
+    // released during its execution. The intended use case is for module
+    // imports which are safe to perform multiple times. We are careful not to
+    // hold a lock across init_fn() to avoid lock ordering problems.
+    std::unique_ptr<T> m = init_fn_();
+    {
+      nanobind::ft_lock_guard lock(mu_);
+      if (T *result = output_.load()) {
+        return *result;
+      }
+      T *p = m.release();
+      output_.store(p);
+      return *p;
+    }
+  }
+
+private:
+  nanobind::ft_mutex mu_;
+  std::atomic<T *> output_{nullptr};
+  F init_fn_;
+};
+
+nanobind::module_ &IrModule() {
+  static SafeInit<nanobind::module_> init([]() {
+    return std::make_unique<nanobind::module_>(
+        nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")));
+  });
+  return init.Get();
+}
+
+} // namespace
+} // namespace python
+} // namespace mlir
+
 // Raw CAPI type casters need to be declared before use, so always include them
 // first.
 namespace nanobind {
@@ -75,7 +127,7 @@ struct type_caster<MlirAffineMap> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonAffineMapToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("AffineMap")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -97,7 +149,7 @@ struct type_caster<MlirAttribute> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonAttributeToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Attribute")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -128,9 +180,7 @@ struct type_caster<MlirContext> {
       // TODO: This raises an error of "No current context" currently.
       // Update the implementation to pretty-print the helpful error that the
       // core implementations print in this case.
-      src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Context")
-                .attr("current");
+      src = mlir::python::IrModule().attr("Context").attr("current");
     }
     std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
     value = mlirPythonCapsuleToContext(capsule->ptr());
@@ -153,7 +203,7 @@ struct type_caster<MlirDialectRegistry> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule = nanobind::steal<nanobind::object>(
         mlirPythonDialectRegistryToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("DialectRegistry")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -167,9 +217,7 @@ struct type_caster<MlirLocation> {
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
     if (src.is_none()) {
       // Gets the current thread-bound context.
-      src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Location")
-                .attr("current");
+      src = mlir::python::IrModule().attr("Location").attr("current");
     }
     if (auto capsule = mlirApiObjectToCapsule(src)) {
       value = mlirPythonCapsuleToLocation(capsule->ptr());
@@ -181,7 +229,7 @@ struct type_caster<MlirLocation> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonLocationToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Location")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -203,7 +251,7 @@ struct type_caster<MlirModule> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonModuleToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Module")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -250,7 +298,7 @@ struct type_caster<MlirOperation> {
       return nanobind::none();
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Operation")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -274,7 +322,7 @@ struct type_caster<MlirValue> {
       return nanobind::none();
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Value")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -312,7 +360,7 @@ struct type_caster<MlirTypeID> {
       return nanobind::none();
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("TypeID")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .release();
@@ -334,7 +382,7 @@ struct type_caster<MlirType> {
                          cleanup_list *cleanup) noexcept {
     nanobind::object capsule =
         nanobind::steal<nanobind::object>(mlirPythonTypeToCapsule(t));
-    return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+    return mlir::python::IrModule()
         .attr("Type")
         .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
         .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -453,11 +501,9 @@ class mlir_attribute_subclass : public pure_subclass {
   mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName,
                           IsAFunctionTy isaFunction,
                           GetTypeIDFunctionTy getTypeIDFunction = nullptr)
-      : mlir_attribute_subclass(
-            scope, attrClassName, isaFunction,
-            nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Attribute"),
-            getTypeIDFunction) {}
+      : mlir_attribute_subclass(scope, attrClassName, isaFunction,
+                                IrModule().attr("Attribute"),
+                                getTypeIDFunction) {}
 
   /// Subclasses with a provided mlir.ir.Attribute super-class. This must
   /// be used if the subclass is being defined in the same extension module
@@ -540,11 +586,8 @@ class mlir_type_subclass : public pure_subclass {
   mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
                      IsAFunctionTy isaFunction,
                      GetTypeIDFunctionTy getTypeIDFunction = nullptr)
-      : mlir_type_subclass(
-            scope, typeClassName, isaFunction,
-            nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Type"),
-            getTypeIDFunction) {}
+      : mlir_type_subclass(scope, typeClassName, isaFunction,
+                           IrModule().attr("Type"), getTypeIDFunction) {}
 
   /// Subclasses with a provided mlir.ir.Type super-class. This must
   /// be used if the subclass is being defined in the same extension module
@@ -631,10 +674,8 @@ class mlir_value_subclass : public pure_subclass {
   /// Subclasses by looking up the super-class dynamically.
   mlir_value_subclass(nanobind::handle scope, const char *valueClassName,
                       IsAFunctionTy isaFunction)
-      : mlir_value_subclass(
-            scope, valueClassName, isaFunction,
-            nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Value")) {}
+      : mlir_value_subclass(scope, valueClassName, isaFunction,
+                            IrModule().attr("Value")) {}
 
   /// Subclasses with a provided mlir.ir.Value super-class. This must
   /// be used if the subclass is being defined in the same extension module



More information about the Mlir-commits mailing list