[llvm] [mlir] [mlir python] Port in-tree dialects to nanobind. (PR #119924)

Jacques Pienaar via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 20 17:54:23 PST 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/119924

>From 59a598ce89b586742abec3c701dc66775e0ce0ff Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins at google.com>
Date: Fri, 13 Dec 2024 20:40:36 +0000
Subject: [PATCH] [mlir python] Port in-tree dialects to nanobind.

This is a companion to #118583, although it can be landed independently
because since #117922 dialects do not have to use the same Python
binding framework as the Python core code.

This PR ports all of the in-tree dialect and pass extensions to nanobind,
with the exception of those that remain for testing pybind11 support. It
would make sense to merge this PR after merging #118583, if we have
agreed that we are migrating the core to nanobind.

This PR also:
* removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This
  was overlooked in a previous PR and it is duplicated in Diagnostics.h.

---------

Co-authored-by: Jacques Pienaar <jpienaar at google.com>
---
 mlir/cmake/modules/AddMLIRPython.cmake        | 12 +++
 mlir/cmake/modules/MLIRDetectPythonEnv.cmake  | 12 +++
 .../python/StandaloneExtensionNanobind.cpp    |  3 +-
 mlir/include/mlir/Bindings/Python/Nanobind.h  | 37 ++++++++
 .../mlir/Bindings/Python/NanobindAdaptors.h   | 38 +--------
 mlir/lib/Bindings/Python/AsyncPasses.cpp      |  5 +-
 mlir/lib/Bindings/Python/DialectGPU.cpp       | 44 +++++-----
 mlir/lib/Bindings/Python/DialectLLVM.cpp      | 54 ++++++------
 mlir/lib/Bindings/Python/DialectLinalg.cpp    | 11 +--
 mlir/lib/Bindings/Python/DialectNVGPU.cpp     | 20 ++---
 mlir/lib/Bindings/Python/DialectPDL.cpp       | 43 +++++-----
 mlir/lib/Bindings/Python/DialectQuant.cpp     | 79 +++++++++--------
 .../Bindings/Python/DialectSparseTensor.cpp   | 45 +++++-----
 mlir/lib/Bindings/Python/DialectTransform.cpp | 48 +++++------
 .../Bindings/Python/ExecutionEngineModule.cpp | 85 ++++++++++---------
 mlir/lib/Bindings/Python/GPUPasses.cpp        |  5 +-
 mlir/lib/Bindings/Python/IRAffine.cpp         |  5 +-
 mlir/lib/Bindings/Python/IRAttributes.cpp     |  8 +-
 mlir/lib/Bindings/Python/IRCore.cpp           |  8 +-
 mlir/lib/Bindings/Python/IRInterfaces.cpp     |  5 +-
 mlir/lib/Bindings/Python/IRModule.cpp         |  4 +-
 mlir/lib/Bindings/Python/IRModule.h           |  4 +-
 mlir/lib/Bindings/Python/IRTypes.cpp          |  6 --
 mlir/lib/Bindings/Python/LinalgPasses.cpp     |  4 +-
 mlir/lib/Bindings/Python/MainModule.cpp       |  3 +-
 mlir/lib/Bindings/Python/NanobindUtils.h      |  5 +-
 mlir/lib/Bindings/Python/Pass.cpp             |  5 +-
 .../Bindings/Python/RegisterEverything.cpp    |  5 +-
 mlir/lib/Bindings/Python/Rewrite.cpp          |  3 +-
 .../Bindings/Python/SparseTensorPasses.cpp    |  4 +-
 .../Bindings/Python/TransformInterpreter.cpp  | 44 +++++-----
 mlir/python/CMakeLists.txt                    | 22 +++--
 .../python/dialects/sparse_tensor/dialect.py  |  2 +-
 .../python/lib/PythonTestModuleNanobind.cpp   |  4 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     | 23 ++---
 35 files changed, 351 insertions(+), 354 deletions(-)
 create mode 100644 mlir/include/mlir/Bindings/Python/Nanobind.h

diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 67619a90c90be9..53a70139fd5a68 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -661,6 +661,18 @@ function(add_mlir_python_extension libname extname)
       NB_DOMAIN mlir
       ${ARG_SOURCES}
     )
+
+    if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
+      # Avoids warnings from upstream nanobind.
+      target_compile_options(nanobind-static
+        PRIVATE
+          -Wno-cast-qual
+          -Wno-zero-length-array
+          -Wno-nested-anon-types
+          -Wno-c++98-compat-extra-semi
+          -Wno-covered-switch-default
+      )
+    endif()
   endif()
 
   # The extension itself must be compiled with RTTI and exceptions enabled.
diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
index d6bb65c64b8292..3a87d39c28a061 100644
--- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
+++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
@@ -95,5 +95,17 @@ function(mlir_detect_nanobind_install)
     endif()
     message(STATUS "found (${PACKAGE_DIR})")
     set(nanobind_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
+    execute_process(
+      COMMAND "${Python3_EXECUTABLE}"
+      -c "import nanobind;print(nanobind.include_dir(), end='')"
+      WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+      RESULT_VARIABLE STATUS
+      OUTPUT_VARIABLE PACKAGE_DIR
+      ERROR_QUIET)
+    if(NOT STATUS EQUAL "0")
+      message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)")
+      return()
+    endif()
+    set(nanobind_INCLUDE_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
   endif()
 endfunction()
diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
index 6d83dc585dcd1d..189ebac368bf59 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
@@ -9,9 +9,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <nanobind/nanobind.h>
-
 #include "Standalone-c/Dialects.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
new file mode 100644
index 00000000000000..ca942c83d3e2fa
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -0,0 +1,37 @@
+//===- Nanobind.h - Trampoline header with ignored warnings ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file is a trampoline for the nanobind headers while disabling warnings
+// reported by the LLVM/MLIR build. This file avoids adding complexity build
+// system side.
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_NANOBIND_H
+#define MLIR_BINDINGS_PYTHON_NANOBIND_H
+
+#if defined(__clang__) || defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wzero-length-array"
+#pragma GCC diagnostic ignored "-Wcast-qual"
+#pragma GCC diagnostic ignored "-Wnested-anon-types"
+#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi"
+#pragma GCC diagnostic ignored "-Wcovered-switch-default"
+#endif
+#include <nanobind/nanobind.h>
+#include <nanobind/ndarray.h>
+#include <nanobind/stl/function.h>
+#include <nanobind/stl/optional.h>
+#include <nanobind/stl/pair.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/string_view.h>
+#include <nanobind/stl/tuple.h>
+#include <nanobind/stl/vector.h>
+#if defined(__clang__) || defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
+
+#endif // MLIR_BINDINGS_PYTHON_NANOBIND_H
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 943981b1fa03dd..ae3d5bb6c74843 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -19,14 +19,12 @@
 #ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
 #define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/string.h>
-
 #include <cstdint>
 
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/Twine.h"
 
 // Raw CAPI type casters need to be declared before use, so always include them
@@ -631,40 +629,6 @@ class mlir_value_subclass : public pure_subclass {
 
 } // namespace nanobind_adaptors
 
-/// RAII scope intercepting all diagnostics into a string. The message must be
-/// checked before this goes out of scope.
-class CollectDiagnosticsToStringScope {
-public:
-  explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
-    handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
-                                                   /*deleteUserData=*/nullptr);
-  }
-  ~CollectDiagnosticsToStringScope() {
-    assert(errorMessage.empty() && "unchecked error message");
-    mlirContextDetachDiagnosticHandler(context, handlerID);
-  }
-
-  [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
-
-private:
-  static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
-    auto printer = +[](MlirStringRef message, void *data) {
-      *static_cast<std::string *>(data) +=
-          llvm::StringRef(message.data, message.length);
-    };
-    MlirLocation loc = mlirDiagnosticGetLocation(diag);
-    *static_cast<std::string *>(data) += "at ";
-    mlirLocationPrint(loc, printer, data);
-    *static_cast<std::string *>(data) += ": ";
-    mlirDiagnosticPrint(diag, printer, data);
-    return mlirLogicalResultSuccess();
-  }
-
-  MlirContext context;
-  MlirDiagnosticHandlerID handlerID;
-  std::string errorMessage = "";
-};
-
 } // namespace python
 } // namespace mlir
 
diff --git a/mlir/lib/Bindings/Python/AsyncPasses.cpp b/mlir/lib/Bindings/Python/AsyncPasses.cpp
index b611a758dbbb37..cfb8dcaaa72ae3 100644
--- a/mlir/lib/Bindings/Python/AsyncPasses.cpp
+++ b/mlir/lib/Bindings/Python/AsyncPasses.cpp
@@ -8,14 +8,13 @@
 
 #include "mlir-c/Dialect/Async.h"
 
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
+#include "mlir/Bindings/Python/Nanobind.h"
 
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirAsyncPasses, m) {
+NB_MODULE(_mlirAsyncPasses, m) {
   m.doc() = "MLIR Async Dialect Passes";
 
   // Register all Async passes on load.
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 560a54bcd15919..e5045cf0bba354 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,21 +9,21 @@
 #include "mlir-c/Dialect/GPU.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
+namespace nb = nanobind;
+using namespace nanobind::literals;
 
-namespace py = pybind11;
 using namespace mlir;
 using namespace mlir::python;
-using namespace mlir::python::adaptors;
+using namespace mlir::python::nanobind_adaptors;
 
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirDialectsGPU, m) {
+NB_MODULE(_mlirDialectsGPU, m) {
   m.doc() = "MLIR GPU Dialect";
   //===-------------------------------------------------------------------===//
   // AsyncTokenType
@@ -34,11 +34,11 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
 
   mlirGPUAsyncTokenType.def_classmethod(
       "get",
-      [](py::object cls, MlirContext ctx) {
+      [](nb::object cls, MlirContext ctx) {
         return cls(mlirGPUAsyncTokenTypeGet(ctx));
       },
-      "Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
-      py::arg("ctx") = py::none());
+      "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
+      nb::arg("ctx").none() = nb::none());
 
   //===-------------------------------------------------------------------===//
   // ObjectAttr
@@ -47,12 +47,12 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
   mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
       .def_classmethod(
           "get",
-          [](py::object cls, MlirAttribute target, uint32_t format,
-             py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
+          [](nb::object cls, MlirAttribute target, uint32_t format,
+             nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
              std::optional<MlirAttribute> mlirKernelsAttr) {
-            py::buffer_info info(py::buffer(object).request());
-            MlirStringRef objectStrRef =
-                mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
+            MlirStringRef objectStrRef = mlirStringRefCreate(
+                static_cast<char *>(const_cast<void *>(object.data())),
+                object.size());
             return cls(mlirGPUObjectAttrGetWithKernels(
                 mlirAttributeGetContext(target), target, format, objectStrRef,
                 mlirObjectProps.has_value() ? *mlirObjectProps
@@ -61,7 +61,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
                                             : MlirAttribute{nullptr}));
           },
           "cls"_a, "target"_a, "format"_a, "object"_a,
-          "properties"_a = py::none(), "kernels"_a = py::none(),
+          "properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(),
           "Gets a gpu.object from parameters.")
       .def_property_readonly(
           "target",
@@ -73,18 +73,18 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
           "object",
           [](MlirAttribute self) {
             MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
-            return py::bytes(stringRef.data, stringRef.length);
+            return nb::bytes(stringRef.data, stringRef.length);
           })
       .def_property_readonly("properties",
-                             [](MlirAttribute self) {
+                             [](MlirAttribute self) -> nb::object {
                                if (mlirGPUObjectAttrHasProperties(self))
-                                 return py::cast(
+                                 return nb::cast(
                                      mlirGPUObjectAttrGetProperties(self));
-                               return py::none().cast<py::object>();
+                               return nb::none();
                              })
-      .def_property_readonly("kernels", [](MlirAttribute self) {
+      .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
         if (mlirGPUObjectAttrHasKernels(self))
-          return py::cast(mlirGPUObjectAttrGetKernels(self));
-        return py::none().cast<py::object>();
+          return nb::cast(mlirGPUObjectAttrGetKernels(self));
+        return nb::none();
       });
 }
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index cccf1370b8cc87..f211e769d66bec 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -12,15 +12,19 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+
+namespace nb = nanobind;
+
+using namespace nanobind::literals;
 
-namespace py = pybind11;
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::python;
-using namespace mlir::python::adaptors;
+using namespace mlir::python::nanobind_adaptors;
 
-void populateDialectLLVMSubmodule(const pybind11::module &m) {
+void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
 
   //===--------------------------------------------------------------------===//
   // StructType
@@ -31,35 +35,35 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
 
   llvmStructType.def_classmethod(
       "get_literal",
-      [](py::object cls, const std::vector<MlirType> &elements, bool packed,
+      [](nb::object cls, const std::vector<MlirType> &elements, bool packed,
          MlirLocation loc) {
         CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
 
         MlirType type = mlirLLVMStructTypeLiteralGetChecked(
             loc, elements.size(), elements.data(), packed);
         if (mlirTypeIsNull(type)) {
-          throw py::value_error(scope.takeMessage());
+          throw nb::value_error(scope.takeMessage().c_str());
         }
         return cls(type);
       },
-      "cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
-      "loc"_a = py::none());
+      "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+      "loc"_a.none() = nb::none());
 
   llvmStructType.def_classmethod(
       "get_identified",
-      [](py::object cls, const std::string &name, MlirContext context) {
+      [](nb::object cls, const std::string &name, MlirContext context) {
         return cls(mlirLLVMStructTypeIdentifiedGet(
             context, mlirStringRefCreate(name.data(), name.size())));
       },
-      "cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
+      "cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none());
 
   llvmStructType.def_classmethod(
       "get_opaque",
-      [](py::object cls, const std::string &name, MlirContext context) {
+      [](nb::object cls, const std::string &name, MlirContext context) {
         return cls(mlirLLVMStructTypeOpaqueGet(
             context, mlirStringRefCreate(name.data(), name.size())));
       },
-      "cls"_a, "name"_a, "context"_a = py::none());
+      "cls"_a, "name"_a, "context"_a.none() = nb::none());
 
   llvmStructType.def(
       "set_body",
@@ -67,22 +71,22 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
         MlirLogicalResult result = mlirLLVMStructTypeSetBody(
             self, elements.size(), elements.data(), packed);
         if (!mlirLogicalResultIsSuccess(result)) {
-          throw py::value_error(
+          throw nb::value_error(
               "Struct body already set to different content.");
         }
       },
-      "elements"_a, py::kw_only(), "packed"_a = false);
+      "elements"_a, nb::kw_only(), "packed"_a = false);
 
   llvmStructType.def_classmethod(
       "new_identified",
-      [](py::object cls, const std::string &name,
+      [](nb::object cls, const std::string &name,
          const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
         return cls(mlirLLVMStructTypeIdentifiedNewGet(
             ctx, mlirStringRefCreate(name.data(), name.length()),
             elements.size(), elements.data(), packed));
       },
-      "cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
-      "context"_a = py::none());
+      "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
+      "context"_a.none() = nb::none());
 
   llvmStructType.def_property_readonly(
       "name", [](MlirType type) -> std::optional<std::string> {
@@ -93,12 +97,12 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
         return StringRef(stringRef.data, stringRef.length).str();
       });
 
-  llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
+  llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
     // Don't crash in absence of a body.
     if (mlirLLVMStructTypeIsOpaque(type))
-      return py::none();
+      return nb::none();
 
-    py::list body;
+    nb::list body;
     for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
          ++i) {
       body.append(mlirLLVMStructTypeGetElementType(type, i));
@@ -119,24 +123,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
   mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
       .def_classmethod(
           "get",
-          [](py::object cls, std::optional<unsigned> addressSpace,
+          [](nb::object cls, std::optional<unsigned> addressSpace,
              MlirContext context) {
             CollectDiagnosticsToStringScope scope(context);
             MlirType type = mlirLLVMPointerTypeGet(
                 context, addressSpace.has_value() ? *addressSpace : 0);
             if (mlirTypeIsNull(type)) {
-              throw py::value_error(scope.takeMessage());
+              throw nb::value_error(scope.takeMessage().c_str());
             }
             return cls(type);
           },
-          "cls"_a, "address_space"_a = py::none(), py::kw_only(),
-          "context"_a = py::none())
+          "cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(),
+          "context"_a.none() = nb::none())
       .def_property_readonly("address_space", [](MlirType type) {
         return mlirLLVMPointerTypeGetAddressSpace(type);
       });
 }
 
-PYBIND11_MODULE(_mlirDialectsLLVM, m) {
+NB_MODULE(_mlirDialectsLLVM, m) {
   m.doc() = "MLIR LLVM Dialect";
 
   populateDialectLLVMSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 2e54ebeb61fb10..548df4ee100aa9 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -8,20 +8,21 @@
 
 #include "mlir-c/Dialect/Linalg.h"
 #include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-namespace py = pybind11;
+namespace nb = nanobind;
 
-static void populateDialectLinalgSubmodule(py::module m) {
+static void populateDialectLinalgSubmodule(nb::module_ m) {
   m.def(
       "fill_builtin_region",
       [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
-      py::arg("op"),
+      nb::arg("op"),
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
 }
 
-PYBIND11_MODULE(_mlirDialectsLinalg, m) {
+NB_MODULE(_mlirDialectsLinalg, m) {
   m.doc() = "MLIR Linalg dialect.";
 
   populateDialectLinalgSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index 754e0a75b0abc7..a0d6a4b4c73f92 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -8,33 +8,33 @@
 
 #include "mlir-c/Dialect/NVGPU.h"
 #include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
-#include <pybind11/pybind11.h>
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-namespace py = pybind11;
+namespace nb = nanobind;
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::python;
-using namespace mlir::python::adaptors;
+using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectNVGPUSubmodule(const pybind11::module &m) {
+static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
   auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
       m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);
 
   nvgpuTensorMapDescriptorType.def_classmethod(
       "get",
-      [](py::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
+      [](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
          int oobFill, int interleave, MlirContext ctx) {
         return cls(mlirNVGPUTensorMapDescriptorTypeGet(
             ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
       },
       "Gets an instance of TensorMapDescriptorType in the same context",
-      py::arg("cls"), py::arg("tensor_type"), py::arg("swizzle"),
-      py::arg("l2promo"), py::arg("oob_fill"), py::arg("interleave"),
-      py::arg("ctx") = py::none());
+      nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"),
+      nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"),
+      nb::arg("ctx").none() = nb::none());
 }
 
-PYBIND11_MODULE(_mlirDialectsNVGPU, m) {
+NB_MODULE(_mlirDialectsNVGPU, m) {
   m.doc() = "MLIR NVGPU dialect.";
 
   populateDialectNVGPUSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index 8d3f9a7ab1d6ac..bcc6ff406c9529 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -8,19 +8,16 @@
 
 #include "mlir-c/Dialect/PDL.h"
 #include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
-#include <pybind11/cast.h>
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-#include <pybind11/pytypes.h>
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-namespace py = pybind11;
+namespace nb = nanobind;
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::python;
-using namespace mlir::python::adaptors;
+using namespace mlir::python::nanobind_adaptors;
 
-void populateDialectPDLSubmodule(const pybind11::module &m) {
+void populateDialectPDLSubmodule(const nanobind::module_ &m) {
   //===-------------------------------------------------------------------===//
   // PDLType
   //===-------------------------------------------------------------------===//
@@ -35,11 +32,11 @@ void populateDialectPDLSubmodule(const pybind11::module &m) {
       mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
   attributeType.def_classmethod(
       "get",
-      [](py::object cls, MlirContext ctx) {
+      [](nb::object cls, MlirContext ctx) {
         return cls(mlirPDLAttributeTypeGet(ctx));
       },
-      "Get an instance of AttributeType in given context.", py::arg("cls"),
-      py::arg("context") = py::none());
+      "Get an instance of AttributeType in given context.", nb::arg("cls"),
+      nb::arg("context").none() = nb::none());
 
   //===-------------------------------------------------------------------===//
   // OperationType
@@ -49,11 +46,11 @@ void populateDialectPDLSubmodule(const pybind11::module &m) {
       mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
   operationType.def_classmethod(
       "get",
-      [](py::object cls, MlirContext ctx) {
+      [](nb::object cls, MlirContext ctx) {
         return cls(mlirPDLOperationTypeGet(ctx));
       },
-      "Get an instance of OperationType in given context.", py::arg("cls"),
-      py::arg("context") = py::none());
+      "Get an instance of OperationType in given context.", nb::arg("cls"),
+      nb::arg("context").none() = nb::none());
 
   //===-------------------------------------------------------------------===//
   // RangeType
@@ -62,12 +59,12 @@ void populateDialectPDLSubmodule(const pybind11::module &m) {
   auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
   rangeType.def_classmethod(
       "get",
-      [](py::object cls, MlirType elementType) {
+      [](nb::object cls, MlirType elementType) {
         return cls(mlirPDLRangeTypeGet(elementType));
       },
       "Gets an instance of RangeType in the same context as the provided "
       "element type.",
-      py::arg("cls"), py::arg("element_type"));
+      nb::arg("cls"), nb::arg("element_type"));
   rangeType.def_property_readonly(
       "element_type",
       [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); },
@@ -80,11 +77,11 @@ void populateDialectPDLSubmodule(const pybind11::module &m) {
   auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
   typeType.def_classmethod(
       "get",
-      [](py::object cls, MlirContext ctx) {
+      [](nb::object cls, MlirContext ctx) {
         return cls(mlirPDLTypeTypeGet(ctx));
       },
-      "Get an instance of TypeType in given context.", py::arg("cls"),
-      py::arg("context") = py::none());
+      "Get an instance of TypeType in given context.", nb::arg("cls"),
+      nb::arg("context").none() = nb::none());
 
   //===-------------------------------------------------------------------===//
   // ValueType
@@ -93,14 +90,14 @@ void populateDialectPDLSubmodule(const pybind11::module &m) {
   auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
   valueType.def_classmethod(
       "get",
-      [](py::object cls, MlirContext ctx) {
+      [](nb::object cls, MlirContext ctx) {
         return cls(mlirPDLValueTypeGet(ctx));
       },
-      "Get an instance of TypeType in given context.", py::arg("cls"),
-      py::arg("context") = py::none());
+      "Get an instance of TypeType in given context.", nb::arg("cls"),
+      nb::arg("context").none() = nb::none());
 }
 
-PYBIND11_MODULE(_mlirDialectsPDL, m) {
+NB_MODULE(_mlirDialectsPDL, m) {
   m.doc() = "MLIR PDL dialect.";
   populateDialectPDLSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 9a871f2c122d12..29f19c9c500659 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -6,21 +6,20 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir-c/Dialect/Quant.h"
-#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
 #include <cstdint>
-#include <pybind11/cast.h>
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
 #include <vector>
 
-namespace py = pybind11;
+#include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+
+namespace nb = nanobind;
 using namespace llvm;
 using namespace mlir;
-using namespace mlir::python::adaptors;
+using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectQuantSubmodule(const py::module &m) {
+static void populateDialectQuantSubmodule(const nb::module_ &m) {
   //===-------------------------------------------------------------------===//
   // QuantizedType
   //===-------------------------------------------------------------------===//
@@ -35,7 +34,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       },
       "Default minimum value for the integer with the specified signedness and "
       "bit width.",
-      py::arg("is_signed"), py::arg("integral_width"));
+      nb::arg("is_signed"), nb::arg("integral_width"));
   quantizedType.def_staticmethod(
       "default_maximum_for_integer",
       [](bool isSigned, unsigned integralWidth) {
@@ -44,7 +43,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       },
       "Default maximum value for the integer with the specified signedness and "
       "bit width.",
-      py::arg("is_signed"), py::arg("integral_width"));
+      nb::arg("is_signed"), nb::arg("integral_width"));
   quantizedType.def_property_readonly(
       "expressed_type",
       [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
@@ -82,7 +81,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       },
       "Checks whether the candidate type can be expressed by this quantized "
       "type.",
-      py::arg("candidate"));
+      nb::arg("candidate"));
   quantizedType.def_property_readonly(
       "quantized_element_type",
       [](MlirType type) {
@@ -96,24 +95,24 @@ static void populateDialectQuantSubmodule(const py::module &m) {
             mlirQuantizedTypeCastFromStorageType(type, candidate);
         if (!mlirTypeIsNull(castResult))
           return castResult;
-        throw py::type_error("Invalid cast.");
+        throw nb::type_error("Invalid cast.");
       },
       "Casts from a type based on the storage type of this quantized type to a "
       "corresponding type based on the quantized type. Raises TypeError if the "
       "cast is not valid.",
-      py::arg("candidate"));
+      nb::arg("candidate"));
   quantizedType.def_staticmethod(
       "cast_to_storage_type",
       [](MlirType type) {
         MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
         if (!mlirTypeIsNull(castResult))
           return castResult;
-        throw py::type_error("Invalid cast.");
+        throw nb::type_error("Invalid cast.");
       },
       "Casts from a type based on a quantized type to a corresponding type "
       "based on the storage type of this quantized type. Raises TypeError if "
       "the cast is not valid.",
-      py::arg("type"));
+      nb::arg("type"));
   quantizedType.def(
       "cast_from_expressed_type",
       [](MlirType type, MlirType candidate) {
@@ -121,24 +120,24 @@ static void populateDialectQuantSubmodule(const py::module &m) {
             mlirQuantizedTypeCastFromExpressedType(type, candidate);
         if (!mlirTypeIsNull(castResult))
           return castResult;
-        throw py::type_error("Invalid cast.");
+        throw nb::type_error("Invalid cast.");
       },
       "Casts from a type based on the expressed type of this quantized type to "
       "a corresponding type based on the quantized type. Raises TypeError if "
       "the cast is not valid.",
-      py::arg("candidate"));
+      nb::arg("candidate"));
   quantizedType.def_staticmethod(
       "cast_to_expressed_type",
       [](MlirType type) {
         MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
         if (!mlirTypeIsNull(castResult))
           return castResult;
-        throw py::type_error("Invalid cast.");
+        throw nb::type_error("Invalid cast.");
       },
       "Casts from a type based on a quantized type to a corresponding type "
       "based on the expressed type of this quantized type. Raises TypeError if "
       "the cast is not valid.",
-      py::arg("type"));
+      nb::arg("type"));
   quantizedType.def(
       "cast_expressed_to_storage_type",
       [](MlirType type, MlirType candidate) {
@@ -146,12 +145,12 @@ static void populateDialectQuantSubmodule(const py::module &m) {
             mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
         if (!mlirTypeIsNull(castResult))
           return castResult;
-        throw py::type_error("Invalid cast.");
+        throw nb::type_error("Invalid cast.");
       },
       "Casts from a type based on the expressed type of this quantized type to "
       "a corresponding type based on the storage type. Raises TypeError if the "
       "cast is not valid.",
-      py::arg("candidate"));
+      nb::arg("candidate"));
 
   quantizedType.get_class().attr("FLAG_SIGNED") =
       mlirQuantizedTypeGetSignedFlag();
@@ -165,7 +164,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
                          quantizedType.get_class());
   anyQuantizedType.def_classmethod(
       "get",
-      [](py::object cls, unsigned flags, MlirType storageType,
+      [](nb::object cls, unsigned flags, MlirType storageType,
          MlirType expressedType, int64_t storageTypeMin,
          int64_t storageTypeMax) {
         return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
@@ -173,9 +172,9 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       },
       "Gets an instance of AnyQuantizedType in the same context as the "
       "provided storage type.",
-      py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
-      py::arg("expressed_type"), py::arg("storage_type_min"),
-      py::arg("storage_type_max"));
+      nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
+      nb::arg("expressed_type"), nb::arg("storage_type_min"),
+      nb::arg("storage_type_max"));
 
   //===-------------------------------------------------------------------===//
   // UniformQuantizedType
@@ -186,7 +185,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       quantizedType.get_class());
   uniformQuantizedType.def_classmethod(
       "get",
-      [](py::object cls, unsigned flags, MlirType storageType,
+      [](nb::object cls, unsigned flags, MlirType storageType,
          MlirType expressedType, double scale, int64_t zeroPoint,
          int64_t storageTypeMin, int64_t storageTypeMax) {
         return cls(mlirUniformQuantizedTypeGet(flags, storageType,
@@ -195,9 +194,9 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       },
       "Gets an instance of UniformQuantizedType in the same context as the "
       "provided storage type.",
-      py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
-      py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"),
-      py::arg("storage_type_min"), py::arg("storage_type_max"));
+      nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
+      nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"),
+      nb::arg("storage_type_min"), nb::arg("storage_type_max"));
   uniformQuantizedType.def_property_readonly(
       "scale",
       [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
@@ -221,12 +220,12 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       quantizedType.get_class());
   uniformQuantizedPerAxisType.def_classmethod(
       "get",
-      [](py::object cls, unsigned flags, MlirType storageType,
+      [](nb::object cls, unsigned flags, MlirType storageType,
          MlirType expressedType, std::vector<double> scales,
          std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
          int64_t storageTypeMin, int64_t storageTypeMax) {
         if (scales.size() != zeroPoints.size())
-          throw py::value_error(
+          throw nb::value_error(
               "Mismatching number of scales and zero points.");
         auto nDims = static_cast<intptr_t>(scales.size());
         return cls(mlirUniformQuantizedPerAxisTypeGet(
@@ -236,10 +235,10 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       },
       "Gets an instance of UniformQuantizedPerAxisType in the same context as "
       "the provided storage type.",
-      py::arg("cls"), py::arg("flags"), py::arg("storage_type"),
-      py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"),
-      py::arg("quantized_dimension"), py::arg("storage_type_min"),
-      py::arg("storage_type_max"));
+      nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
+      nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
+      nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
+      nb::arg("storage_type_max"));
   uniformQuantizedPerAxisType.def_property_readonly(
       "scales",
       [](MlirType type) {
@@ -294,13 +293,13 @@ static void populateDialectQuantSubmodule(const py::module &m) {
       quantizedType.get_class());
   calibratedQuantizedType.def_classmethod(
       "get",
-      [](py::object cls, MlirType expressedType, double min, double max) {
+      [](nb::object cls, MlirType expressedType, double min, double max) {
         return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
       },
       "Gets an instance of CalibratedQuantizedType in the same context as the "
       "provided expressed type.",
-      py::arg("cls"), py::arg("expressed_type"), py::arg("min"),
-      py::arg("max"));
+      nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"),
+      nb::arg("max"));
   calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
     return mlirCalibratedQuantizedTypeGetMin(type);
   });
@@ -309,7 +308,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
   });
 }
 
-PYBIND11_MODULE(_mlirDialectsQuant, m) {
+NB_MODULE(_mlirDialectsQuant, m) {
   m.doc() = "MLIR Quantization dialect";
 
   populateDialectQuantSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index a730bf500be98c..97cebcceebd9ad 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -6,32 +6,30 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <optional>
+#include <vector>
+
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/Dialect/SparseTensor.h"
 #include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
-#include <optional>
-#include <pybind11/cast.h>
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-#include <pybind11/pytypes.h>
-#include <vector>
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-namespace py = pybind11;
+namespace nb = nanobind;
 using namespace llvm;
 using namespace mlir;
-using namespace mlir::python::adaptors;
+using namespace mlir::python::nanobind_adaptors;
 
-static void populateDialectSparseTensorSubmodule(const py::module &m) {
-  py::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", py::module_local())
+static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
+  nb::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", nb::is_arithmetic(),
+                                         nb::is_flag())
       .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
       .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
       .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
       .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
       .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED);
 
-  py::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty",
-                                                     py::module_local())
+  nb::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty")
       .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED)
       .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE)
       .value("soa", MLIR_SPARSE_PROPERTY_SOA);
@@ -40,7 +38,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                           mlirAttributeIsASparseTensorEncodingAttr)
       .def_classmethod(
           "get",
-          [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
+          [](nb::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl,
              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
              std::optional<MlirAttribute> explicitVal,
@@ -52,24 +50,25 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                 crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
                 implicitVal ? *implicitVal : MlirAttribute{nullptr}));
           },
-          py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
-          py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
-          py::arg("explicit_val") = py::none(),
-          py::arg("implicit_val") = py::none(), py::arg("context") = py::none(),
+          nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(),
+          nb::arg("lvl_to_dim").none(), nb::arg("pos_width"),
+          nb::arg("crd_width"), nb::arg("explicit_val").none() = nb::none(),
+          nb::arg("implicit_val").none() = nb::none(),
+          nb::arg("context").none() = nb::none(),
           "Gets a sparse_tensor.encoding from parameters.")
       .def_classmethod(
           "build_level_type",
-          [](py::object cls, MlirSparseTensorLevelFormat lvlFmt,
+          [](nb::object cls, MlirSparseTensorLevelFormat lvlFmt,
              const std::vector<MlirSparseTensorLevelPropertyNondefault>
                  &properties,
              unsigned n, unsigned m) {
             return mlirSparseTensorEncodingAttrBuildLvlType(
                 lvlFmt, properties.data(), properties.size(), n, m);
           },
-          py::arg("cls"), py::arg("lvl_fmt"),
-          py::arg("properties") =
+          nb::arg("cls"), nb::arg("lvl_fmt"),
+          nb::arg("properties") =
               std::vector<MlirSparseTensorLevelPropertyNondefault>(),
-          py::arg("n") = 0, py::arg("m") = 0,
+          nb::arg("n") = 0, nb::arg("m") = 0,
           "Builds a sparse_tensor.encoding.level_type from parameters.")
       .def_property_readonly(
           "lvl_types",
@@ -143,7 +142,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
       });
 }
 
-PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
+NB_MODULE(_mlirDialectsSparseTensor, m) {
   m.doc() = "MLIR SparseTensor dialect.";
   populateDialectSparseTensorSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 6b57e652aa9d8b..59a030ac67f570 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -6,22 +6,20 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <string>
+
 #include "mlir-c/Dialect/Transform.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
-#include <pybind11/cast.h>
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-#include <pybind11/pytypes.h>
-#include <string>
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-namespace py = pybind11;
+namespace nb = nanobind;
 using namespace mlir;
 using namespace mlir::python;
-using namespace mlir::python::adaptors;
+using namespace mlir::python::nanobind_adaptors;
 
-void populateDialectTransformSubmodule(const pybind11::module &m) {
+void populateDialectTransformSubmodule(const nb::module_ &m) {
   //===-------------------------------------------------------------------===//
   // AnyOpType
   //===-------------------------------------------------------------------===//
@@ -31,11 +29,11 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
                          mlirTransformAnyOpTypeGetTypeID);
   anyOpType.def_classmethod(
       "get",
-      [](py::object cls, MlirContext ctx) {
+      [](nb::object cls, MlirContext ctx) {
         return cls(mlirTransformAnyOpTypeGet(ctx));
       },
-      "Get an instance of AnyOpType in the given context.", py::arg("cls"),
-      py::arg("context") = py::none());
+      "Get an instance of AnyOpType in the given context.", nb::arg("cls"),
+      nb::arg("context").none() = nb::none());
 
   //===-------------------------------------------------------------------===//
   // AnyParamType
@@ -46,11 +44,11 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
                          mlirTransformAnyParamTypeGetTypeID);
   anyParamType.def_classmethod(
       "get",
-      [](py::object cls, MlirContext ctx) {
+      [](nb::object cls, MlirContext ctx) {
         return cls(mlirTransformAnyParamTypeGet(ctx));
       },
-      "Get an instance of AnyParamType in the given context.", py::arg("cls"),
-      py::arg("context") = py::none());
+      "Get an instance of AnyParamType in the given context.", nb::arg("cls"),
+      nb::arg("context").none() = nb::none());
 
   //===-------------------------------------------------------------------===//
   // AnyValueType
@@ -61,11 +59,11 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
                          mlirTransformAnyValueTypeGetTypeID);
   anyValueType.def_classmethod(
       "get",
-      [](py::object cls, MlirContext ctx) {
+      [](nb::object cls, MlirContext ctx) {
         return cls(mlirTransformAnyValueTypeGet(ctx));
       },
-      "Get an instance of AnyValueType in the given context.", py::arg("cls"),
-      py::arg("context") = py::none());
+      "Get an instance of AnyValueType in the given context.", nb::arg("cls"),
+      nb::arg("context").none() = nb::none());
 
   //===-------------------------------------------------------------------===//
   // OperationType
@@ -76,21 +74,21 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
                          mlirTransformOperationTypeGetTypeID);
   operationType.def_classmethod(
       "get",
-      [](py::object cls, const std::string &operationName, MlirContext ctx) {
+      [](nb::object cls, const std::string &operationName, MlirContext ctx) {
         MlirStringRef cOperationName =
             mlirStringRefCreate(operationName.data(), operationName.size());
         return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
       },
       "Get an instance of OperationType for the given kind in the given "
       "context",
-      py::arg("cls"), py::arg("operation_name"),
-      py::arg("context") = py::none());
+      nb::arg("cls"), nb::arg("operation_name"),
+      nb::arg("context").none() = nb::none());
   operationType.def_property_readonly(
       "operation_name",
       [](MlirType type) {
         MlirStringRef operationName =
             mlirTransformOperationTypeGetOperationName(type);
-        return py::str(operationName.data, operationName.length);
+        return nb::str(operationName.data, operationName.length);
       },
       "Get the name of the payload operation accepted by the handle.");
 
@@ -103,11 +101,11 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
                          mlirTransformParamTypeGetTypeID);
   paramType.def_classmethod(
       "get",
-      [](py::object cls, MlirType type, MlirContext ctx) {
+      [](nb::object cls, MlirType type, MlirContext ctx) {
         return cls(mlirTransformParamTypeGet(ctx, type));
       },
       "Get an instance of ParamType for the given type in the given context.",
-      py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
+      nb::arg("cls"), nb::arg("type"), nb::arg("context").none() = nb::none());
   paramType.def_property_readonly(
       "type",
       [](MlirType type) {
@@ -117,7 +115,7 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
       "Get the type this ParamType is associated with.");
 }
 
-PYBIND11_MODULE(_mlirDialectsTransform, m) {
+NB_MODULE(_mlirDialectsTransform, m) {
   m.doc() = "MLIR Transform dialect.";
   populateDialectTransformSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index b3df30583fc963..81dada3553622b 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -7,9 +7,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/ExecutionEngine.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-namespace py = pybind11;
+namespace nb = nanobind;
 using namespace mlir;
 using namespace mlir::python;
 
@@ -34,23 +35,22 @@ class PyExecutionEngine {
     executionEngine.ptr = nullptr;
     referencedObjects.clear();
   }
-  pybind11::object getCapsule() {
-    return py::reinterpret_steal<py::object>(
-        mlirPythonExecutionEngineToCapsule(get()));
+  nb::object getCapsule() {
+    return nb::steal<nb::object>(mlirPythonExecutionEngineToCapsule(get()));
   }
 
   // Add an object to the list of referenced objects whose lifetime must exceed
   // those of the ExecutionEngine.
-  void addReferencedObject(const pybind11::object &obj) {
+  void addReferencedObject(const nb::object &obj) {
     referencedObjects.push_back(obj);
   }
 
-  static pybind11::object createFromCapsule(pybind11::object capsule) {
+  static nb::object createFromCapsule(nb::object capsule) {
     MlirExecutionEngine rawPm =
         mlirPythonCapsuleToExecutionEngine(capsule.ptr());
     if (mlirExecutionEngineIsNull(rawPm))
-      throw py::error_already_set();
-    return py::cast(PyExecutionEngine(rawPm), py::return_value_policy::move);
+      throw nb::python_error();
+    return nb::cast(PyExecutionEngine(rawPm), nb::rv_policy::move);
   }
 
 private:
@@ -58,44 +58,45 @@ class PyExecutionEngine {
   // We support Python ctypes closures as callbacks. Keep a list of the objects
   // so that they don't get garbage collected. (The ExecutionEngine itself
   // just holds raw pointers with no lifetime semantics).
-  std::vector<py::object> referencedObjects;
+  std::vector<nb::object> referencedObjects;
 };
 
 } // namespace
 
 /// Create the `mlir.execution_engine` module here.
-PYBIND11_MODULE(_mlirExecutionEngine, m) {
+NB_MODULE(_mlirExecutionEngine, m) {
   m.doc() = "MLIR Execution Engine";
 
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
-  py::class_<PyExecutionEngine>(m, "ExecutionEngine", py::module_local())
-      .def(py::init<>([](MlirModule module, int optLevel,
-                         const std::vector<std::string> &sharedLibPaths,
-                         bool enableObjectDump) {
-             llvm::SmallVector<MlirStringRef, 4> libPaths;
-             for (const std::string &path : sharedLibPaths)
-               libPaths.push_back({path.c_str(), path.length()});
-             MlirExecutionEngine executionEngine =
-                 mlirExecutionEngineCreate(module, optLevel, libPaths.size(),
-                                           libPaths.data(), enableObjectDump);
-             if (mlirExecutionEngineIsNull(executionEngine))
-               throw std::runtime_error(
-                   "Failure while creating the ExecutionEngine.");
-             return new PyExecutionEngine(executionEngine);
-           }),
-           py::arg("module"), py::arg("opt_level") = 2,
-           py::arg("shared_libs") = py::list(),
-           py::arg("enable_object_dump") = true,
-           "Create a new ExecutionEngine instance for the given Module. The "
-           "module must contain only dialects that can be translated to LLVM. "
-           "Perform transformations and code generation at the optimization "
-           "level `opt_level` if specified, or otherwise at the default "
-           "level of two (-O2). Load a list of libraries specified in "
-           "`shared_libs`.")
-      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
-                             &PyExecutionEngine::getCapsule)
+  nb::class_<PyExecutionEngine>(m, "ExecutionEngine")
+      .def(
+          "__init__",
+          [](PyExecutionEngine &self, MlirModule module, int optLevel,
+             const std::vector<std::string> &sharedLibPaths,
+             bool enableObjectDump) {
+            llvm::SmallVector<MlirStringRef, 4> libPaths;
+            for (const std::string &path : sharedLibPaths)
+              libPaths.push_back({path.c_str(), path.length()});
+            MlirExecutionEngine executionEngine =
+                mlirExecutionEngineCreate(module, optLevel, libPaths.size(),
+                                          libPaths.data(), enableObjectDump);
+            if (mlirExecutionEngineIsNull(executionEngine))
+              throw std::runtime_error(
+                  "Failure while creating the ExecutionEngine.");
+            new (&self) PyExecutionEngine(executionEngine);
+          },
+          nb::arg("module"), nb::arg("opt_level") = 2,
+          nb::arg("shared_libs") = nb::list(),
+          nb::arg("enable_object_dump") = true,
+          "Create a new ExecutionEngine instance for the given Module. The "
+          "module must contain only dialects that can be translated to LLVM. "
+          "Perform transformations and code generation at the optimization "
+          "level `opt_level` if specified, or otherwise at the default "
+          "level of two (-O2). Load a list of libraries specified in "
+          "`shared_libs`.")
+      .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyExecutionEngine::getCapsule)
       .def("_testing_release", &PyExecutionEngine::release,
            "Releases (leaks) the backing ExecutionEngine (for testing purpose)")
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule)
@@ -107,21 +108,21 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) {
                 mlirStringRefCreate(func.c_str(), func.size()));
             return reinterpret_cast<uintptr_t>(res);
           },
-          py::arg("func_name"),
+          nb::arg("func_name"),
           "Lookup function `func` in the ExecutionEngine.")
       .def(
           "raw_register_runtime",
           [](PyExecutionEngine &executionEngine, const std::string &name,
-             py::object callbackObj) {
+             nb::object callbackObj) {
             executionEngine.addReferencedObject(callbackObj);
             uintptr_t rawSym =
-                py::cast<uintptr_t>(py::getattr(callbackObj, "value"));
+                nb::cast<uintptr_t>(nb::getattr(callbackObj, "value"));
             mlirExecutionEngineRegisterSymbol(
                 executionEngine.get(),
                 mlirStringRefCreate(name.c_str(), name.size()),
                 reinterpret_cast<void *>(rawSym));
           },
-          py::arg("name"), py::arg("callback"),
+          nb::arg("name"), nb::arg("callback"),
           "Register `callback` as the runtime symbol `name`.")
       .def(
           "dump_to_object_file",
@@ -130,5 +131,5 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) {
                 executionEngine.get(),
                 mlirStringRefCreate(fileName.c_str(), fileName.size()));
           },
-          py::arg("file_name"), "Dump ExecutionEngine to an object file.");
+          nb::arg("file_name"), "Dump ExecutionEngine to an object file.");
 }
diff --git a/mlir/lib/Bindings/Python/GPUPasses.cpp b/mlir/lib/Bindings/Python/GPUPasses.cpp
index e276a3ce3a56a0..be474edbe9639a 100644
--- a/mlir/lib/Bindings/Python/GPUPasses.cpp
+++ b/mlir/lib/Bindings/Python/GPUPasses.cpp
@@ -8,14 +8,13 @@
 
 #include "mlir-c/Dialect/GPU.h"
 
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
+#include "mlir/Bindings/Python/Nanobind.h"
 
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirGPUPasses, m) {
+NB_MODULE(_mlirGPUPasses, m) {
   m.doc() = "MLIR GPU Dialect Passes";
 
   // Register all GPU passes on load.
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 2db690309fab8c..2a2d2a4cd0e8e8 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -6,10 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/string.h>
-#include <nanobind/stl/vector.h>
-
 #include <cstddef>
 #include <cstdint>
 #include <stdexcept>
@@ -23,6 +19,7 @@
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/IntegerSet.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/SmallVector.h"
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 779af09509748e..08f7d4881e137b 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -6,13 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <nanobind/nanobind.h>
-#include <nanobind/ndarray.h>
-#include <nanobind/stl/optional.h>
-#include <nanobind/stl/string.h>
-#include <nanobind/stl/string_view.h>
-#include <nanobind/stl/vector.h>
-
 #include <cstdint>
 #include <optional>
 #include <string>
@@ -24,6 +17,7 @@
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/raw_ostream.h"
 
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index e1c56a3984314f..cf5baf2848fd2b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,13 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/function.h>
-#include <nanobind/stl/optional.h>
-#include <nanobind/stl/string.h>
-#include <nanobind/stl/tuple.h>
-#include <nanobind/stl/vector.h>
-
 #include <optional>
 #include <utility>
 
@@ -25,6 +18,7 @@
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index c339a93e31857b..9e1fedaab52352 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -6,10 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/optional.h>
-#include <nanobind/stl/vector.h>
-
 #include <cstdint>
 #include <optional>
 #include <string>
@@ -21,6 +17,7 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/Interfaces.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 416a14218f125d..99d23970684880 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -8,9 +8,6 @@
 
 #include "IRModule.h"
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/string.h>
-
 #include <optional>
 #include <vector>
 
@@ -18,6 +15,7 @@
 #include "NanobindUtils.h"
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
 namespace nb = nanobind;
 using namespace mlir;
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index a242ff26bbbf57..8fb32a225e65f1 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -10,9 +10,6 @@
 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/string.h>
-
 #include <optional>
 #include <utility>
 #include <vector>
@@ -26,6 +23,7 @@
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Transforms.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/DenseMap.h"
 
 namespace mlir {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5cfa51142ac08f..0f2719c10a0275 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -11,12 +11,6 @@
 #include "mlir/Bindings/Python/IRTypes.h"
 // clang-format on
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/optional.h>
-#include <nanobind/stl/pair.h>
-#include <nanobind/stl/string.h>
-#include <nanobind/stl/vector.h>
-
 #include <optional>
 
 #include "IRModule.h"
diff --git a/mlir/lib/Bindings/Python/LinalgPasses.cpp b/mlir/lib/Bindings/Python/LinalgPasses.cpp
index 3f230207a42114..49f2ea94151a01 100644
--- a/mlir/lib/Bindings/Python/LinalgPasses.cpp
+++ b/mlir/lib/Bindings/Python/LinalgPasses.cpp
@@ -8,13 +8,13 @@
 
 #include "mlir-c/Dialect/Linalg.h"
 
-#include <pybind11/pybind11.h>
+#include "mlir/Bindings/Python/Nanobind.h"
 
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirLinalgPasses, m) {
+NB_MODULE(_mlirLinalgPasses, m) {
   m.doc() = "MLIR Linalg Dialect Passes";
 
   // Register all Linalg passes on load.
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index e5e64a921a79ad..7c4064262012ef 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,14 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/string.h>
 
 #include "Globals.h"
 #include "IRModule.h"
 #include "NanobindUtils.h"
 #include "Pass.h"
 #include "Rewrite.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
 namespace nb = nanobind;
 using namespace mlir;
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index 3b0f7f698b22d4..ee193cf9f8ef86 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -10,9 +10,8 @@
 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
 
-#include <nanobind/nanobind.h>
-
 #include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/DataTypes.h"
@@ -68,7 +67,7 @@ namespace detail {
 
 template <typename DefaultingTy>
 struct MlirDefaultingCaster {
-  NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription));
+  NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription))
 
   bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
     if (src.is_none()) {
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index b5dce4fe4128a5..15d0f5c14d761c 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,13 +8,10 @@
 
 #include "Pass.h"
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/optional.h>
-#include <nanobind/stl/string.h>
-
 #include "IRModule.h"
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/Pass.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp
index 6b2f6b0a6a3b86..3ba42bec5d80c3 100644
--- a/mlir/lib/Bindings/Python/RegisterEverything.cpp
+++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp
@@ -7,9 +7,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/RegisterEverything.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-PYBIND11_MODULE(_mlirRegisterEverything, m) {
+NB_MODULE(_mlirRegisterEverything, m) {
   m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration";
 
   m.def("register_dialects", [](MlirDialectRegistry registry) {
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index b2c1de4be9a69c..a3e59d8d05be6c 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,11 +8,10 @@
 
 #include "Rewrite.h"
 
-#include <nanobind/nanobind.h>
-
 #include "IRModule.h"
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/Rewrite.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Config/mlir-config.h"
 
 namespace nb = nanobind;
diff --git a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
index 2a8e2b802df9c4..8242f0973a446c 100644
--- a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
+++ b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
@@ -8,13 +8,13 @@
 
 #include "mlir-c/Dialect/SparseTensor.h"
 
-#include <pybind11/pybind11.h>
+#include "mlir/Bindings/Python/Nanobind.h"
 
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirSparseTensorPasses, m) {
+NB_MODULE(_mlirSparseTensorPasses, m) {
   m.doc() = "MLIR SparseTensor Dialect Passes";
 
   // Register all SparseTensor passes on load.
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 0c8c0e0a965aa7..f9b0fed62778f4 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -10,16 +10,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-
 #include "mlir-c/Dialect/Transform/Interpreter.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 
-namespace py = pybind11;
+namespace nb = nanobind;
 
 namespace {
 struct PyMlirTransformOptions {
@@ -36,10 +34,10 @@ struct PyMlirTransformOptions {
 };
 } // namespace
 
-static void populateTransformInterpreterSubmodule(py::module &m) {
-  py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
-      .def(py::init())
-      .def_property(
+static void populateTransformInterpreterSubmodule(nb::module_ &m) {
+  nb::class_<PyMlirTransformOptions>(m, "TransformOptions")
+      .def(nb::init<>())
+      .def_prop_rw(
           "expensive_checks",
           [](const PyMlirTransformOptions &self) {
             return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
@@ -47,7 +45,7 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
           [](PyMlirTransformOptions &self, bool value) {
             mlirTransformOptionsEnableExpensiveChecks(self.options, value);
           })
-      .def_property(
+      .def_prop_rw(
           "enforce_single_top_level_transform_op",
           [](const PyMlirTransformOptions &self) {
             return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
@@ -68,7 +66,7 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
         // Calling back into Python to invalidate everything under the payload
         // root. This is awkward, but we don't have access to PyMlirContext
         // object here otherwise.
-        py::object obj = py::cast(payloadRoot);
+        nb::object obj = nb::cast(payloadRoot);
         obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
 
         MlirLogicalResult result = mlirTransformApplyNamedSequence(
@@ -76,13 +74,14 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
         if (mlirLogicalResultIsSuccess(result))
           return;
 
-        throw py::value_error(
-            "Failed to apply named transform sequence.\nDiagnostic message " +
-            scope.takeMessage());
+        throw nb::value_error(
+            ("Failed to apply named transform sequence.\nDiagnostic message " +
+             scope.takeMessage())
+                .c_str());
       },
-      py::arg("payload_root"), py::arg("transform_root"),
-      py::arg("transform_module"),
-      py::arg("transform_options") = PyMlirTransformOptions());
+      nb::arg("payload_root"), nb::arg("transform_root"),
+      nb::arg("transform_module"),
+      nb::arg("transform_options") = PyMlirTransformOptions());
 
   m.def(
       "copy_symbols_and_merge_into",
@@ -92,15 +91,16 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
 
         MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
         if (mlirLogicalResultIsFailure(result)) {
-          throw py::value_error(
-              "Failed to merge symbols.\nDiagnostic message " +
-              scope.takeMessage());
+          throw nb::value_error(
+              ("Failed to merge symbols.\nDiagnostic message " +
+               scope.takeMessage())
+                  .c_str());
         }
       },
-      py::arg("target"), py::arg("other"));
+      nb::arg("target"), nb::arg("other"));
 }
 
-PYBIND11_MODULE(_mlirTransformInterpreter, m) {
+NB_MODULE(_mlirTransformInterpreter, m) {
   m.doc() = "MLIR Transform dialect interpreter functionality.";
   populateTransformInterpreterSubmodule(m);
 }
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 6d6b983128b80f..fb115a5f43423a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -476,9 +476,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     # Dialects
     MLIRCAPIFunc
 )
-if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
-set_target_properties(MLIRPythonExtension.Core PROPERTIES INTERFACE_COMPILE_OPTIONS "-Wno-cast-qual;-Wno-zero-length-array;-Wno-extra-semi;-Wno-nested-anon-types;-Wno-pedantic")
-endif()
 
 # This extension exposes an API to register all dialects, extensions, and passes
 # packaged in upstream MLIR and it is used for the upstream "mlir" Python
@@ -490,6 +487,7 @@ endif()
 declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
   MODULE_NAME _mlirRegisterEverything
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     RegisterEverything.cpp
   PRIVATE_LINK_LIBS
@@ -504,6 +502,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
   MODULE_NAME _mlirDialectsLinalg
   ADD_TO_PARENT MLIRPythonSources.Dialects.linalg
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     DialectLinalg.cpp
   PRIVATE_LINK_LIBS
@@ -517,6 +516,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
   MODULE_NAME _mlirDialectsGPU
   ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     DialectGPU.cpp
   PRIVATE_LINK_LIBS
@@ -530,6 +530,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
   MODULE_NAME _mlirDialectsLLVM
   ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     DialectLLVM.cpp
   PRIVATE_LINK_LIBS
@@ -543,6 +544,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
   MODULE_NAME _mlirDialectsQuant
   ADD_TO_PARENT MLIRPythonSources.Dialects.quant
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     DialectQuant.cpp
   PRIVATE_LINK_LIBS
@@ -556,6 +558,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
   MODULE_NAME _mlirDialectsNVGPU
   ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     DialectNVGPU.cpp
   PRIVATE_LINK_LIBS
@@ -569,6 +572,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
   MODULE_NAME _mlirDialectsPDL
   ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     DialectPDL.cpp
   PRIVATE_LINK_LIBS
@@ -582,6 +586,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
   MODULE_NAME _mlirDialectsSparseTensor
   ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     DialectSparseTensor.cpp
   PRIVATE_LINK_LIBS
@@ -595,6 +600,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
   MODULE_NAME _mlirDialectsTransform
   ADD_TO_PARENT MLIRPythonSources.Dialects.transform
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     DialectTransform.cpp
   PRIVATE_LINK_LIBS
@@ -608,6 +614,7 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
   MODULE_NAME _mlirAsyncPasses
   ADD_TO_PARENT MLIRPythonSources.Dialects.async
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     AsyncPasses.cpp
   PRIVATE_LINK_LIBS
@@ -621,6 +628,7 @@ if(MLIR_ENABLE_EXECUTION_ENGINE)
     MODULE_NAME _mlirExecutionEngine
     ADD_TO_PARENT MLIRPythonSources.ExecutionEngine
     ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
     SOURCES
       ExecutionEngineModule.cpp
     PRIVATE_LINK_LIBS
@@ -634,6 +642,7 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses
   MODULE_NAME _mlirGPUPasses
   ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     GPUPasses.cpp
   PRIVATE_LINK_LIBS
@@ -646,6 +655,7 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
   MODULE_NAME _mlirLinalgPasses
   ADD_TO_PARENT MLIRPythonSources.Dialects.linalg
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     LinalgPasses.cpp
   PRIVATE_LINK_LIBS
@@ -658,6 +668,7 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
   MODULE_NAME _mlirSparseTensorPasses
   ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     SparseTensorPasses.cpp
   PRIVATE_LINK_LIBS
@@ -670,6 +681,7 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
   MODULE_NAME _mlirTransformInterpreter
   ADD_TO_PARENT MLIRPythonSources.Dialects.transform
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
   SOURCES
     TransformInterpreter.cpp
   PRIVATE_LINK_LIBS
@@ -735,9 +747,6 @@ if(MLIR_INCLUDE_TESTS)
     EMBED_CAPI_LINK_LIBS
       MLIRCAPIPythonTestDialect
   )
-  if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
-    set_target_properties(MLIRPythonTestSources.PythonTestExtensionNanobind PROPERTIES INTERFACE_COMPILE_OPTIONS "-Wno-cast-qual;-Wno-zero-length-array;-Wno-extra-semi;-Wno-nested-anon-types;-Wno-pedantic")
-  endif()
 endif()
 
 ################################################################################
@@ -794,3 +803,4 @@ add_mlir_python_modules(MLIRPythonModules
   COMMON_CAPI_LINK_LIBS
     MLIRPythonCAPI
 )
+
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 656979f3d9a1df..c72a69830a1e8e 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -90,7 +90,7 @@ def testEncodingAttrStructure():
 
         # CHECK: lvl_types: [65536, 65536, 4406638542848]
         print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 2097152>]
+        # CHECK: lvl_formats_enum: [{{65536|LevelFormat.dense}}, {{65536|LevelFormat.dense}}, {{2097152|LevelFormat.n_out_of_m}}]
         print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
         # CHECK: structured_n: 2
         print(f"structured_n: {casted.structured_n}")
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 7c504d04be0d13..99c81eae97a0cf 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -8,13 +8,11 @@
 // This is the nanobind edition of the PythonTest dialect module.
 //===----------------------------------------------------------------------===//
 
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/vector.h>
-
 #include "PythonTestCAPI.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index d2ac43ef5bcff2..f1192d069fa5f5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1131,7 +1131,8 @@ cc_binary(
     deps = [
         ":CAPIIR",
         ":CAPILinalg",
-        ":MLIRBindingsPythonHeadersAndDeps",
+        ":MLIRBindingsPythonNanobindHeadersAndDeps",
+        "@nanobind",
     ],
 )
 
@@ -1145,8 +1146,8 @@ cc_binary(
     deps = [
         ":CAPIIR",
         ":CAPILLVM",
-        ":MLIRBindingsPythonHeadersAndDeps",
-        "@pybind11",
+        ":MLIRBindingsPythonNanobindHeadersAndDeps",
+        "@nanobind",
     ],
 )
 
@@ -1160,8 +1161,8 @@ cc_binary(
     deps = [
         ":CAPIIR",
         ":CAPIQuant",
-        ":MLIRBindingsPythonHeadersAndDeps",
-        "@pybind11",
+        ":MLIRBindingsPythonNanobindHeadersAndDeps",
+        "@nanobind",
     ],
 )
 
@@ -1175,8 +1176,8 @@ cc_binary(
     deps = [
         ":CAPIIR",
         ":CAPISparseTensor",
-        ":MLIRBindingsPythonHeadersAndDeps",
-        "@pybind11",
+        ":MLIRBindingsPythonNanobindHeadersAndDeps",
+        "@nanobind",
     ],
 )
 
@@ -1190,8 +1191,8 @@ cc_binary(
     linkstatic = 0,
     deps = [
         ":CAPIExecutionEngine",
-        ":MLIRBindingsPythonHeadersAndDeps",
-        "@pybind11",
+        ":MLIRBindingsPythonNanobindHeadersAndDeps",
+        "@nanobind",
         "@rules_python//python/cc:current_py_cc_headers",
     ],
 )
@@ -1206,8 +1207,8 @@ cc_binary(
     linkstatic = 0,
     deps = [
         ":CAPILinalg",
-        ":MLIRBindingsPythonHeadersAndDeps",
-        "@pybind11",
+        ":MLIRBindingsPythonNanobindHeadersAndDeps",
+        "@nanobind",
         "@rules_python//python/cc:current_py_cc_headers",
     ],
 )



More information about the llvm-commits mailing list