[llvm] [mlir] [mlir python] Port Python core code to nanobind. (PR #118583)
Peter Hawkins via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 5 17:15:39 PST 2024
https://github.com/hawkinsp updated https://github.com/llvm/llvm-project/pull/118583
>From cb0435bbbd6a116ce57ec35632ed0ac891abd4be Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins at google.com>
Date: Wed, 27 Nov 2024 20:17:40 +0000
Subject: [PATCH 1/2] [mlir python] Add nanobind support for standalone
dialects.
This PR allows out-of-tree dialects to write Python dialect modules using nanobind
instead of pybind11.
It may make sense to migrate in-tree dialects and some of the ODS Python
infrastructure to nanobind, but that is a topic for a future change.
This PR makes the following changes:
* adds nanobind to the CMake and Bazel build systems. We also add
robin_map to the Bazel build, which is a dependency of nanobind.
* adds a PYTHON_BINDING_LIBRARY option to various CMake functions, such
as declare_mlir_python_extension, allowing users to select a
Python binding library.
* creates a fork of mlir/include/mlir/Bindings/Python/PybindAdaptors.h
named NanobindAdaptors.h. This plays the same role, using nanobind
instead of pybind11.
* splits CollectDiagnosticsToStringScope out of PybindAdaptors.h and
into a new header mlir/include/mlir/Bindings/Python/Diagnostics.h, since
it is code that is no way related to pybind11 or for that matter,
Python.
* changed the standalone Python extension example to have both pybind11
and nanobind variants.
* changed mlir/python/mlir/dialects/python_test.py to have both pybind11
and nanobind variants.
Notes:
* A slightly unfortunate thing that I needed to do in the CMake
integration was to use FindPython in addition to FindPython3, since nanobind's CMake
integration expects the Python_ names for variables. Perhaps there's a
better way to do this.
---
mlir/cmake/modules/AddMLIRPython.cmake | 27 +-
mlir/cmake/modules/MLIRDetectPythonEnv.cmake | 39 +
mlir/docs/Bindings/Python.md | 20 +-
.../examples/standalone/python/CMakeLists.txt | 22 +-
.../python/StandaloneExtensionNanobind.cpp | 35 +
...on.cpp => StandaloneExtensionPybind11.cpp} | 7 +-
.../{standalone.py => standalone_nanobind.py} | 2 +-
.../dialects/standalone_pybind11.py | 6 +
.../standalone/test/python/smoketest.py | 14 +-
.../mlir/Bindings/Python/Diagnostics.h | 59 ++
.../mlir/Bindings/Python/NanobindAdaptors.h | 671 ++++++++++++++++++
.../mlir/Bindings/Python/PybindAdaptors.h | 43 +-
mlir/lib/Bindings/Python/DialectLLVM.cpp | 4 +-
.../Bindings/Python/TransformInterpreter.cpp | 7 +-
mlir/python/CMakeLists.txt | 23 +-
mlir/python/mlir/dialects/python_test.py | 17 +-
mlir/python/requirements.txt | 1 +
mlir/test/python/dialects/python_test.py | 59 +-
mlir/test/python/lib/CMakeLists.txt | 3 +-
.../python/lib/PythonTestModuleNanobind.cpp | 121 ++++
...odule.cpp => PythonTestModulePybind11.cpp} | 4 +-
utils/bazel/WORKSPACE | 18 +
.../llvm-project-overlay/mlir/BUILD.bazel | 50 +-
utils/bazel/third_party_build/nanobind.BUILD | 25 +
utils/bazel/third_party_build/robin_map.BUILD | 12 +
25 files changed, 1184 insertions(+), 105 deletions(-)
create mode 100644 mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
rename mlir/examples/standalone/python/{StandaloneExtension.cpp => StandaloneExtensionPybind11.cpp} (81%)
rename mlir/examples/standalone/python/mlir_standalone/dialects/{standalone.py => standalone_nanobind.py} (78%)
create mode 100644 mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py
create mode 100644 mlir/include/mlir/Bindings/Python/Diagnostics.h
create mode 100644 mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
create mode 100644 mlir/test/python/lib/PythonTestModuleNanobind.cpp
rename mlir/test/python/lib/{PythonTestModule.cpp => PythonTestModulePybind11.cpp} (96%)
create mode 100644 utils/bazel/third_party_build/nanobind.BUILD
create mode 100644 utils/bazel/third_party_build/robin_map.BUILD
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 7b91f43e2d57fd..67619a90c90be9 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -114,10 +114,11 @@ endfunction()
# EMBED_CAPI_LINK_LIBS: Dependent CAPI libraries that this extension depends
# on. These will be collected for all extensions and put into an
# aggregate dylib that is linked against.
+# PYTHON_BINDINGS_LIBRARY: Either pybind11 or nanobind.
function(declare_mlir_python_extension name)
cmake_parse_arguments(ARG
""
- "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT"
+ "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;PYTHON_BINDINGS_LIBRARY"
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
${ARGN})
@@ -126,15 +127,20 @@ function(declare_mlir_python_extension name)
endif()
set(_install_destination "src/python/${name}")
+ if(NOT ARG_PYTHON_BINDINGS_LIBRARY)
+ set(ARG_PYTHON_BINDINGS_LIBRARY "pybind11")
+ endif()
+
add_library(${name} INTERFACE)
set_target_properties(${name} PROPERTIES
# Yes: Leading-lowercase property names are load bearing and the recommended
# way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
- EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS"
+ EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS;mlir_python_BINDINGS_LIBRARY"
mlir_python_SOURCES_TYPE extension
mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}"
mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
mlir_python_DEPENDS ""
+ mlir_python_BINDINGS_LIBRARY "${ARG_PYTHON_BINDINGS_LIBRARY}"
)
# Set the interface source and link_libs properties of the target
@@ -223,12 +229,14 @@ function(add_mlir_python_modules name)
elseif(_source_type STREQUAL "extension")
# Native CPP extension.
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
+ get_target_property(_bindings_library ${sources_target} mlir_python_BINDINGS_LIBRARY)
# Transform relative source to based on root dir.
set(_extension_target "${modules_target}.extension.${_module_name}.dso")
add_mlir_python_extension(${_extension_target} "${_module_name}"
INSTALL_COMPONENT ${modules_target}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ PYTHON_BINDINGS_LIBRARY ${_bindings_library}
LINK_LIBS PRIVATE
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
@@ -634,7 +642,7 @@ endfunction()
function(add_mlir_python_extension libname extname)
cmake_parse_arguments(ARG
""
- "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY"
+ "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;PYTHON_BINDINGS_LIBRARY"
"SOURCES;LINK_LIBS"
${ARGN})
if(ARG_UNPARSED_ARGUMENTS)
@@ -644,9 +652,16 @@ function(add_mlir_python_extension libname extname)
# The actual extension library produces a shared-object or DLL and has
# sources that must be compiled in accordance with pybind11 needs (RTTI and
# exceptions).
- pybind11_add_module(${libname}
- ${ARG_SOURCES}
- )
+ if(NOT DEFINED ARG_PYTHON_BINDINGS_LIBRARY OR ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "pybind11")
+ pybind11_add_module(${libname}
+ ${ARG_SOURCES}
+ )
+ elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
+ nanobind_add_module(${libname}
+ NB_DOMAIN mlir
+ ${ARG_SOURCES}
+ )
+ endif()
# The extension itself must be compiled with RTTI and exceptions enabled.
# Also, some warning classes triggered by pybind11 are disabled.
diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
index 05397b7a1e1c75..c62ac7fa615ea6 100644
--- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
+++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
@@ -21,6 +21,12 @@ macro(mlir_configure_python_dev_packages)
find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION}
COMPONENTS Interpreter ${_python_development_component} REQUIRED)
+
+ # It's a little silly to detect Python a second time, but nanobind's cmake
+ # code looks for Python_ not Python3_.
+ find_package(Python ${LLVM_MINIMUM_PYTHON_VERSION}
+ COMPONENTS Interpreter ${_python_development_component} REQUIRED)
+
unset(_python_development_component)
message(STATUS "Found python include dirs: ${Python3_INCLUDE_DIRS}")
message(STATUS "Found python libraries: ${Python3_LIBRARIES}")
@@ -31,6 +37,13 @@ macro(mlir_configure_python_dev_packages)
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
"suffix = '${PYTHON_MODULE_SUFFIX}', "
"extension = '${PYTHON_MODULE_EXTENSION}")
+
+ mlir_detect_nanobind_install()
+ find_package(nanobind 2.2 CONFIG REQUIRED)
+ message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
+ message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
+ "suffix = '${PYTHON_MODULE_SUFFIX}', "
+ "extension = '${PYTHON_MODULE_EXTENSION}")
endif()
endmacro()
@@ -58,3 +71,29 @@ function(mlir_detect_pybind11_install)
set(pybind11_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
endif()
endfunction()
+
+
+# Detects a nanobind package installed in the current python environment
+# and sets variables to allow it to be found. This allows nanobind to be
+# installed via pip, which typically yields a much more recent version than
+# the OS install, which will be available otherwise.
+function(mlir_detect_nanobind_install)
+ if(nanobind_DIR)
+ message(STATUS "Using explicit nanobind cmake directory: ${nanobind_DIR} (-Dnanobind_DIR to change)")
+ else()
+ message(STATUS "Checking for nanobind in python path...")
+ execute_process(
+ COMMAND "${Python3_EXECUTABLE}"
+ -c "import nanobind;print(nanobind.cmake_dir(), end='')"
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ RESULT_VARIABLE STATUS
+ OUTPUT_VARIABLE PACKAGE_DIR
+ ERROR_QUIET)
+ if(NOT STATUS EQUAL "0")
+ message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)")
+ return()
+ endif()
+ message(STATUS "found (${PACKAGE_DIR})")
+ set(nanobind_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
+ endif()
+endfunction()
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 6e52c4deaad9aa..a0bd1cac118bad 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1138,12 +1138,14 @@ attributes and types must connect to the relevant C APIs for building and
inspection, which must be provided first. Bindings for `Attribute` and `Type`
subclasses can be defined using
[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)
-utilities that mimic pybind11 API for defining functions and properties. These
-bindings are to be included in a separate pybind11 module. The utilities also
-provide automatic casting between C API handles `MlirAttribute` and `MlirType`
-and their Python counterparts so that the C API handles can be used directly in
-binding implementations. The methods and properties provided by the bindings
-should follow the principles discussed above.
+or
+[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)
+utilities that mimic pybind11/nanobind API for defining functions and
+properties. These bindings are to be included in a separate module. The
+utilities also provide automatic casting between C API handles `MlirAttribute`
+and `MlirType` and their Python counterparts so that the C API handles can be
+used directly in binding implementations. The methods and properties provided by
+the bindings should follow the principles discussed above.
The attribute and type bindings for a dialect can be located in
`lib/Bindings/Python/Dialect<Name>.cpp` and should be compiled into a separate
@@ -1179,7 +1181,9 @@ make the passes available along with the dialect.
Dialect functionality other than IR objects or passes, such as helper functions,
can be exposed to Python similarly to attributes and types. C API is expected to
exist for this functionality, which can then be wrapped using pybind11 and
-`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`
+`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`,
+or nanobind and
+`[include/mlir/Bindings/Python/NanobindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)`
utilities to connect to the rest of Python API. The bindings can be located in a
-separate pybind11 module or in the same module as attributes and types, and
+separate module or in the same module as attributes and types, and
loaded along with the dialect.
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index a8c43827a5a375..69c82fd9135798 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -17,18 +17,32 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir_standalone"
TD_FILE dialects/StandaloneOps.td
SOURCES
- dialects/standalone.py
+ dialects/standalone_pybind11.py
+ dialects/standalone_nanobind.py
DIALECT_NAME standalone)
-declare_mlir_python_extension(StandalonePythonSources.Extension
- MODULE_NAME _standaloneDialects
+
+declare_mlir_python_extension(StandalonePythonSources.Pybind11Extension
+ MODULE_NAME _standaloneDialectsPybind11
+ ADD_TO_PARENT StandalonePythonSources
+ SOURCES
+ StandaloneExtensionPybind11.cpp
+ EMBED_CAPI_LINK_LIBS
+ StandaloneCAPI
+ PYTHON_BINDINGS_LIBRARY pybind11
+)
+
+declare_mlir_python_extension(StandalonePythonSources.NanobindExtension
+ MODULE_NAME _standaloneDialectsNanobind
ADD_TO_PARENT StandalonePythonSources
SOURCES
- StandaloneExtension.cpp
+ StandaloneExtensionNanobind.cpp
EMBED_CAPI_LINK_LIBS
StandaloneCAPI
+ PYTHON_BINDINGS_LIBRARY nanobind
)
+
################################################################################
# Common CAPI
################################################################################
diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
new file mode 100644
index 00000000000000..6d83dc585dcd1d
--- /dev/null
+++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
@@ -0,0 +1,35 @@
+//===- StandaloneExtension.cpp - Extension module -------------------------===//
+//
+// This is the nanobind version of the example module. There is also a pybind11
+// example in StandaloneExtensionPybind11.cpp.
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <nanobind/nanobind.h>
+
+#include "Standalone-c/Dialects.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+
+NB_MODULE(_standaloneDialectsNanobind, m) {
+ //===--------------------------------------------------------------------===//
+ // standalone dialect
+ //===--------------------------------------------------------------------===//
+ auto standaloneM = m.def_submodule("standalone");
+
+ standaloneM.def(
+ "register_dialect",
+ [](MlirContext context, bool load) {
+ MlirDialectHandle handle = mlirGetDialectHandle__standalone__();
+ mlirDialectHandleRegisterDialect(handle, context);
+ if (load) {
+ mlirDialectHandleLoadDialect(handle, context);
+ }
+ },
+ nb::arg("context").none() = nb::none(), nb::arg("load") = true);
+}
diff --git a/mlir/examples/standalone/python/StandaloneExtension.cpp b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
similarity index 81%
rename from mlir/examples/standalone/python/StandaloneExtension.cpp
rename to mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
index 5e83060cd48d82..397db4c20e7432 100644
--- a/mlir/examples/standalone/python/StandaloneExtension.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
@@ -1,4 +1,7 @@
-//===- StandaloneExtension.cpp - Extension module -------------------------===//
+//===- StandaloneExtensionPybind11.cpp - Extension module -----------------===//
+//
+// This is the pybind11 version of the example module. There is also a nanobind
+// example in StandaloneExtensionNanobind.cpp.
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -11,7 +14,7 @@
using namespace mlir::python::adaptors;
-PYBIND11_MODULE(_standaloneDialects, m) {
+PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
//===--------------------------------------------------------------------===//
// standalone dialect
//===--------------------------------------------------------------------===//
diff --git a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone.py b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_nanobind.py
similarity index 78%
rename from mlir/examples/standalone/python/mlir_standalone/dialects/standalone.py
rename to mlir/examples/standalone/python/mlir_standalone/dialects/standalone_nanobind.py
index c958b2ac193682..6218720951c82a 100644
--- a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone.py
+++ b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_nanobind.py
@@ -3,4 +3,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._standalone_ops_gen import *
-from .._mlir_libs._standaloneDialects.standalone import *
+from .._mlir_libs._standaloneDialectsNanobind.standalone import *
diff --git a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py
new file mode 100644
index 00000000000000..bfb98e404e13f2
--- /dev/null
+++ b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py
@@ -0,0 +1,6 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._standalone_ops_gen import *
+from .._mlir_libs._standaloneDialectsPybind11.standalone import *
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 08e08cbd2fe24c..bd40c65d161645 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -1,7 +1,17 @@
-# RUN: %python %s | FileCheck %s
+# RUN: %python %s pybind11 | FileCheck %s
+# RUN: %python %s nanobind | FileCheck %s
+import sys
from mlir_standalone.ir import *
-from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
+from mlir_standalone.dialects import builtin as builtin_d
+
+if sys.argv[1] == "pybind11":
+ from mlir_standalone.dialects import standalone_pybind11 as standalone_d
+elif sys.argv[1] == "nanobind":
+ from mlir_standalone.dialects import standalone_nanobind as standalone_d
+else:
+ raise ValueError("Expected either pybind11 or nanobind as arguments")
+
with Context():
standalone_d.register_dialect()
diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h
new file mode 100644
index 00000000000000..ea80e14dde0f3a
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h
@@ -0,0 +1,59 @@
+//===- Diagnostics.h - Helpers for diagnostics in Python bindings ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
+#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
+
+#include <cassert>
+#include <string>
+
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/IR.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+namespace python {
+
+/// RAII scope intercepting all diagnostics into a string. The message must be
+/// checked before this goes out of scope.
+class CollectDiagnosticsToStringScope {
+public:
+ explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
+ handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
+ /*deleteUserData=*/nullptr);
+ }
+ ~CollectDiagnosticsToStringScope() {
+ assert(errorMessage.empty() && "unchecked error message");
+ mlirContextDetachDiagnosticHandler(context, handlerID);
+ }
+
+ [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+
+private:
+ static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
+ auto printer = +[](MlirStringRef message, void *data) {
+ *static_cast<std::string *>(data) +=
+ llvm::StringRef(message.data, message.length);
+ };
+ MlirLocation loc = mlirDiagnosticGetLocation(diag);
+ *static_cast<std::string *>(data) += "at ";
+ mlirLocationPrint(loc, printer, data);
+ *static_cast<std::string *>(data) += ": ";
+ mlirDiagnosticPrint(diag, printer, data);
+ return mlirLogicalResultSuccess();
+ }
+
+ MlirContext context;
+ MlirDiagnosticHandlerID handlerID;
+ std::string errorMessage = "";
+};
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
new file mode 100644
index 00000000000000..5e01cebcb09c91
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -0,0 +1,671 @@
+//===- NanobindAdaptors.h - Interop with MLIR APIs via nanobind -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file contains adaptors for clients of the core MLIR Python APIs to
+// interop via MLIR CAPI types, using nanobind. The facilities here do not
+// depend on implementation details of the MLIR Python API and do not introduce
+// C++-level dependencies with it (requiring only Python and CAPI-level
+// dependencies).
+//
+// It is encouraged to be used both in-tree and out-of-tree. For in-tree use
+// cases, it should be used for dialect implementations (versus relying on
+// Pybind-based internals of the core libraries).
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
+#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
+
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
+
+#include <cstdint>
+
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/IR.h"
+#include "llvm/ADT/Twine.h"
+
+// Raw CAPI type casters need to be declared before use, so always include them
+// first.
+namespace nanobind {
+namespace detail {
+
+/// Helper to convert a presumed MLIR API object to a capsule, accepting either
+/// an explicit Capsule (which can happen when two C APIs are communicating
+/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
+/// attribute (through which supported MLIR Python API objects export their
+/// contained API pointer as a capsule). Throws a type error if the object is
+/// neither. This is intended to be used from type casters, which are invoked
+/// with a raw handle (unowned). The returned object's lifetime may not extend
+/// beyond the apiObject handle without explicitly having its refcount increased
+/// (i.e. on return).
+static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) {
+ if (PyCapsule_CheckExact(apiObject.ptr()))
+ return nanobind::borrow<nanobind::object>(apiObject);
+ if (!nanobind::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) {
+ std::string repr = nanobind::cast<std::string>(nanobind::repr(apiObject));
+ throw nanobind::type_error(
+ (llvm::Twine("Expected an MLIR object (got ") + repr + ").")
+ .str()
+ .c_str());
+ }
+ return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
+}
+
+// Note: Currently all of the following support cast from nanobind::object to
+// the Mlir* C-API type, but only a few light-weight, context-bound ones
+// implicitly cast the other way because the use case has not yet emerged and
+// ownership is unclear.
+
+/// Casts object <-> MlirAffineMap.
+template <>
+struct type_caster<MlirAffineMap> {
+ NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToAffineMap(capsule.ptr());
+ if (mlirAffineMapIsNull(value)) {
+ return false;
+ }
+ return !mlirAffineMapIsNull(value);
+ }
+ static handle from_cpp(MlirAffineMap v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ nanobind::object capsule =
+ nanobind::steal<nanobind::object>(mlirPythonAffineMapToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("AffineMap")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ }
+};
+
+/// Casts object <-> MlirAttribute.
+template <>
+struct type_caster<MlirAttribute> {
+ NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToAttribute(capsule.ptr());
+ return !mlirAttributeIsNull(value);
+ }
+ static handle from_cpp(MlirAttribute v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ nanobind::object capsule =
+ nanobind::steal<nanobind::object>(mlirPythonAttributeToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Attribute")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
+ .release();
+ }
+};
+
+/// Casts object -> MlirBlock.
+template <>
+struct type_caster<MlirBlock> {
+ NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToBlock(capsule.ptr());
+ return !mlirBlockIsNull(value);
+ }
+};
+
+/// Casts object -> MlirContext.
+template <>
+struct type_caster<MlirContext> {
+ NB_TYPE_CASTER(MlirContext, const_name("MlirContext"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ if (src.is_none()) {
+ // Gets the current thread-bound context.
+ // TODO: This raises an error of "No current context" currently.
+ // Update the implementation to pretty-print the helpful error that the
+ // core implementations print in this case.
+ src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Context")
+ .attr("current");
+ }
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToContext(capsule.ptr());
+ return !mlirContextIsNull(value);
+ }
+};
+
+/// Casts object <-> MlirDialectRegistry.
+template <>
+struct type_caster<MlirDialectRegistry> {
+ NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
+ return !mlirDialectRegistryIsNull(value);
+ }
+ static handle from_cpp(MlirDialectRegistry v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ nanobind::object capsule = nanobind::steal<nanobind::object>(
+ mlirPythonDialectRegistryToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("DialectRegistry")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ }
+};
+
+/// Casts object <-> MlirLocation.
+template <>
+struct type_caster<MlirLocation> {
+ NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ if (src.is_none()) {
+ // Gets the current thread-bound context.
+ src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Location")
+ .attr("current");
+ }
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToLocation(capsule.ptr());
+ return !mlirLocationIsNull(value);
+ }
+ static handle from_cpp(MlirLocation v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ nanobind::object capsule =
+ nanobind::steal<nanobind::object>(mlirPythonLocationToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Location")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ }
+};
+
+/// Casts object <-> MlirModule.
+template <>
+struct type_caster<MlirModule> {
+ NB_TYPE_CASTER(MlirModule, const_name("MlirModule"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToModule(capsule.ptr());
+ return !mlirModuleIsNull(value);
+ }
+ static handle from_cpp(MlirModule v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ nanobind::object capsule =
+ nanobind::steal<nanobind::object>(mlirPythonModuleToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Module")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
+/// Casts object <-> MlirFrozenRewritePatternSet.
+template <>
+struct type_caster<MlirFrozenRewritePatternSet> {
+ NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
+ const_name("MlirFrozenRewritePatternSet"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+ return value.ptr != nullptr;
+ }
+ static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy, handle) {
+ nanobind::object capsule = nanobind::steal<nanobind::object>(
+ mlirPythonFrozenRewritePatternSetToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
+ .attr("FrozenRewritePatternSet")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
+/// Casts object <-> MlirOperation.
+template <>
+struct type_caster<MlirOperation> {
+ NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToOperation(capsule.ptr());
+ return !mlirOperationIsNull(value);
+ }
+ static handle from_cpp(MlirOperation v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ if (v.ptr == nullptr)
+ return nanobind::none();
+ nanobind::object capsule =
+ nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Operation")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
+/// Casts object <-> MlirValue.
+template <>
+struct type_caster<MlirValue> {
+ NB_TYPE_CASTER(MlirValue, const_name("MlirValue"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToValue(capsule.ptr());
+ return !mlirValueIsNull(value);
+ }
+ static handle from_cpp(MlirValue v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ if (v.ptr == nullptr)
+ return nanobind::none();
+ nanobind::object capsule =
+ nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Value")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
+ .release();
+ };
+};
+
+/// Casts object -> MlirPassManager.
+template <>
+struct type_caster<MlirPassManager> {
+ NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToPassManager(capsule.ptr());
+ return !mlirPassManagerIsNull(value);
+ }
+};
+
+/// Casts object <-> MlirTypeID.
+template <>
+struct type_caster<MlirTypeID> {
+ NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToTypeID(capsule.ptr());
+ return !mlirTypeIDIsNull(value);
+ }
+ static handle from_cpp(MlirTypeID v, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ if (v.ptr == nullptr)
+ return nanobind::none();
+ nanobind::object capsule =
+ nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("TypeID")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
+/// Casts object <-> MlirType.
+template <>
+struct type_caster<MlirType> {
+ NB_TYPE_CASTER(MlirType, const_name("MlirType"));
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
+ nanobind::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToType(capsule.ptr());
+ return !mlirTypeIsNull(value);
+ }
+ static handle from_cpp(MlirType t, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ nanobind::object capsule =
+ nanobind::steal<nanobind::object>(mlirPythonTypeToCapsule(t));
+ return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Type")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
+ .release();
+ }
+};
+
+} // namespace detail
+} // namespace nanobind
+
+namespace mlir {
+namespace python {
+namespace nanobind_adaptors {
+
+/// Provides a facility like nanobind::class_ for defining a new class in a
+/// scope, but this allows extension of an arbitrary Python class, defining
+/// methods on it is a similar way. Classes defined in this way are very similar
+/// to if defined in Python in the usual way but use nanobind machinery to
+/// do it. These are not "real" nanobind classes but pure Python classes
+/// with no relation to a concrete C++ class.
+///
+/// Derived from a discussion upstream:
+/// https://github.com/pybind/pybind11/issues/1193
+/// (plus a fair amount of extra curricular poking)
+/// TODO: If this proves useful, see about including it in nanobind.
+class pure_subclass {
+public:
+ pure_subclass(nanobind::handle scope, const char *derivedClassName,
+ const nanobind::object &superClass) {
+ nanobind::object pyType =
+ nanobind::borrow<nanobind::object>((PyObject *)&PyType_Type);
+ nanobind::object metaclass = pyType(superClass);
+ nanobind::dict attributes;
+
+ thisClass = metaclass(derivedClassName, nanobind::make_tuple(superClass),
+ attributes);
+ scope.attr(derivedClassName) = thisClass;
+ }
+
+ template <typename Func, typename... Extra>
+ pure_subclass &def(const char *name, Func &&f, const Extra &...extra) {
+ nanobind::object cf = nanobind::cpp_function(
+ std::forward<Func>(f), nanobind::name(name), nanobind::is_method(),
+ nanobind::scope(thisClass), extra...);
+ thisClass.attr(name) = cf;
+ return *this;
+ }
+
+ template <typename Func, typename... Extra>
+ pure_subclass &def_property_readonly(const char *name, Func &&f,
+ const Extra &...extra) {
+ nanobind::object cf = nanobind::cpp_function(
+ std::forward<Func>(f), nanobind::name(name), nanobind::is_method(),
+ nanobind::scope(thisClass), extra...);
+ auto builtinProperty =
+ nanobind::borrow<nanobind::object>((PyObject *)&PyProperty_Type);
+ thisClass.attr(name) = builtinProperty(cf);
+ return *this;
+ }
+
+ template <typename Func, typename... Extra>
+ pure_subclass &def_staticmethod(const char *name, Func &&f,
+ const Extra &...extra) {
+ static_assert(!std::is_member_function_pointer<Func>::value,
+ "def_staticmethod(...) called with a non-static member "
+ "function pointer");
+ nanobind::object cf = nanobind::cpp_function(
+ std::forward<Func>(f),
+ nanobind::name(name), // nanobind::scope(thisClass),
+ extra...);
+ thisClass.attr(name) = cf;
+ return *this;
+ }
+
+ template <typename Func, typename... Extra>
+ pure_subclass &def_classmethod(const char *name, Func &&f,
+ const Extra &...extra) {
+ static_assert(!std::is_member_function_pointer<Func>::value,
+ "def_classmethod(...) called with a non-static member "
+ "function pointer");
+ nanobind::object cf = nanobind::cpp_function(
+ std::forward<Func>(f),
+ nanobind::name(name), // nanobind::scope(thisClass),
+ extra...);
+ thisClass.attr(name) =
+ nanobind::borrow<nanobind::object>(PyClassMethod_New(cf.ptr()));
+ return *this;
+ }
+
+ nanobind::object get_class() const { return thisClass; }
+
+protected:
+ nanobind::object superClass;
+ nanobind::object thisClass;
+};
+
+/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting
+/// constructor and type checking methods.
+class mlir_attribute_subclass : public pure_subclass {
+public:
+ using IsAFunctionTy = bool (*)(MlirAttribute);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+
+ /// Subclasses by looking up the super-class dynamically.
+ mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName,
+ IsAFunctionTy isaFunction,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ : mlir_attribute_subclass(
+ scope, attrClassName, isaFunction,
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Attribute"),
+ getTypeIDFunction) {}
+
+ /// Subclasses with a provided mlir.ir.Attribute super-class. This must
+ /// be used if the subclass is being defined in the same extension module
+ /// as the mlir.ir class (otherwise, it will trigger a recursive
+ /// initialization).
+ mlir_attribute_subclass(nanobind::handle scope, const char *typeClassName,
+ IsAFunctionTy isaFunction,
+ const nanobind::object &superCls,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ : pure_subclass(scope, typeClassName, superCls) {
+ // Casting constructor. Note that it hard, if not impossible, to properly
+ // call chain to parent `__init__` in nanobind due to its special handling
+ // for init functions that don't have a fully constructed self-reference,
+ // which makes it impossible to forward it to `__init__` of a superclass.
+ // Instead, provide a custom `__new__` and call that of a superclass, which
+ // eventually calls `__init__` of the superclass. Since attribute subclasses
+ // have no additional members, we can just return the instance thus created
+ // without amending it.
+ std::string captureTypeName(
+ typeClassName); // As string in case if typeClassName is not static.
+ nanobind::object newCf = nanobind::cpp_function(
+ [superCls, isaFunction, captureTypeName](
+ nanobind::object cls, nanobind::object otherAttribute) {
+ MlirAttribute rawAttribute =
+ nanobind::cast<MlirAttribute>(otherAttribute);
+ if (!isaFunction(rawAttribute)) {
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(otherAttribute));
+ throw std::invalid_argument(
+ (llvm::Twine("Cannot cast attribute to ") + captureTypeName +
+ " (from " + origRepr + ")")
+ .str());
+ }
+ nanobind::object self = superCls.attr("__new__")(cls, otherAttribute);
+ return self;
+ },
+ nanobind::name("__new__"), nanobind::arg("cls"),
+ nanobind::arg("cast_from_attr"));
+ thisClass.attr("__new__") = newCf;
+
+ // 'isinstance' method.
+ def_staticmethod(
+ "isinstance",
+ [isaFunction](MlirAttribute other) { return isaFunction(other); },
+ nanobind::arg("other_attribute"));
+ def("__repr__", [superCls, captureTypeName](nanobind::object self) {
+ return nanobind::repr(superCls(self))
+ .attr("replace")(superCls.attr("__name__"), captureTypeName);
+ });
+ if (getTypeIDFunction) {
+ def_staticmethod("get_static_typeid",
+ [getTypeIDFunction]() { return getTypeIDFunction(); });
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ getTypeIDFunction())(nanobind::cpp_function(
+ [thisClass = thisClass](const nanobind::object &mlirAttribute) {
+ return thisClass(mlirAttribute);
+ }));
+ }
+ }
+};
+
+/// Creates a custom subclass of mlir.ir.Type, implementing a casting
+/// constructor and type checking methods.
+class mlir_type_subclass : public pure_subclass {
+public:
+ using IsAFunctionTy = bool (*)(MlirType);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+
+ /// Subclasses by looking up the super-class dynamically.
+ mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
+ IsAFunctionTy isaFunction,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ : mlir_type_subclass(
+ scope, typeClassName, isaFunction,
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Type"),
+ getTypeIDFunction) {}
+
+ /// Subclasses with a provided mlir.ir.Type super-class. This must
+ /// be used if the subclass is being defined in the same extension module
+ /// as the mlir.ir class (otherwise, it will trigger a recursive
+ /// initialization).
+ mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
+ IsAFunctionTy isaFunction,
+ const nanobind::object &superCls,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ : pure_subclass(scope, typeClassName, superCls) {
+ // Casting constructor. Note that it hard, if not impossible, to properly
+ // call chain to parent `__init__` in nanobind due to its special handling
+ // for init functions that don't have a fully constructed self-reference,
+ // which makes it impossible to forward it to `__init__` of a superclass.
+ // Instead, provide a custom `__new__` and call that of a superclass, which
+ // eventually calls `__init__` of the superclass. Since attribute subclasses
+ // have no additional members, we can just return the instance thus created
+ // without amending it.
+ std::string captureTypeName(
+ typeClassName); // As string in case if typeClassName is not static.
+ nanobind::object newCf = nanobind::cpp_function(
+ [superCls, isaFunction, captureTypeName](nanobind::object cls,
+ nanobind::object otherType) {
+ MlirType rawType = nanobind::cast<MlirType>(otherType);
+ if (!isaFunction(rawType)) {
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(otherType));
+ throw std::invalid_argument((llvm::Twine("Cannot cast type to ") +
+ captureTypeName + " (from " +
+ origRepr + ")")
+ .str());
+ }
+ nanobind::object self = superCls.attr("__new__")(cls, otherType);
+ return self;
+ },
+ nanobind::name("__new__"), nanobind::arg("cls"),
+ nanobind::arg("cast_from_type"));
+ thisClass.attr("__new__") = newCf;
+
+ // 'isinstance' method.
+ def_staticmethod(
+ "isinstance",
+ [isaFunction](MlirType other) { return isaFunction(other); },
+ nanobind::arg("other_type"));
+ def("__repr__", [superCls, captureTypeName](nanobind::object self) {
+ return nanobind::repr(superCls(self))
+ .attr("replace")(superCls.attr("__name__"), captureTypeName);
+ });
+ if (getTypeIDFunction) {
+ // 'get_static_typeid' method.
+ // This is modeled as a static method instead of a static property because
+ // `def_property_readonly_static` is not available in `pure_subclass` and
+ // we do not want to introduce the complexity that pybind uses to
+ // implement it.
+ def_staticmethod("get_static_typeid",
+ [getTypeIDFunction]() { return getTypeIDFunction(); });
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ getTypeIDFunction())(nanobind::cpp_function(
+ [thisClass = thisClass](const nanobind::object &mlirType) {
+ return thisClass(mlirType);
+ }));
+ }
+ }
+};
+
+/// Creates a custom subclass of mlir.ir.Value, implementing a casting
+/// constructor and type checking methods.
+class mlir_value_subclass : public pure_subclass {
+public:
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ /// Subclasses by looking up the super-class dynamically.
+ mlir_value_subclass(nanobind::handle scope, const char *valueClassName,
+ IsAFunctionTy isaFunction)
+ : mlir_value_subclass(
+ scope, valueClassName, isaFunction,
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Value")) {}
+
+ /// Subclasses with a provided mlir.ir.Value super-class. This must
+ /// be used if the subclass is being defined in the same extension module
+ /// as the mlir.ir class (otherwise, it will trigger a recursive
+ /// initialization).
+ mlir_value_subclass(nanobind::handle scope, const char *valueClassName,
+ IsAFunctionTy isaFunction,
+ const nanobind::object &superCls)
+ : pure_subclass(scope, valueClassName, superCls) {
+ // Casting constructor. Note that it hard, if not impossible, to properly
+ // call chain to parent `__init__` in nanobind due to its special handling
+ // for init functions that don't have a fully constructed self-reference,
+ // which makes it impossible to forward it to `__init__` of a superclass.
+ // Instead, provide a custom `__new__` and call that of a superclass, which
+ // eventually calls `__init__` of the superclass. Since attribute subclasses
+ // have no additional members, we can just return the instance thus created
+ // without amending it.
+ std::string captureValueName(
+ valueClassName); // As string in case if valueClassName is not static.
+ nanobind::object newCf = nanobind::cpp_function(
+ [superCls, isaFunction, captureValueName](nanobind::object cls,
+ nanobind::object otherValue) {
+ MlirValue rawValue = nanobind::cast<MlirValue>(otherValue);
+ if (!isaFunction(rawValue)) {
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(otherValue));
+ throw std::invalid_argument((llvm::Twine("Cannot cast value to ") +
+ captureValueName + " (from " +
+ origRepr + ")")
+ .str());
+ }
+ nanobind::object self = superCls.attr("__new__")(cls, otherValue);
+ return self;
+ },
+ nanobind::name("__new__"), nanobind::arg("cls"),
+ nanobind::arg("cast_from_value"));
+ thisClass.attr("__new__") = newCf;
+
+ // 'isinstance' method.
+ def_staticmethod(
+ "isinstance",
+ [isaFunction](MlirValue other) { return isaFunction(other); },
+ nanobind::arg("other_value"));
+ }
+};
+
+} // namespace nanobind_adaptors
+
+/// RAII scope intercepting all diagnostics into a string. The message must be
+/// checked before this goes out of scope.
+class CollectDiagnosticsToStringScope {
+public:
+ explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
+ handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
+ /*deleteUserData=*/nullptr);
+ }
+ ~CollectDiagnosticsToStringScope() {
+ assert(errorMessage.empty() && "unchecked error message");
+ mlirContextDetachDiagnosticHandler(context, handlerID);
+ }
+
+ [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+
+private:
+ static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
+ auto printer = +[](MlirStringRef message, void *data) {
+ *static_cast<std::string *>(data) +=
+ llvm::StringRef(message.data, message.length);
+ };
+ MlirLocation loc = mlirDiagnosticGetLocation(diag);
+ *static_cast<std::string *>(data) += "at ";
+ mlirLocationPrint(loc, printer, data);
+ *static_cast<std::string *>(data) += ": ";
+ mlirDiagnosticPrint(diag, printer, data);
+ return mlirLogicalResultSuccess();
+ }
+
+ MlirContext context;
+ MlirDiagnosticHandlerID handlerID;
+ std::string errorMessage = "";
+};
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index df4b9bf713592d..c8233355d1d67b 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -1,4 +1,4 @@
-//===- PybindAdaptors.h - Adaptors for interop with MLIR APIs -------------===//
+//===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,10 @@
//
//===----------------------------------------------------------------------===//
// This file contains adaptors for clients of the core MLIR Python APIs to
-// interop via MLIR CAPI types. The facilities here do not depend on
-// implementation details of the MLIR Python API and do not introduce C++-level
-// dependencies with it (requiring only Python and CAPI-level dependencies).
+// interop via MLIR CAPI types, using pybind11. The facilities here do not
+// depend on implementation details of the MLIR Python API and do not introduce
+// C++-level dependencies with it (requiring only Python and CAPI-level
+// dependencies).
//
// It is encouraged to be used both in-tree and out-of-tree. For in-tree use
// cases, it should be used for dialect implementations (versus relying on
@@ -611,40 +612,6 @@ class mlir_value_subclass : public pure_subclass {
} // namespace adaptors
-/// RAII scope intercepting all diagnostics into a string. The message must be
-/// checked before this goes out of scope.
-class CollectDiagnosticsToStringScope {
-public:
- explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
- handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
- /*deleteUserData=*/nullptr);
- }
- ~CollectDiagnosticsToStringScope() {
- assert(errorMessage.empty() && "unchecked error message");
- mlirContextDetachDiagnosticHandler(context, handlerID);
- }
-
- [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
-
-private:
- static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
- auto printer = +[](MlirStringRef message, void *data) {
- *static_cast<std::string *>(data) +=
- llvm::StringRef(message.data, message.length);
- };
- MlirLocation loc = mlirDiagnosticGetLocation(diag);
- *static_cast<std::string *>(data) += "at ";
- mlirLocationPrint(loc, printer, data);
- *static_cast<std::string *>(data) += ": ";
- mlirDiagnosticPrint(diag, printer, data);
- return mlirLogicalResultSuccess();
- }
-
- MlirContext context;
- MlirDiagnosticHandlerID handlerID;
- std::string errorMessage = "";
-};
-
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 42a4c8c0793ba8..cccf1370b8cc87 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -6,11 +6,13 @@
//
//===----------------------------------------------------------------------===//
+#include <string>
+
#include "mlir-c/Dialect/LLVM.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
-#include <string>
namespace py = pybind11;
using namespace llvm;
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f6b4532b1b6be4..0c8c0e0a965aa7 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -10,14 +10,15 @@
//
//===----------------------------------------------------------------------===//
+#include <pybind11/detail/common.h>
+#include <pybind11/pybind11.h>
+
#include "mlir-c/Dialect/Transform/Interpreter.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-
namespace py = pybind11;
namespace {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 23187f256455bb..e1b870b53ad25c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -683,7 +683,9 @@ if(MLIR_INCLUDE_TESTS)
MLIRPythonTestSources.Dialects.PythonTest
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
ADD_TO_PARENT MLIRPythonTestSources.Dialects
- SOURCES dialects/python_test.py)
+ SOURCES
+ dialects/python_test.py
+ )
set(LLVM_TARGET_DEFINITIONS
"${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td")
mlir_tablegen(
@@ -697,12 +699,25 @@ if(MLIR_INCLUDE_TESTS)
ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest
SOURCES "dialects/_python_test_ops_gen.py")
- declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension
- MODULE_NAME _mlirPythonTest
+ declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionPybind11
+ MODULE_NAME _mlirPythonTestPybind11
+ ADD_TO_PARENT MLIRPythonTestSources.Dialects
+ ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
+ PYTHON_BINDINGS_LIBRARY pybind11
+ SOURCES
+ PythonTestModulePybind11.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIPythonTestDialect
+ )
+ declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionNanobind
+ MODULE_NAME _mlirPythonTestNanobind
ADD_TO_PARENT MLIRPythonTestSources.Dialects
ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
- PythonTestModule.cpp
+ PythonTestModuleNanobind.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index b5baa80bc767fb..9380896c8c06e8 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,15 +3,14 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import (
- TestAttr,
- TestType,
- TestTensorValue,
- TestIntegerRankedTensorType,
-)
-def register_python_test_dialect(registry):
- from .._mlir_libs import _mlirPythonTest
+def register_python_test_dialect(registry, use_nanobind):
+ if use_nanobind:
+ from .._mlir_libs import _mlirPythonTestNanobind
- _mlirPythonTest.register_dialect(registry)
+ _mlirPythonTestNanobind.register_dialect(registry)
+ else:
+ from .._mlir_libs import _mlirPythonTestPybind11
+
+ _mlirPythonTestPybind11.register_dialect(registry)
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index 272d066831f927..ab8a9122919e19 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,3 +1,4 @@
+nanobind>=2.0, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 948d1225ea489c..fd678f8321fd93 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,12 +1,33 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: %PYTHON %s pybind11 | FileCheck %s
+# RUN: %PYTHON %s nanobind | FileCheck %s
+import sys
from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
-test.register_python_test_dialect(get_dialect_registry())
+if sys.argv[1] == "pybind11":
+ from mlir._mlir_libs._mlirPythonTestPybind11 import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+ )
+
+ test.register_python_test_dialect(get_dialect_registry(), use_nanobind=False)
+elif sys.argv[1] == "nanobind":
+ from mlir._mlir_libs._mlirPythonTestNanobind import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+ )
+
+ test.register_python_test_dialect(get_dialect_registry(), use_nanobind=True)
+else:
+ raise ValueError("Expected pybind11 or nanobind as argument")
def run(f):
@@ -308,7 +329,7 @@ def testOptionalOperandOp():
@run
def testCustomAttribute():
with Context() as ctx, Location.unknown():
- a = test.TestAttr.get()
+ a = TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
@@ -325,11 +346,11 @@ def testCustomAttribute():
print(repr(op2.test_attr))
# The following cast must not assert.
- b = test.TestAttr(a)
+ b = TestAttr(a)
unit = UnitAttr.get()
try:
- test.TestAttr(unit)
+ TestAttr(unit)
except ValueError as e:
assert "Cannot cast attribute to TestAttr" in str(e)
else:
@@ -338,7 +359,7 @@ def testCustomAttribute():
# The following must trigger a TypeError from our adaptors and must not
# crash.
try:
- test.TestAttr(42)
+ TestAttr(42)
except TypeError as e:
assert "Expected an MLIR object" in str(e)
else:
@@ -347,7 +368,7 @@ def testCustomAttribute():
# The following must trigger a TypeError from pybind (therefore, not
# checking its message) and must not crash.
try:
- test.TestAttr(42, 56)
+ TestAttr(42, 56)
except TypeError:
pass
else:
@@ -357,12 +378,12 @@ def testCustomAttribute():
@run
def testCustomType():
with Context() as ctx:
- a = test.TestType.get()
+ a = TestType.get()
# CHECK: !python_test.test_type
print(a)
# The following cast must not assert.
- b = test.TestType(a)
+ b = TestType(a)
# Instance custom types should have typeids
assert isinstance(b.typeid, TypeID)
# Subclasses of ir.Type should not have a static_typeid
@@ -374,7 +395,7 @@ def testCustomType():
i8 = IntegerType.get_signless(8)
try:
- test.TestType(i8)
+ TestType(i8)
except ValueError as e:
assert "Cannot cast type to TestType" in str(e)
else:
@@ -383,7 +404,7 @@ def testCustomType():
# The following must trigger a TypeError from our adaptors and must not
# crash.
try:
- test.TestType(42)
+ TestType(42)
except TypeError as e:
assert "Expected an MLIR object" in str(e)
else:
@@ -392,7 +413,7 @@ def testCustomType():
# The following must trigger a TypeError from pybind (therefore, not
# checking its message) and must not crash.
try:
- test.TestType(42, 56)
+ TestType(42, 56)
except TypeError:
pass
else:
@@ -405,7 +426,7 @@ def testTensorValue():
with Context() as ctx, Location.unknown():
i8 = IntegerType.get_signless(8)
- class Tensor(test.TestTensorValue):
+ class Tensor(TestTensorValue):
def __str__(self):
return super().__str__().replace("Value", "Tensor")
@@ -425,9 +446,9 @@ def __str__(self):
# Classes of custom types that inherit from concrete types should have
# static_typeid
- assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
+ assert isinstance(TestIntegerRankedTensorType.static_typeid, TypeID)
# And it should be equal to the in-tree concrete type
- assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
+ assert TestIntegerRankedTensorType.static_typeid == t.type.typeid
d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
@@ -491,7 +512,7 @@ def inferReturnTypeComponents():
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
- a = test.TestType.get()
+ a = TestType.get()
assert a.typeid is not None
b = Type.parse("!python_test.test_type")
@@ -500,7 +521,7 @@ def testCustomTypeTypeCaster():
# CHECK: TestType(!python_test.test_type)
print(repr(b))
- c = test.TestIntegerRankedTensorType.get([10, 10], 5)
+ c = TestIntegerRankedTensorType.get([10, 10], 5)
# CHECK: tensor<10x10xi5>
print(c)
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
@@ -511,7 +532,7 @@ def testCustomTypeTypeCaster():
@register_type_caster(c.typeid)
def type_caster(pytype):
- return test.TestIntegerRankedTensorType(pytype)
+ return TestIntegerRankedTensorType(pytype)
except RuntimeError as e:
print(e)
@@ -530,7 +551,7 @@ def type_caster(pytype):
@register_type_caster(c.typeid, replace=True)
def type_caster(pytype):
- return test.TestIntegerRankedTensorType(pytype)
+ return TestIntegerRankedTensorType(pytype)
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt
index d7cbbfbc214772..198ed8211e773f 100644
--- a/mlir/test/python/lib/CMakeLists.txt
+++ b/mlir/test/python/lib/CMakeLists.txt
@@ -1,7 +1,8 @@
set(LLVM_OPTIONAL_SOURCES
PythonTestCAPI.cpp
PythonTestDialect.cpp
- PythonTestModule.cpp
+ PythonTestModulePybind11.cpp
+ PythonTestModuleNanobind.cpp
)
add_mlir_library(MLIRPythonTestDialect
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
new file mode 100644
index 00000000000000..7c504d04be0d13
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -0,0 +1,121 @@
+//===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This is the nanobind edition of the PythonTest dialect module.
+//===----------------------------------------------------------------------===//
+
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/vector.h>
+
+#include "PythonTestCAPI.h"
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+using namespace mlir::python::nanobind_adaptors;
+
+static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
+ return mlirTypeIsARankedTensor(t) &&
+ mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
+}
+
+NB_MODULE(_mlirPythonTestNanobind, m) {
+ m.def(
+ "register_python_test_dialect",
+ [](MlirContext context, bool load) {
+ MlirDialectHandle pythonTestDialect =
+ mlirGetDialectHandle__python_test__();
+ mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+ if (load) {
+ mlirDialectHandleLoadDialect(pythonTestDialect, context);
+ }
+ },
+ nb::arg("context"), nb::arg("load") = true);
+
+ m.def(
+ "register_dialect",
+ [](MlirDialectRegistry registry) {
+ MlirDialectHandle pythonTestDialect =
+ mlirGetDialectHandle__python_test__();
+ mlirDialectHandleInsertDialect(pythonTestDialect, registry);
+ },
+ nb::arg("registry"));
+
+ mlir_attribute_subclass(m, "TestAttr",
+ mlirAttributeIsAPythonTestTestAttribute,
+ mlirPythonTestTestAttributeGetTypeID)
+ .def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirPythonTestTestAttributeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context").none() = nb::none());
+
+ mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
+ mlirPythonTestTestTypeGetTypeID)
+ .def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirPythonTestTestTypeGet(ctx));
+ },
+ nb::arg("cls"), nb::arg("context").none() = nb::none());
+
+ auto typeCls =
+ mlir_type_subclass(m, "TestIntegerRankedTensorType",
+ mlirTypeIsARankedIntegerTensor,
+ nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("RankedTensorType"))
+ .def_classmethod(
+ "get",
+ [](const nb::object &cls, std::vector<int64_t> shape,
+ unsigned width, MlirContext ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return cls(mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
+ encoding));
+ },
+ nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
+ nb::arg("context").none() = nb::none());
+
+ assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
+ "TestIntegerRankedTensorType has no static_typeid");
+
+ MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
+
+ nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ mlirRankedTensorTypeID, nb::arg("replace") = true)(
+ nanobind::cpp_function([typeCls](const nb::object &mlirType) {
+ return typeCls.get_class()(mlirType);
+ }));
+
+ auto valueCls = mlir_value_subclass(m, "TestTensorValue",
+ mlirTypeIsAPythonTestTestTensorValue)
+ .def("is_null", [](MlirValue &self) {
+ return mlirValueIsNull(self);
+ });
+
+ nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
+ mlirRankedTensorTypeID)(
+ nanobind::cpp_function([valueCls](const nb::object &valueObj) {
+ nb::object capsule = mlirApiObjectToCapsule(valueObj);
+ MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
+ MlirType t = mlirValueGetType(v);
+ // This is hyper-specific in order to exercise/test registering a
+ // value caster from cpp (but only for a single test case; see
+ // testTensorValue python_test.py).
+ if (mlirShapedTypeHasStaticShape(t) &&
+ mlirShapedTypeGetDimSize(t, 0) == 1 &&
+ mlirShapedTypeGetDimSize(t, 1) == 2 &&
+ mlirShapedTypeGetDimSize(t, 2) == 3)
+ return valueCls.get_class()(valueObj);
+ return valueObj;
+ }));
+}
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModulePybind11.cpp
similarity index 96%
rename from mlir/test/python/lib/PythonTestModule.cpp
rename to mlir/test/python/lib/PythonTestModulePybind11.cpp
index a4f538dcb55944..94a5f5178d16e8 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModulePybind11.cpp
@@ -5,6 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+// This is the pybind11 edition of the PythonTest dialect module.
+//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
#include "mlir-c/BuiltinAttributes.h"
@@ -21,7 +23,7 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
-PYBIND11_MODULE(_mlirPythonTest, m) {
+PYBIND11_MODULE(_mlirPythonTestPybind11, m) {
m.def(
"register_python_test_dialect",
[](MlirContext context, bool load) {
diff --git a/utils/bazel/WORKSPACE b/utils/bazel/WORKSPACE
index 7baca11eed3d39..66ba1ac1b17e1e 100644
--- a/utils/bazel/WORKSPACE
+++ b/utils/bazel/WORKSPACE
@@ -148,6 +148,24 @@ maybe(
url = "https://github.com/pybind/pybind11/archive/v2.10.3.zip",
)
+maybe(
+ http_archive,
+ name = "robin_map",
+ strip_prefix = "robin-map-1.3.0",
+ sha256 = "a8424ad3b0affd4c57ed26f0f3d8a29604f0e1f2ef2089f497f614b1c94c7236",
+ build_file = "@llvm-raw//utils/bazel/third_party_build:robin_map.BUILD",
+ url = "https://github.com/Tessil/robin-map/archive/refs/tags/v1.3.0.tar.gz",
+)
+
+maybe(
+ http_archive,
+ name = "nanobind",
+ build_file = "@llvm-raw//utils/bazel/third_party_build:nanobind.BUILD",
+ sha256 = "bfbfc7e5759f1669e4ddb48752b1ddc5647d1430e94614d6f8626df1d508e65a",
+ strip_prefix = "nanobind-2.2.0",
+ url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.2.0.tar.gz",
+)
+
load("@rules_python//python:repositories.bzl", "py_repositories", "python_register_toolchains")
py_repositories()
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 179fed2f5e9a00..544becfa30b40f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -932,7 +932,6 @@ exports_files(
filegroup(
name = "MLIRBindingsPythonHeaderFiles",
srcs = glob([
- "lib/Bindings/Python/*.h",
"include/mlir-c/Bindings/Python/*.h",
"include/mlir/Bindings/Python/*.h",
]),
@@ -942,12 +941,10 @@ cc_library(
name = "MLIRBindingsPythonHeaders",
includes = [
"include",
- "lib/Bindings/Python",
],
textual_hdrs = [":MLIRBindingsPythonHeaderFiles"],
deps = [
":CAPIIRHeaders",
- ":CAPITransformsHeaders",
"@pybind11",
"@rules_python//python/cc:current_py_cc_headers",
],
@@ -957,17 +954,41 @@ cc_library(
name = "MLIRBindingsPythonHeadersAndDeps",
includes = [
"include",
- "lib/Bindings/Python",
],
textual_hdrs = [":MLIRBindingsPythonHeaderFiles"],
deps = [
":CAPIIR",
- ":CAPITransforms",
"@pybind11",
"@rules_python//python/cc:current_py_cc_headers",
],
)
+cc_library(
+ name = "MLIRBindingsPythonNanobindHeaders",
+ includes = [
+ "include",
+ ],
+ textual_hdrs = [":MLIRBindingsPythonHeaderFiles"],
+ deps = [
+ ":CAPIIRHeaders",
+ "@nanobind",
+ "@rules_python//python/cc:current_py_cc_headers",
+ ],
+)
+
+cc_library(
+ name = "MLIRBindingsPythonNanobindHeadersAndDeps",
+ includes = [
+ "include",
+ ],
+ textual_hdrs = [":MLIRBindingsPythonHeaderFiles"],
+ deps = [
+ ":CAPIIR",
+ "@nanobind",
+ "@rules_python//python/cc:current_py_cc_headers",
+ ],
+)
+
# These flags are needed for pybind11 to work.
PYBIND11_COPTS = [
"-fexceptions",
@@ -993,16 +1014,25 @@ filegroup(
],
)
+filegroup(
+ name = "MLIRBindingsPythonCoreHeaders",
+ srcs = glob([
+ "lib/Bindings/Python/*.h",
+ ]),
+)
+
cc_library(
name = "MLIRBindingsPythonCore",
srcs = [":MLIRBindingsPythonSourceFiles"],
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
+ textual_hdrs = [":MLIRBindingsPythonCoreHeaders"],
deps = [
":CAPIAsync",
":CAPIDebug",
":CAPIIR",
":CAPIInterfaces",
+ ":CAPITransforms",
":MLIRBindingsPythonHeadersAndDeps",
":Support",
":config",
@@ -1017,10 +1047,12 @@ cc_library(
srcs = [":MLIRBindingsPythonSourceFiles"],
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
+ textual_hdrs = [":MLIRBindingsPythonCoreHeaders"],
deps = [
":CAPIAsyncHeaders",
":CAPIDebugHeaders",
":CAPIIRHeaders",
+ ":CAPITransformsHeaders",
":MLIRBindingsPythonHeaders",
":Support",
":config",
@@ -1050,6 +1082,9 @@ cc_binary(
# These flags are needed for pybind11 to work.
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
+ includes = [
+ "lib/Bindings/Python",
+ ],
linkshared = 1,
linkstatic = 0,
deps = [
@@ -1063,6 +1098,9 @@ cc_binary(
srcs = ["lib/Bindings/Python/DialectLinalg.cpp"],
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
+ includes = [
+ "lib/Bindings/Python",
+ ],
linkshared = 1,
linkstatic = 0,
deps = [
@@ -8448,9 +8486,9 @@ cc_library(
hdrs = ["include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"],
includes = ["include"],
deps = [
+ ":Analysis",
":ConversionPassIncGen",
":ConvertToLLVMInterface",
- ":Analysis",
":IR",
":LLVMCommonConversion",
":LLVMDialect",
diff --git a/utils/bazel/third_party_build/nanobind.BUILD b/utils/bazel/third_party_build/nanobind.BUILD
new file mode 100644
index 00000000000000..262d14a040b87e
--- /dev/null
+++ b/utils/bazel/third_party_build/nanobind.BUILD
@@ -0,0 +1,25 @@
+cc_library(
+ name = "nanobind",
+ srcs = glob(
+ [
+ "src/*.cpp",
+ ],
+ exclude = ["src/nb_combined.cpp"],
+ ),
+ defines = [
+ "NB_BUILD=1",
+ "NB_SHARED=1",
+ ],
+ includes = ["include"],
+ textual_hdrs = glob(
+ [
+ "include/**/*.h",
+ "src/*.h",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+ deps = [
+ "@robin_map",
+ "@rules_python//python/cc:current_py_cc_headers",
+ ],
+)
diff --git a/utils/bazel/third_party_build/robin_map.BUILD b/utils/bazel/third_party_build/robin_map.BUILD
new file mode 100644
index 00000000000000..b8d04beaed81f9
--- /dev/null
+++ b/utils/bazel/third_party_build/robin_map.BUILD
@@ -0,0 +1,12 @@
+cc_library(
+ name = "robin_map",
+ hdrs = [
+ "include/tsl/robin_growth_policy.h",
+ "include/tsl/robin_hash.h",
+ "include/tsl/robin_map.h",
+ "include/tsl/robin_set.h",
+ ],
+ includes = ["."],
+ strip_include_prefix = "include",
+ visibility = ["//visibility:public"],
+)
>From cbcde9f1f2d83aa4da560b80ebd17ecfeb8ca6db Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins at google.com>
Date: Wed, 4 Dec 2024 02:44:33 +0000
Subject: [PATCH 2/2] mlir python] Port Python core code to nanobind.
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.
For a complicated Google-internal LLM model in JAX, this change improves the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.
To a large extent, this is a mechanical change, for instance changing pybind11::
to nanobind::.
Notes:
* this PR needs https://github.com/wjakob/nanobind/pull/806 to land in
nanobind first. Without that fix, importing the MLIR modules will
fail.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in PybindAdapters.h. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now defined in
nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of
a similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple
of places I added code to support truthy values during casting.
* nanobind distinguishes bytes (nb::bytes) from strings (e.g.,
std::string). This required nb::bytes overloads in a few places.
---
mlir/cmake/modules/MLIRDetectPythonEnv.cmake | 2 +-
mlir/include/mlir/Bindings/Python/IRTypes.h | 2 +-
.../mlir/Bindings/Python/PybindAdaptors.h | 10 +-
mlir/lib/Bindings/Python/Globals.h | 39 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 265 ++--
mlir/lib/Bindings/Python/IRAttributes.cpp | 663 +++++---
mlir/lib/Bindings/Python/IRCore.cpp | 1412 +++++++++--------
mlir/lib/Bindings/Python/IRInterfaces.cpp | 171 +-
mlir/lib/Bindings/Python/IRModule.cpp | 57 +-
mlir/lib/Bindings/Python/IRModule.h | 332 ++--
mlir/lib/Bindings/Python/IRTypes.cpp | 200 +--
mlir/lib/Bindings/Python/MainModule.cpp | 56 +-
.../Python/{PybindUtils.h => NanobindUtils.h} | 84 +-
mlir/lib/Bindings/Python/Pass.cpp | 58 +-
mlir/lib/Bindings/Python/Pass.h | 4 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 43 +-
mlir/lib/Bindings/Python/Rewrite.h | 4 +-
mlir/python/CMakeLists.txt | 3 +-
mlir/python/requirements.txt | 2 +-
mlir/test/python/ir/symbol_table.py | 3 +-
utils/bazel/WORKSPACE | 6 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 15 +-
22 files changed, 1862 insertions(+), 1569 deletions(-)
rename mlir/lib/Bindings/Python/{PybindUtils.h => NanobindUtils.h} (85%)
diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
index c62ac7fa615ea6..d6bb65c64b8292 100644
--- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
+++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
@@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages)
"extension = '${PYTHON_MODULE_EXTENSION}")
mlir_detect_nanobind_install()
- find_package(nanobind 2.2 CONFIG REQUIRED)
+ find_package(nanobind 2.4 CONFIG REQUIRED)
message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
"suffix = '${PYTHON_MODULE_SUFFIX}', "
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 9afad4c23b3f35..ba9642cf2c6a2d 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace mlir {
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index c8233355d1d67b..edc69774be9227 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -374,9 +374,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_staticmethod(...) called with a non-static member "
"function pointer");
- py::cpp_function cf(
- std::forward<Func>(f), py::name(name), py::scope(thisClass),
- py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+ py::cpp_function cf(std::forward<Func>(f), py::name(name),
+ py::scope(thisClass), extra...);
thisClass.attr(cf.name()) = py::staticmethod(cf);
return *this;
}
@@ -387,9 +386,8 @@ class pure_subclass {
static_assert(!std::is_member_function_pointer<Func>::value,
"def_classmethod(...) called with a non-static member "
"function pointer");
- py::cpp_function cf(
- std::forward<Func>(f), py::name(name), py::scope(thisClass),
- py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+ py::cpp_function cf(std::forward<Func>(f), py::name(name),
+ py::scope(thisClass), extra...);
thisClass.attr(cf.name()) =
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
return *this;
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index a022067f5c7e57..0ec522d14f74bd 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,18 +9,17 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
-#include "PybindUtils.h"
+#include <optional>
+#include <string>
+#include <vector>
+#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
-#include <optional>
-#include <string>
-#include <vector>
-
namespace mlir {
namespace python {
@@ -57,55 +56,55 @@ class PyGlobals {
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
- pybind11::function pyFunc,
+ nanobind::callable pyFunc,
bool replace = false);
/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
- void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
+ void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
bool replace = false);
/// Adds a user-friendly value caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerValueCaster(MlirTypeID mlirTypeID,
- pybind11::function valueCaster,
+ nanobind::callable valueCaster,
bool replace = false);
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerDialectImpl(const std::string &dialectNamespace,
- pybind11::object pyClass);
+ nanobind::object pyClass);
/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
- pybind11::object pyClass, bool replace = false);
+ nanobind::object pyClass, bool replace = false);
/// Returns the custom Attribute builder for Attribute kind.
- std::optional<pybind11::function>
+ std::optional<nanobind::callable>
lookupAttributeBuilder(const std::string &attributeKind);
/// Returns the custom type caster for MlirTypeID mlirTypeID.
- std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
+ std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
/// Returns the custom value caster for MlirTypeID mlirTypeID.
- std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
+ std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
- std::optional<pybind11::object>
+ std::optional<nanobind::object>
lookupDialectClass(const std::string &dialectNamespace);
/// Looks up a registered operation class (deriving from OpView) by operation
/// name. Note that this may trigger a load of the dialect, which can
/// arbitrarily re-enter.
- std::optional<pybind11::object>
+ std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
private:
@@ -113,15 +112,15 @@ class PyGlobals {
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
- llvm::StringMap<pybind11::object> dialectClassMap;
+ llvm::StringMap<nanobind::object> dialectClassMap;
/// Map of full operation name to external operation class object.
- llvm::StringMap<pybind11::object> operationClassMap;
+ llvm::StringMap<nanobind::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
- llvm::StringMap<pybind11::object> attributeBuilderMap;
+ llvm::StringMap<nanobind::callable> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
- llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
+ llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
/// Map of MlirTypeID to custom value caster.
- llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
+ llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index b138e131e851ea..2db690309fab8c 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -6,20 +6,19 @@
//
//===----------------------------------------------------------------------===//
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/vector.h>
+
#include <cstddef>
#include <cstdint>
-#include <pybind11/cast.h>
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-#include <pybind11/pytypes.h>
+#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include "IRModule.h"
-
-#include "PybindUtils.h"
-
+#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Bindings/Python/Interop.h"
@@ -30,7 +29,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
-namespace py = pybind11;
+namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
@@ -46,23 +45,23 @@ static const char kDumpDocstring[] =
/// Throws errors in case of failure, using "action" to describe what the caller
/// was attempting to do.
template <typename PyType, typename CType>
-static void pyListToVector(const py::list &list,
+static void pyListToVector(const nb::list &list,
llvm::SmallVectorImpl<CType> &result,
StringRef action) {
- result.reserve(py::len(list));
- for (py::handle item : list) {
+ result.reserve(nb::len(list));
+ for (nb::handle item : list) {
try {
- result.push_back(item.cast<PyType>());
- } catch (py::cast_error &err) {
+ result.push_back(nb::cast<PyType>(item));
+ } catch (nb::cast_error &err) {
std::string msg = (llvm::Twine("Invalid expression when ") + action +
" (" + err.what() + ")")
.str();
- throw py::cast_error(msg);
- } catch (py::reference_cast_error &err) {
+ throw std::runtime_error(msg.c_str());
+ } catch (std::runtime_error &err) {
std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
action + " (" + err.what() + ")")
.str();
- throw py::cast_error(msg);
+ throw std::runtime_error(msg.c_str());
}
}
}
@@ -94,7 +93,7 @@ class PyConcreteAffineExpr : public BaseTy {
// IsAFunctionTy isaFunction
// const char *pyClassName
// and redefine bindDerived.
- using ClassTy = py::class_<DerivedTy, BaseTy>;
+ using ClassTy = nb::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirAffineExpr);
PyConcreteAffineExpr() = default;
@@ -105,24 +104,25 @@ class PyConcreteAffineExpr : public BaseTy {
static MlirAffineExpr castFrom(PyAffineExpr &orig) {
if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw py::value_error((Twine("Cannot cast affine expression to ") +
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+ throw nb::value_error((Twine("Cannot cast affine expression to ") +
DerivedTy::pyClassName + " (from " + origRepr +
")")
- .str());
+ .str()
+ .c_str());
}
return orig;
}
- static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
- cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
+ static void bind(nb::module_ &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
cls.def_static(
"isinstance",
[](PyAffineExpr &otherAffineExpr) -> bool {
return DerivedTy::isaFunction(otherAffineExpr);
},
- py::arg("other"));
+ nb::arg("other"));
DerivedTy::bindDerived(cls);
}
@@ -144,9 +144,9 @@ class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
}
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
- py::arg("context") = py::none());
- c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
+ c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"),
+ nb::arg("context").none() = nb::none());
+ c.def_prop_ro("value", [](PyAffineConstantExpr &self) {
return mlirAffineConstantExprGetValue(self);
});
}
@@ -164,9 +164,9 @@ class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
}
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
- py::arg("context") = py::none());
- c.def_property_readonly("position", [](PyAffineDimExpr &self) {
+ c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"),
+ nb::arg("context").none() = nb::none());
+ c.def_prop_ro("position", [](PyAffineDimExpr &self) {
return mlirAffineDimExprGetPosition(self);
});
}
@@ -184,9 +184,9 @@ class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
}
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
- py::arg("context") = py::none());
- c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
+ c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"),
+ nb::arg("context").none() = nb::none());
+ c.def_prop_ro("position", [](PyAffineSymbolExpr &self) {
return mlirAffineSymbolExprGetPosition(self);
});
}
@@ -209,8 +209,8 @@ class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
}
static void bindDerived(ClassTy &c) {
- c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
- c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
+ c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs);
+ c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs);
}
};
@@ -365,15 +365,14 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
return mlirAffineExprEqual(affineExpr, other.affineExpr);
}
-py::object PyAffineExpr::getCapsule() {
- return py::reinterpret_steal<py::object>(
- mlirPythonAffineExprToCapsule(*this));
+nb::object PyAffineExpr::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this));
}
-PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
+PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) {
MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
if (mlirAffineExprIsNull(rawAffineExpr))
- throw py::error_already_set();
+ throw nb::python_error();
return PyAffineExpr(
PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
rawAffineExpr);
@@ -424,14 +423,14 @@ bool PyAffineMap::operator==(const PyAffineMap &other) const {
return mlirAffineMapEqual(affineMap, other.affineMap);
}
-py::object PyAffineMap::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
+nb::object PyAffineMap::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this));
}
-PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
+PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) {
MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
if (mlirAffineMapIsNull(rawAffineMap))
- throw py::error_already_set();
+ throw nb::python_error();
return PyAffineMap(
PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
rawAffineMap);
@@ -454,11 +453,10 @@ class PyIntegerSetConstraint {
bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
- static void bind(py::module &m) {
- py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
- py::module_local())
- .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
- .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
+ static void bind(nb::module_ &m) {
+ nb::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
+ .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr)
+ .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq);
}
private:
@@ -501,27 +499,25 @@ bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
return mlirIntegerSetEqual(integerSet, other.integerSet);
}
-py::object PyIntegerSet::getCapsule() {
- return py::reinterpret_steal<py::object>(
- mlirPythonIntegerSetToCapsule(*this));
+nb::object PyIntegerSet::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this));
}
-PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
+PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) {
MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
if (mlirIntegerSetIsNull(rawIntegerSet))
- throw py::error_already_set();
+ throw nb::python_error();
return PyIntegerSet(
PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
rawIntegerSet);
}
-void mlir::python::populateIRAffine(py::module &m) {
+void mlir::python::populateIRAffine(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineExpr and derived classes.
//----------------------------------------------------------------------------
- py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyAffineExpr::getCapsule)
+ nb::class_<PyAffineExpr>(m, "AffineExpr")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
.def("__add__", &PyAffineAddExpr::get)
.def("__add__", &PyAffineAddExpr::getRHSConstant)
@@ -558,7 +554,7 @@ void mlir::python::populateIRAffine(py::module &m) {
.def("__eq__", [](PyAffineExpr &self,
PyAffineExpr &other) { return self == other; })
.def("__eq__",
- [](PyAffineExpr &self, py::object &other) { return false; })
+ [](PyAffineExpr &self, nb::object &other) { return false; })
.def("__str__",
[](PyAffineExpr &self) {
PyPrintAccumulator printAccum;
@@ -579,7 +575,7 @@ void mlir::python::populateIRAffine(py::module &m) {
[](PyAffineExpr &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
- .def_property_readonly(
+ .def_prop_ro(
"context",
[](PyAffineExpr &self) { return self.getContext().getObject(); })
.def("compose",
@@ -632,16 +628,16 @@ void mlir::python::populateIRAffine(py::module &m) {
.def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
"Gets an affine expression containing the rounded-up result "
"of dividing an expression by a constant.")
- .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
- py::arg("context") = py::none(),
+ .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"),
+ nb::arg("context").none() = nb::none(),
"Gets a constant affine expression with the given value.")
.def_static(
- "get_dim", &PyAffineDimExpr::get, py::arg("position"),
- py::arg("context") = py::none(),
+ "get_dim", &PyAffineDimExpr::get, nb::arg("position"),
+ nb::arg("context").none() = nb::none(),
"Gets an affine expression of a dimension at the given position.")
.def_static(
- "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
- py::arg("context") = py::none(),
+ "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"),
+ nb::arg("context").none() = nb::none(),
"Gets an affine expression of a symbol at the given position.")
.def(
"dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
@@ -659,13 +655,12 @@ void mlir::python::populateIRAffine(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineMap.
//----------------------------------------------------------------------------
- py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyAffineMap::getCapsule)
+ nb::class_<PyAffineMap>(m, "AffineMap")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
.def("__eq__",
[](PyAffineMap &self, PyAffineMap &other) { return self == other; })
- .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
+ .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; })
.def("__str__",
[](PyAffineMap &self) {
PyPrintAccumulator printAccum;
@@ -687,7 +682,7 @@ void mlir::python::populateIRAffine(py::module &m) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def_static("compress_unused_symbols",
- [](py::list affineMaps, DefaultingPyMlirContext context) {
+ [](nb::list affineMaps, DefaultingPyMlirContext context) {
SmallVector<MlirAffineMap> maps;
pyListToVector<PyAffineMap, MlirAffineMap>(
affineMaps, maps, "attempting to create an AffineMap");
@@ -704,7 +699,7 @@ void mlir::python::populateIRAffine(py::module &m) {
res.emplace_back(context->getRef(), m);
return res;
})
- .def_property_readonly(
+ .def_prop_ro(
"context",
[](PyAffineMap &self) { return self.getContext().getObject(); },
"Context that owns the Affine Map")
@@ -713,7 +708,7 @@ void mlir::python::populateIRAffine(py::module &m) {
kDumpDocstring)
.def_static(
"get",
- [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
+ [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs,
DefaultingPyMlirContext context) {
SmallVector<MlirAffineExpr> affineExprs;
pyListToVector<PyAffineExpr, MlirAffineExpr>(
@@ -723,8 +718,8 @@ void mlir::python::populateIRAffine(py::module &m) {
affineExprs.size(), affineExprs.data());
return PyAffineMap(context->getRef(), map);
},
- py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
- py::arg("context") = py::none(),
+ nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"),
+ nb::arg("context").none() = nb::none(),
"Gets a map with the given expressions as results.")
.def_static(
"get_constant",
@@ -733,7 +728,7 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirAffineMapConstantGet(context->get(), value);
return PyAffineMap(context->getRef(), affineMap);
},
- py::arg("value"), py::arg("context") = py::none(),
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
"Gets an affine map with a single constant result")
.def_static(
"get_empty",
@@ -741,7 +736,7 @@ void mlir::python::populateIRAffine(py::module &m) {
MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
return PyAffineMap(context->getRef(), affineMap);
},
- py::arg("context") = py::none(), "Gets an empty affine map.")
+ nb::arg("context").none() = nb::none(), "Gets an empty affine map.")
.def_static(
"get_identity",
[](intptr_t nDims, DefaultingPyMlirContext context) {
@@ -749,7 +744,7 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
return PyAffineMap(context->getRef(), affineMap);
},
- py::arg("n_dims"), py::arg("context") = py::none(),
+ nb::arg("n_dims"), nb::arg("context").none() = nb::none(),
"Gets an identity map with the given number of dimensions.")
.def_static(
"get_minor_identity",
@@ -759,8 +754,8 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
return PyAffineMap(context->getRef(), affineMap);
},
- py::arg("n_dims"), py::arg("n_results"),
- py::arg("context") = py::none(),
+ nb::arg("n_dims"), nb::arg("n_results"),
+ nb::arg("context").none() = nb::none(),
"Gets a minor identity map with the given number of dimensions and "
"results.")
.def_static(
@@ -768,13 +763,13 @@ void mlir::python::populateIRAffine(py::module &m) {
[](std::vector<unsigned> permutation,
DefaultingPyMlirContext context) {
if (!isPermutation(permutation))
- throw py::cast_error("Invalid permutation when attempting to "
- "create an AffineMap");
+ throw std::runtime_error("Invalid permutation when attempting to "
+ "create an AffineMap");
MlirAffineMap affineMap = mlirAffineMapPermutationGet(
context->get(), permutation.size(), permutation.data());
return PyAffineMap(context->getRef(), affineMap);
},
- py::arg("permutation"), py::arg("context") = py::none(),
+ nb::arg("permutation"), nb::arg("context").none() = nb::none(),
"Gets an affine map that permutes its inputs.")
.def(
"get_submap",
@@ -782,33 +777,33 @@ void mlir::python::populateIRAffine(py::module &m) {
intptr_t numResults = mlirAffineMapGetNumResults(self);
for (intptr_t pos : resultPos) {
if (pos < 0 || pos >= numResults)
- throw py::value_error("result position out of bounds");
+ throw nb::value_error("result position out of bounds");
}
MlirAffineMap affineMap = mlirAffineMapGetSubMap(
self, resultPos.size(), resultPos.data());
return PyAffineMap(self.getContext(), affineMap);
},
- py::arg("result_positions"))
+ nb::arg("result_positions"))
.def(
"get_major_submap",
[](PyAffineMap &self, intptr_t nResults) {
if (nResults >= mlirAffineMapGetNumResults(self))
- throw py::value_error("number of results out of bounds");
+ throw nb::value_error("number of results out of bounds");
MlirAffineMap affineMap =
mlirAffineMapGetMajorSubMap(self, nResults);
return PyAffineMap(self.getContext(), affineMap);
},
- py::arg("n_results"))
+ nb::arg("n_results"))
.def(
"get_minor_submap",
[](PyAffineMap &self, intptr_t nResults) {
if (nResults >= mlirAffineMapGetNumResults(self))
- throw py::value_error("number of results out of bounds");
+ throw nb::value_error("number of results out of bounds");
MlirAffineMap affineMap =
mlirAffineMapGetMinorSubMap(self, nResults);
return PyAffineMap(self.getContext(), affineMap);
},
- py::arg("n_results"))
+ nb::arg("n_results"))
.def(
"replace",
[](PyAffineMap &self, PyAffineExpr &expression,
@@ -818,39 +813,37 @@ void mlir::python::populateIRAffine(py::module &m) {
self, expression, replacement, numResultDims, numResultSyms);
return PyAffineMap(self.getContext(), affineMap);
},
- py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
- py::arg("n_result_syms"))
- .def_property_readonly(
+ nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"),
+ nb::arg("n_result_syms"))
+ .def_prop_ro(
"is_permutation",
[](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
- .def_property_readonly("is_projected_permutation",
- [](PyAffineMap &self) {
- return mlirAffineMapIsProjectedPermutation(self);
- })
- .def_property_readonly(
+ .def_prop_ro("is_projected_permutation",
+ [](PyAffineMap &self) {
+ return mlirAffineMapIsProjectedPermutation(self);
+ })
+ .def_prop_ro(
"n_dims",
[](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
- .def_property_readonly(
+ .def_prop_ro(
"n_inputs",
[](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
- .def_property_readonly(
+ .def_prop_ro(
"n_symbols",
[](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
- .def_property_readonly("results", [](PyAffineMap &self) {
- return PyAffineMapExprList(self);
- });
+ .def_prop_ro("results",
+ [](PyAffineMap &self) { return PyAffineMapExprList(self); });
PyAffineMapExprList::bind(m);
//----------------------------------------------------------------------------
// Mapping of PyIntegerSet.
//----------------------------------------------------------------------------
- py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyIntegerSet::getCapsule)
+ nb::class_<PyIntegerSet>(m, "IntegerSet")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
.def("__eq__", [](PyIntegerSet &self,
PyIntegerSet &other) { return self == other; })
- .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
+ .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; })
.def("__str__",
[](PyIntegerSet &self) {
PyPrintAccumulator printAccum;
@@ -871,7 +864,7 @@ void mlir::python::populateIRAffine(py::module &m) {
[](PyIntegerSet &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
- .def_property_readonly(
+ .def_prop_ro(
"context",
[](PyIntegerSet &self) { return self.getContext().getObject(); })
.def(
@@ -879,14 +872,14 @@ void mlir::python::populateIRAffine(py::module &m) {
kDumpDocstring)
.def_static(
"get",
- [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
+ [](intptr_t numDims, intptr_t numSymbols, nb::list exprs,
std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
if (exprs.size() != eqFlags.size())
- throw py::value_error(
+ throw nb::value_error(
"Expected the number of constraints to match "
"that of equality flags");
- if (exprs.empty())
- throw py::value_error("Expected non-empty list of constraints");
+ if (exprs.size() == 0)
+ throw nb::value_error("Expected non-empty list of constraints");
// Copy over to a SmallVector because std::vector has a
// specialization for booleans that packs data and does not
@@ -901,8 +894,8 @@ void mlir::python::populateIRAffine(py::module &m) {
affineExprs.data(), flags.data());
return PyIntegerSet(context->getRef(), set);
},
- py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
- py::arg("eq_flags"), py::arg("context") = py::none())
+ nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"),
+ nb::arg("eq_flags"), nb::arg("context").none() = nb::none())
.def_static(
"get_empty",
[](intptr_t numDims, intptr_t numSymbols,
@@ -911,20 +904,20 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
return PyIntegerSet(context->getRef(), set);
},
- py::arg("num_dims"), py::arg("num_symbols"),
- py::arg("context") = py::none())
+ nb::arg("num_dims"), nb::arg("num_symbols"),
+ nb::arg("context").none() = nb::none())
.def(
"get_replaced",
- [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
+ [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs,
intptr_t numResultDims, intptr_t numResultSymbols) {
if (static_cast<intptr_t>(dimExprs.size()) !=
mlirIntegerSetGetNumDims(self))
- throw py::value_error(
+ throw nb::value_error(
"Expected the number of dimension replacement expressions "
"to match that of dimensions");
if (static_cast<intptr_t>(symbolExprs.size()) !=
mlirIntegerSetGetNumSymbols(self))
- throw py::value_error(
+ throw nb::value_error(
"Expected the number of symbol replacement expressions "
"to match that of symbols");
@@ -940,30 +933,30 @@ void mlir::python::populateIRAffine(py::module &m) {
numResultDims, numResultSymbols);
return PyIntegerSet(self.getContext(), set);
},
- py::arg("dim_exprs"), py::arg("symbol_exprs"),
- py::arg("num_result_dims"), py::arg("num_result_symbols"))
- .def_property_readonly("is_canonical_empty",
- [](PyIntegerSet &self) {
- return mlirIntegerSetIsCanonicalEmpty(self);
- })
- .def_property_readonly(
+ nb::arg("dim_exprs"), nb::arg("symbol_exprs"),
+ nb::arg("num_result_dims"), nb::arg("num_result_symbols"))
+ .def_prop_ro("is_canonical_empty",
+ [](PyIntegerSet &self) {
+ return mlirIntegerSetIsCanonicalEmpty(self);
+ })
+ .def_prop_ro(
"n_dims",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
- .def_property_readonly(
+ .def_prop_ro(
"n_symbols",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
- .def_property_readonly(
+ .def_prop_ro(
"n_inputs",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
- .def_property_readonly("n_equalities",
- [](PyIntegerSet &self) {
- return mlirIntegerSetGetNumEqualities(self);
- })
- .def_property_readonly("n_inequalities",
- [](PyIntegerSet &self) {
- return mlirIntegerSetGetNumInequalities(self);
- })
- .def_property_readonly("constraints", [](PyIntegerSet &self) {
+ .def_prop_ro("n_equalities",
+ [](PyIntegerSet &self) {
+ return mlirIntegerSetGetNumEqualities(self);
+ })
+ .def_prop_ro("n_inequalities",
+ [](PyIntegerSet &self) {
+ return mlirIntegerSetGetNumInequalities(self);
+ })
+ .def_prop_ro("constraints", [](PyIntegerSet &self) {
return PyIntegerSetConstraintList(self);
});
PyIntegerSetConstraint::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index cc9532f4e33b2c..c85c4e286fbb61 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -6,23 +6,29 @@
//
//===----------------------------------------------------------------------===//
+#include <nanobind/nanobind.h>
+#include <nanobind/ndarray.h>
+#include <nanobind/stl/optional.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/string_view.h>
+#include <nanobind/stl/vector.h>
+
+#include <cstdint>
#include <optional>
+#include <string>
#include <string_view>
#include <utility>
#include "IRModule.h"
-
-#include "PybindUtils.h"
-#include <pybind11/numpy.h>
-
-#include "llvm/ADT/ScopeExit.h"
-#include "llvm/Support/raw_ostream.h"
-
+#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/raw_ostream.h"
-namespace py = pybind11;
+namespace nb = nanobind;
+using namespace nanobind::literals;
using namespace mlir;
using namespace mlir::python;
@@ -123,10 +129,108 @@ subsequent processing.
namespace {
+struct nb_buffer_info {
+ void *ptr = nullptr;
+ ssize_t itemsize = 0;
+ ssize_t size = 0;
+ const char *format = nullptr;
+ ssize_t ndim = 0;
+ SmallVector<ssize_t, 4> shape;
+ SmallVector<ssize_t, 4> strides;
+ bool readonly = false;
+
+ nb_buffer_info() = default;
+
+ nb_buffer_info(void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+ SmallVector<ssize_t, 4> shape_in,
+ SmallVector<ssize_t, 4> strides_in, bool readonly = false)
+ : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
+ shape(std::move(shape_in)), strides(std::move(strides_in)),
+ readonly(readonly) {
+ size = 1;
+ for (ssize_t i = 0; i < ndim; ++i) {
+ size *= shape[i];
+ }
+ }
+
+ explicit nb_buffer_info(Py_buffer *view)
+ : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
+ {view->shape, view->shape + view->ndim},
+ // TODO(phawkins): check for null strides
+ {view->strides, view->strides + view->ndim},
+ view->readonly != 0) {}
+};
+
+class nb_buffer : public nb::object {
+ NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
+
+ nb_buffer_info request() const {
+ int flags = PyBUF_STRIDES | PyBUF_FORMAT;
+ auto *view = new Py_buffer();
+ if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
+ delete view;
+ throw nb::python_error();
+ }
+ return nb_buffer_info(view);
+ }
+};
+
+template <typename T>
+struct nb_format_descriptor {};
+
+template <>
+struct nb_format_descriptor<bool> {
+ static const char *format() { return "?"; }
+};
+template <>
+struct nb_format_descriptor<signed char> {
+ static const char *format() { return "b"; }
+};
+template <>
+struct nb_format_descriptor<unsigned char> {
+ static const char *format() { return "B"; }
+};
+template <>
+struct nb_format_descriptor<short> {
+ static const char *format() { return "h"; }
+};
+template <>
+struct nb_format_descriptor<unsigned short> {
+ static const char *format() { return "H"; }
+};
+template <>
+struct nb_format_descriptor<int> {
+ static const char *format() { return "i"; }
+};
+template <>
+struct nb_format_descriptor<unsigned int> {
+ static const char *format() { return "I"; }
+};
+template <>
+struct nb_format_descriptor<long> {
+ static const char *format() { return "l"; }
+};
+template <>
+struct nb_format_descriptor<unsigned long> {
+ static const char *format() { return "L"; }
+};
+template <>
+struct nb_format_descriptor<float> {
+ static const char *format() { return "f"; }
+};
+template <>
+struct nb_format_descriptor<double> {
+ static const char *format() { return "d"; }
+};
+
static MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
+static MlirStringRef toMlirStringRef(const nb::bytes &s) {
+ return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
+}
+
class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
@@ -142,9 +246,9 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
return PyAffineMapAttribute(affineMap.getContext(), attr);
},
- py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
- c.def_property_readonly("value", mlirAffineMapAttrGetValue,
- "Returns the value of the AffineMap attribute");
+ nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+ c.def_prop_ro("value", mlirAffineMapAttrGetValue,
+ "Returns the value of the AffineMap attribute");
}
};
@@ -164,25 +268,24 @@ class PyIntegerSetAttribute
MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
return PyIntegerSetAttribute(integerSet.getContext(), attr);
},
- py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+ nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
}
};
template <typename T>
-static T pyTryCast(py::handle object) {
+static T pyTryCast(nb::handle object) {
try {
- return object.cast<T>();
- } catch (py::cast_error &err) {
- std::string msg =
- std::string(
- "Invalid attribute when attempting to create an ArrayAttribute (") +
- err.what() + ")";
- throw py::cast_error(msg);
- } catch (py::reference_cast_error &err) {
+ return nb::cast<T>(object);
+ } catch (nb::cast_error &err) {
+ std::string msg = std::string("Invalid attribute when attempting to "
+ "create an ArrayAttribute (") +
+ err.what() + ")";
+ throw std::runtime_error(msg.c_str());
+ } catch (std::runtime_error &err) {
std::string msg = std::string("Invalid attribute (None?) when attempting "
"to create an ArrayAttribute (") +
err.what() + ")";
- throw py::cast_error(msg);
+ throw std::runtime_error(msg.c_str());
}
}
@@ -205,14 +308,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
EltTy dunderNext() {
// Throw if the index has reached the end.
if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
- throw py::stop_iteration();
+ throw nb::stop_iteration();
return DerivedT::getElement(attr.get(), nextIndex++);
}
/// Bind the iterator class.
- static void bind(py::module &m) {
- py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
- py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
.def("__iter__", &PyDenseArrayIterator::dunderIter)
.def("__next__", &PyDenseArrayIterator::dunderNext);
}
@@ -230,17 +332,35 @@ class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
/// Bind the attribute class.
static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
// Bind the constructor.
- c.def_static(
- "get",
- [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
- return getAttribute(values, ctx->getRef());
- },
- py::arg("values"), py::arg("context") = py::none(),
- "Gets a uniqued dense array attribute");
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ c.def_static(
+ "get",
+ [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
+ std::vector<bool> values;
+ for (nb::handle py_value : py_values) {
+ int is_true = PyObject_IsTrue(py_value.ptr());
+ if (is_true < 0) {
+ throw nb::python_error();
+ }
+ values.push_back(is_true);
+ }
+ return getAttribute(values, ctx->getRef());
+ },
+ nb::arg("values"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued dense array attribute");
+ } else {
+ c.def_static(
+ "get",
+ [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
+ return getAttribute(values, ctx->getRef());
+ },
+ nb::arg("values"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued dense array attribute");
+ }
// Bind the array methods.
c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
if (i >= mlirDenseArrayGetNumElements(arr))
- throw py::index_error("DenseArray index out of range");
+ throw nb::index_error("DenseArray index out of range");
return arr.getItem(i);
});
c.def("__len__", [](const DerivedT &arr) {
@@ -248,13 +368,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
});
c.def("__iter__",
[](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
- c.def("__add__", [](DerivedT &arr, const py::list &extras) {
+ c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
std::vector<EltTy> values;
intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
- values.reserve(numOldElements + py::len(extras));
+ values.reserve(numOldElements + nb::len(extras));
for (intptr_t i = 0; i < numOldElements; ++i)
values.push_back(arr.getItem(i));
- for (py::handle attr : extras)
+ for (nb::handle attr : extras)
values.push_back(pyTryCast<EltTy>(attr));
return getAttribute(values, arr.getContext());
});
@@ -358,13 +478,12 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
MlirAttribute dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
- throw py::stop_iteration();
+ throw nb::stop_iteration();
return mlirArrayAttrGetElement(attr.get(), nextIndex++);
}
- static void bind(py::module &m) {
- py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
- py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
.def("__iter__", &PyArrayAttributeIterator::dunderIter)
.def("__next__", &PyArrayAttributeIterator::dunderNext);
}
@@ -381,9 +500,9 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](py::list attributes, DefaultingPyMlirContext context) {
+ [](nb::list attributes, DefaultingPyMlirContext context) {
SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(py::len(attributes));
+ mlirAttributes.reserve(nb::len(attributes));
for (auto attribute : attributes) {
mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
}
@@ -391,12 +510,12 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
context->get(), mlirAttributes.size(), mlirAttributes.data());
return PyArrayAttribute(context->getRef(), attr);
},
- py::arg("attributes"), py::arg("context") = py::none(),
+ nb::arg("attributes"), nb::arg("context").none() = nb::none(),
"Gets a uniqued Array attribute");
c.def("__getitem__",
[](PyArrayAttribute &arr, intptr_t i) {
if (i >= mlirArrayAttrGetNumElements(arr))
- throw py::index_error("ArrayAttribute index out of range");
+ throw nb::index_error("ArrayAttribute index out of range");
return arr.getItem(i);
})
.def("__len__",
@@ -406,13 +525,13 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
.def("__iter__", [](const PyArrayAttribute &arr) {
return PyArrayAttributeIterator(arr);
});
- c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
+ c.def("__add__", [](PyArrayAttribute arr, nb::list extras) {
std::vector<MlirAttribute> attributes;
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
- attributes.reserve(numOldElements + py::len(extras));
+ attributes.reserve(numOldElements + nb::len(extras));
for (intptr_t i = 0; i < numOldElements; ++i)
attributes.push_back(arr.getItem(i));
- for (py::handle attr : extras)
+ for (nb::handle attr : extras)
attributes.push_back(pyTryCast<PyAttribute>(attr));
MlirAttribute arrayAttr = mlirArrayAttrGet(
arr.getContext()->get(), attributes.size(), attributes.data());
@@ -440,7 +559,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
throw MLIRError("Invalid attribute", errors.take());
return PyFloatAttribute(type.getContext(), attr);
},
- py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
+ nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(),
"Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_f32",
@@ -449,7 +568,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
context->get(), mlirF32TypeGet(context->get()), value);
return PyFloatAttribute(context->getRef(), attr);
},
- py::arg("value"), py::arg("context") = py::none(),
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
"Gets an uniqued float point attribute associated to a f32 type");
c.def_static(
"get_f64",
@@ -458,10 +577,10 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
context->get(), mlirF64TypeGet(context->get()), value);
return PyFloatAttribute(context->getRef(), attr);
},
- py::arg("value"), py::arg("context") = py::none(),
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
"Gets an uniqued float point attribute associated to a f64 type");
- c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
- "Returns the value of the float attribute");
+ c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
+ "Returns the value of the float attribute");
c.def("__float__", mlirFloatAttrGetValueDouble,
"Converts the value of the float attribute to a Python float");
}
@@ -481,20 +600,20 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
MlirAttribute attr = mlirIntegerAttrGet(type, value);
return PyIntegerAttribute(type.getContext(), attr);
},
- py::arg("type"), py::arg("value"),
+ nb::arg("type"), nb::arg("value"),
"Gets an uniqued integer attribute associated to a type");
- c.def_property_readonly("value", toPyInt,
- "Returns the value of the integer attribute");
+ c.def_prop_ro("value", toPyInt,
+ "Returns the value of the integer attribute");
c.def("__int__", toPyInt,
"Converts the value of the integer attribute to a Python int");
- c.def_property_readonly_static("static_typeid",
- [](py::object & /*class*/) -> MlirTypeID {
- return mlirIntegerAttrGetTypeID();
- });
+ c.def_prop_ro_static("static_typeid",
+ [](nb::object & /*class*/) -> MlirTypeID {
+ return mlirIntegerAttrGetTypeID();
+ });
}
private:
- static py::int_ toPyInt(PyIntegerAttribute &self) {
+ static int64_t toPyInt(PyIntegerAttribute &self) {
MlirType type = mlirAttributeGetType(self);
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
return mlirIntegerAttrGetValueInt(self);
@@ -518,10 +637,10 @@ class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
return PyBoolAttribute(context->getRef(), attr);
},
- py::arg("value"), py::arg("context") = py::none(),
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
"Gets an uniqued bool attribute");
- c.def_property_readonly("value", mlirBoolAttrGetValue,
- "Returns the value of the bool attribute");
+ c.def_prop_ro("value", mlirBoolAttrGetValue,
+ "Returns the value of the bool attribute");
c.def("__bool__", mlirBoolAttrGetValue,
"Converts the value of the bool attribute to a Python bool");
}
@@ -555,9 +674,9 @@ class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
DefaultingPyMlirContext context) {
return PySymbolRefAttribute::fromList(symbols, context.resolve());
},
- py::arg("symbols"), py::arg("context") = py::none(),
+ nb::arg("symbols"), nb::arg("context").none() = nb::none(),
"Gets a uniqued SymbolRef attribute from a list of symbol names");
- c.def_property_readonly(
+ c.def_prop_ro(
"value",
[](PySymbolRefAttribute &self) {
std::vector<std::string> symbols = {
@@ -589,13 +708,13 @@ class PyFlatSymbolRefAttribute
mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
return PyFlatSymbolRefAttribute(context->getRef(), attr);
},
- py::arg("value"), py::arg("context") = py::none(),
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
"Gets a uniqued FlatSymbolRef attribute");
- c.def_property_readonly(
+ c.def_prop_ro(
"value",
[](PyFlatSymbolRefAttribute &self) {
MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
- return py::str(stringRef.data, stringRef.length);
+ return nb::str(stringRef.data, stringRef.length);
},
"Returns the value of the FlatSymbolRef attribute as a string");
}
@@ -612,29 +731,29 @@ class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string dialectNamespace, py::buffer buffer, PyType &type,
+ [](std::string dialectNamespace, nb_buffer buffer, PyType &type,
DefaultingPyMlirContext context) {
- const py::buffer_info bufferInfo = buffer.request();
+ const nb_buffer_info bufferInfo = buffer.request();
intptr_t bufferSize = bufferInfo.size;
MlirAttribute attr = mlirOpaqueAttrGet(
context->get(), toMlirStringRef(dialectNamespace), bufferSize,
static_cast<char *>(bufferInfo.ptr), type);
return PyOpaqueAttribute(context->getRef(), attr);
},
- py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
- py::arg("context") = py::none(), "Gets an Opaque attribute.");
- c.def_property_readonly(
+ nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
+ nb::arg("context").none() = nb::none(), "Gets an Opaque attribute.");
+ c.def_prop_ro(
"dialect_namespace",
[](PyOpaqueAttribute &self) {
MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
- return py::str(stringRef.data, stringRef.length);
+ return nb::str(stringRef.data, stringRef.length);
},
"Returns the dialect namespace for the Opaque attribute as a string");
- c.def_property_readonly(
+ c.def_prop_ro(
"data",
[](PyOpaqueAttribute &self) {
MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
- return py::bytes(stringRef.data, stringRef.length);
+ return nb::bytes(stringRef.data, stringRef.length);
},
"Returns the data for the Opaqued attributes as `bytes`");
}
@@ -656,7 +775,16 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
mlirStringAttrGet(context->get(), toMlirStringRef(value));
return PyStringAttribute(context->getRef(), attr);
},
- py::arg("value"), py::arg("context") = py::none(),
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
+ "Gets a uniqued string attribute");
+ c.def_static(
+ "get",
+ [](nb::bytes value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirStringAttrGet(context->get(), toMlirStringRef(value));
+ return PyStringAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
"Gets a uniqued string attribute");
c.def_static(
"get_typed",
@@ -665,20 +793,20 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
mlirStringAttrTypedGet(type, toMlirStringRef(value));
return PyStringAttribute(type.getContext(), attr);
},
- py::arg("type"), py::arg("value"),
+ nb::arg("type"), nb::arg("value"),
"Gets a uniqued string attribute associated to a type");
- c.def_property_readonly(
+ c.def_prop_ro(
"value",
[](PyStringAttribute &self) {
MlirStringRef stringRef = mlirStringAttrGetValue(self);
- return py::str(stringRef.data, stringRef.length);
+ return nb::str(stringRef.data, stringRef.length);
},
"Returns the value of the string attribute");
- c.def_property_readonly(
+ c.def_prop_ro(
"value_bytes",
[](PyStringAttribute &self) {
MlirStringRef stringRef = mlirStringAttrGetValue(self);
- return py::bytes(stringRef.data, stringRef.length);
+ return nb::bytes(stringRef.data, stringRef.length);
},
"Returns the value of the string attribute as `bytes`");
}
@@ -693,12 +821,11 @@ class PyDenseElementsAttribute
using PyConcreteAttribute::PyConcreteAttribute;
static PyDenseElementsAttribute
- getFromList(py::list attributes, std::optional<PyType> explicitType,
+ getFromList(nb::list attributes, std::optional<PyType> explicitType,
DefaultingPyMlirContext contextWrapper) {
-
- const size_t numAttributes = py::len(attributes);
+ const size_t numAttributes = nb::len(attributes);
if (numAttributes == 0)
- throw py::value_error("Attributes list must be non-empty.");
+ throw nb::value_error("Attributes list must be non-empty.");
MlirType shapedType;
if (explicitType) {
@@ -708,8 +835,8 @@ class PyDenseElementsAttribute
std::string message;
llvm::raw_string_ostream os(message);
os << "Expected a static ShapedType for the shaped_type parameter: "
- << py::repr(py::cast(*explicitType));
- throw py::value_error(message);
+ << nb::cast<std::string_view>(nb::repr(nb::cast(*explicitType)));
+ throw nb::value_error(message.c_str());
}
shapedType = *explicitType;
} else {
@@ -722,7 +849,7 @@ class PyDenseElementsAttribute
SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(numAttributes);
- for (const py::handle &attribute : attributes) {
+ for (const nb::handle &attribute : attributes) {
MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
MlirType attrType = mlirAttributeGetType(mlirAttribute);
mlirAttributes.push_back(mlirAttribute);
@@ -731,9 +858,11 @@ class PyDenseElementsAttribute
std::string message;
llvm::raw_string_ostream os(message);
os << "All attributes must be of the same type and match "
- << "the type parameter: expected=" << py::repr(py::cast(shapedType))
- << ", but got=" << py::repr(py::cast(attrType));
- throw py::value_error(message);
+ << "the type parameter: expected="
+ << nb::cast<std::string_view>(nb::repr(nb::cast(shapedType)))
+ << ", but got="
+ << nb::cast<std::string_view>(nb::repr(nb::cast(attrType)));
+ throw nb::value_error(message.c_str());
}
}
@@ -744,7 +873,7 @@ class PyDenseElementsAttribute
}
static PyDenseElementsAttribute
- getFromBuffer(py::buffer array, bool signless,
+ getFromBuffer(nb_buffer array, bool signless,
std::optional<PyType> explicitType,
std::optional<std::vector<int64_t>> explicitShape,
DefaultingPyMlirContext contextWrapper) {
@@ -755,7 +884,7 @@ class PyDenseElementsAttribute
}
Py_buffer view;
if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
- throw py::error_already_set();
+ throw nb::python_error();
}
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
@@ -778,25 +907,29 @@ class PyDenseElementsAttribute
if (!mlirAttributeIsAInteger(elementAttr) &&
!mlirAttributeIsAFloat(elementAttr)) {
std::string message = "Illegal element type for DenseElementsAttr: ";
- message.append(py::repr(py::cast(elementAttr)));
- throw py::value_error(message);
+ message.append(
+ nb::cast<std::string_view>(nb::repr(nb::cast(elementAttr))));
+ throw nb::value_error(message.c_str());
}
if (!mlirTypeIsAShaped(shapedType) ||
!mlirShapedTypeHasStaticShape(shapedType)) {
std::string message =
"Expected a static ShapedType for the shaped_type parameter: ";
- message.append(py::repr(py::cast(shapedType)));
- throw py::value_error(message);
+ message.append(
+ nb::cast<std::string_view>(nb::repr(nb::cast(shapedType))));
+ throw nb::value_error(message.c_str());
}
MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
MlirType attrType = mlirAttributeGetType(elementAttr);
if (!mlirTypeEqual(shapedElementType, attrType)) {
std::string message =
"Shaped element type and attribute type must be equal: shaped=";
- message.append(py::repr(py::cast(shapedType)));
+ message.append(
+ nb::cast<std::string_view>(nb::repr(nb::cast(shapedType))));
message.append(", element=");
- message.append(py::repr(py::cast(elementAttr)));
- throw py::value_error(message);
+ message.append(
+ nb::cast<std::string_view>(nb::repr(nb::cast(elementAttr))));
+ throw nb::value_error(message.c_str());
}
MlirAttribute elements =
@@ -806,7 +939,7 @@ class PyDenseElementsAttribute
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
- py::buffer_info accessBuffer() {
+ nb_buffer_info accessBuffer() {
MlirType shapedType = mlirAttributeGetType(*this);
MlirType elementType = mlirShapedTypeGetElementType(shapedType);
std::string format;
@@ -889,32 +1022,36 @@ class PyDenseElementsAttribute
static void bindDerived(ClassTy &c) {
c.def("__len__", &PyDenseElementsAttribute::dunderLen)
.def_static("get", PyDenseElementsAttribute::getFromBuffer,
- py::arg("array"), py::arg("signless") = true,
- py::arg("type") = py::none(), py::arg("shape") = py::none(),
- py::arg("context") = py::none(),
+ nb::arg("array"), nb::arg("signless") = true,
+ nb::arg("type").none() = nb::none(),
+ nb::arg("shape").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
kDenseElementsAttrGetDocstring)
.def_static("get", PyDenseElementsAttribute::getFromList,
- py::arg("attrs"), py::arg("type") = py::none(),
- py::arg("context") = py::none(),
+ nb::arg("attrs"), nb::arg("type").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
kDenseElementsAttrGetFromListDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
- py::arg("shaped_type"), py::arg("element_attr"),
+ nb::arg("shaped_type"), nb::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
- .def_property_readonly("is_splat",
- [](PyDenseElementsAttribute &self) -> bool {
- return mlirDenseElementsAttrIsSplat(self);
- })
- .def("get_splat_value",
- [](PyDenseElementsAttribute &self) {
- if (!mlirDenseElementsAttrIsSplat(self))
- throw py::value_error(
- "get_splat_value called on a non-splat attribute");
- return mlirDenseElementsAttrGetSplatValue(self);
- })
- .def_buffer(&PyDenseElementsAttribute::accessBuffer);
+ .def_prop_ro("is_splat",
+ [](PyDenseElementsAttribute &self) -> bool {
+ return mlirDenseElementsAttrIsSplat(self);
+ })
+ .def("get_splat_value", [](PyDenseElementsAttribute &self) {
+ if (!mlirDenseElementsAttrIsSplat(self))
+ throw nb::value_error(
+ "get_splat_value called on a non-splat attribute");
+ return mlirDenseElementsAttrGetSplatValue(self);
+ });
}
+ static PyType_Slot slots[];
+
private:
+ static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
+ static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+
static bool isUnsignedIntegerFormat(std::string_view format) {
if (format.empty())
return false;
@@ -1039,27 +1176,27 @@ class PyDenseElementsAttribute
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
}
- // There is a complication for boolean numpy arrays, as numpy represents them
- // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
- // per byte.
+ // There is a complication for boolean numpy arrays, as numpy represents
+ // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
+ // booleans per byte.
static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
MlirContext &context) {
if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian systems
- // we will throw
- throw py::type_error("Constructing a bit-packed MLIR attribute is "
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nb::type_error("Constructing a bit-packed MLIR attribute is "
"unsupported on big-endian systems");
}
+ nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
+ /*data=*/static_cast<uint8_t *>(view.buf),
+ /*shape=*/{static_cast<size_t>(view.len)});
- py::array_t<uint8_t> unpackedArray(view.len,
- static_cast<uint8_t *>(view.buf));
-
- py::module numpy = py::module::import("numpy");
- py::object packbitsFunc = numpy.attr("packbits");
- py::object packedBooleans =
- packbitsFunc(unpackedArray, "bitorder"_a = "little");
- py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request();
+ nb::module_ numpy = nb::module_::import_("numpy");
+ nb::object packbitsFunc = numpy.attr("packbits");
+ nb::object packedBooleans =
+ packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
+ nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
MlirType bitpackedType =
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
@@ -1073,11 +1210,11 @@ class PyDenseElementsAttribute
// This does the opposite transformation of
// `getBitpackedAttributeFromBooleanBuffer`
- py::buffer_info getBooleanBufferFromBitpackedAttribute() {
+ nb_buffer_info getBooleanBufferFromBitpackedAttribute() {
if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian systems
- // we will throw
- throw py::type_error("Constructing a numpy array from a MLIR attribute "
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nb::type_error("Constructing a numpy array from a MLIR attribute "
"is unsupported on big-endian systems");
}
@@ -1085,21 +1222,24 @@ class PyDenseElementsAttribute
int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
uint8_t *bitpackedData = static_cast<uint8_t *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
+ nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
+ /*data=*/bitpackedData,
+ /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
- py::module numpy = py::module::import("numpy");
- py::object unpackbitsFunc = numpy.attr("unpackbits");
- py::object equalFunc = numpy.attr("equal");
- py::object reshapeFunc = numpy.attr("reshape");
- py::array unpackedBooleans =
- unpackbitsFunc(packedArray, "bitorder"_a = "little");
+ nb::module_ numpy = nb::module_::import_("numpy");
+ nb::object unpackbitsFunc = numpy.attr("unpackbits");
+ nb::object equalFunc = numpy.attr("equal");
+ nb::object reshapeFunc = numpy.attr("reshape");
+ nb::object unpackedBooleans =
+ unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
// Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
// We need to:
// 1. Slice away the padded bits
// 2. Make the boolean array have the correct shape
// 3. Convert the array to a boolean array
- unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)];
+ unpackedBooleans = unpackedBooleans[nb::slice(
+ nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
unpackedBooleans = equalFunc(unpackedBooleans, 1);
MlirType shapedType = mlirAttributeGetType(*this);
@@ -1110,15 +1250,15 @@ class PyDenseElementsAttribute
}
unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
- // Make sure the returned py::buffer_view claims ownership of the data in
+ // Make sure the returned nb::buffer_view claims ownership of the data in
// `pythonBuffer` so it remains valid when Python reads it
- py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
+ nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
return pythonBuffer.request();
}
template <typename Type>
- py::buffer_info bufferInfo(MlirType shapedType,
- const char *explicitFormat = nullptr) {
+ nb_buffer_info bufferInfo(MlirType shapedType,
+ const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the data for the buffer_info.
// Buffer is configured for read-only access below.
@@ -1142,19 +1282,69 @@ class PyDenseElementsAttribute
}
strides.push_back(sizeof(Type));
}
- std::string format;
+ const char *format;
if (explicitFormat) {
format = explicitFormat;
} else {
- format = py::format_descriptor<Type>::format();
+ format = nb_format_descriptor<Type>::format();
}
- return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
- /*readonly=*/true);
+ return nb_buffer_info(data, sizeof(Type), format, rank, std::move(shape),
+ std::move(strides),
+ /*readonly=*/true);
}
}; // namespace
-/// Refinement of the PyDenseElementsAttribute for attributes containing integer
-/// (and boolean) values. Supports element access.
+PyType_Slot PyDenseElementsAttribute::slots[] = {
+ {Py_bf_getbuffer,
+ reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
+ {Py_bf_releasebuffer,
+ reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
+ {0, nullptr},
+};
+
+/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
+ Py_buffer *view,
+ int flags) {
+ view->obj = nullptr;
+ nb_buffer_info info;
+ try {
+ auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
+ info = attr->accessBuffer();
+ } catch (nb::python_error &e) {
+ e.restore();
+ nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
+ return -1;
+ }
+ view->obj = obj;
+ view->ndim = 1;
+ view->buf = info.ptr;
+ view->itemsize = info.itemsize;
+ view->len = info.itemsize;
+ for (auto s : info.shape) {
+ view->len *= s;
+ }
+ view->readonly = info.readonly;
+ if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
+ view->format = const_cast<char *>(info.format);
+ }
+ if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
+ view->ndim = static_cast<int>(info.ndim);
+ view->strides = info.strides.data();
+ view->shape = info.shape.data();
+ }
+ view->suboffsets = nullptr;
+ view->internal = new nb_buffer_info(std::move(info));
+ Py_INCREF(obj);
+ return 0;
+}
+
+/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
+ Py_buffer *view) {
+ delete reinterpret_cast<nb_buffer_info *>(view->internal);
+}
+
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
class PyDenseIntElementsAttribute
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
PyDenseElementsAttribute> {
@@ -1163,11 +1353,11 @@ class PyDenseIntElementsAttribute
static constexpr const char *pyClassName = "DenseIntElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
- /// Returns the element at the given linear position. Asserts if the index is
- /// out of range.
- py::int_ dunderGetItem(intptr_t pos) {
+ /// Returns the element at the given linear position. Asserts if the index
+ /// is out of range.
+ nb::object dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
- throw py::index_error("attempt to access out of bounds element");
+ throw nb::index_error("attempt to access out of bounds element");
}
MlirType type = mlirAttributeGetType(*this);
@@ -1175,7 +1365,7 @@ class PyDenseIntElementsAttribute
assert(mlirTypeIsAInteger(type) &&
"expected integer element type in dense int elements attribute");
// Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. py::int_ is implicitly constructible
+ // elemental type of the attribute. nb::int_ is implicitly constructible
// from any C++ integral type and handles bitwidth correctly.
// TODO: consider caching the type properties in the constructor to avoid
// querying them on each element access.
@@ -1183,38 +1373,38 @@ class PyDenseIntElementsAttribute
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
if (isUnsigned) {
if (width == 1) {
- return mlirDenseElementsAttrGetBoolValue(*this, pos);
+ return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
}
if (width == 8) {
- return mlirDenseElementsAttrGetUInt8Value(*this, pos);
+ return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
}
if (width == 16) {
- return mlirDenseElementsAttrGetUInt16Value(*this, pos);
+ return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
}
if (width == 32) {
- return mlirDenseElementsAttrGetUInt32Value(*this, pos);
+ return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
}
if (width == 64) {
- return mlirDenseElementsAttrGetUInt64Value(*this, pos);
+ return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
}
} else {
if (width == 1) {
- return mlirDenseElementsAttrGetBoolValue(*this, pos);
+ return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
}
if (width == 8) {
- return mlirDenseElementsAttrGetInt8Value(*this, pos);
+ return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
}
if (width == 16) {
- return mlirDenseElementsAttrGetInt16Value(*this, pos);
+ return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
}
if (width == 32) {
- return mlirDenseElementsAttrGetInt32Value(*this, pos);
+ return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
}
if (width == 64) {
- return mlirDenseElementsAttrGetInt64Value(*this, pos);
+ return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
}
}
- throw py::type_error("Unsupported integer type");
+ throw nb::type_error("Unsupported integer type");
}
static void bindDerived(ClassTy &c) {
@@ -1231,7 +1421,7 @@ class PyDenseResourceElementsAttribute
using PyConcreteAttribute::PyConcreteAttribute;
static PyDenseResourceElementsAttribute
- getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type,
+ getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type,
std::optional<size_t> alignment, bool isMutable,
DefaultingPyMlirContext contextWrapper) {
if (!mlirTypeIsAShaped(type)) {
@@ -1244,7 +1434,7 @@ class PyDenseResourceElementsAttribute
int flags = PyBUF_STRIDES;
std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
- throw py::error_already_set();
+ throw nb::python_error();
}
// This scope releaser will only release if we haven't yet transferred
@@ -1289,12 +1479,12 @@ class PyDenseResourceElementsAttribute
}
static void bindDerived(ClassTy &c) {
- c.def_static("get_from_buffer",
- PyDenseResourceElementsAttribute::getFromBuffer,
- py::arg("array"), py::arg("name"), py::arg("type"),
- py::arg("alignment") = py::none(),
- py::arg("is_mutable") = false, py::arg("context") = py::none(),
- kDenseResourceElementsAttrGetFromBufferDocstring);
+ c.def_static(
+ "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
+ nb::arg("array"), nb::arg("name"), nb::arg("type"),
+ nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
+ nb::arg("context").none() = nb::none(),
+ kDenseResourceElementsAttrGetFromBufferDocstring);
}
};
@@ -1318,12 +1508,12 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
c.def("__len__", &PyDictAttribute::dunderLen);
c.def_static(
"get",
- [](py::dict attributes, DefaultingPyMlirContext context) {
+ [](nb::dict attributes, DefaultingPyMlirContext context) {
SmallVector<MlirNamedAttribute> mlirNamedAttributes;
mlirNamedAttributes.reserve(attributes.size());
- for (auto &it : attributes) {
- auto &mlirAttr = it.second.cast<PyAttribute &>();
- auto name = it.first.cast<std::string>();
+ for (std::pair<nb::handle, nb::handle> it : attributes) {
+ auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
+ auto name = nb::cast<std::string>(it.first);
mlirNamedAttributes.push_back(mlirNamedAttributeGet(
mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
toMlirStringRef(name)),
@@ -1334,18 +1524,18 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
mlirNamedAttributes.data());
return PyDictAttribute(context->getRef(), attr);
},
- py::arg("value") = py::dict(), py::arg("context") = py::none(),
+ nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(),
"Gets an uniqued dict attribute");
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
MlirAttribute attr =
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
if (mlirAttributeIsNull(attr))
- throw py::key_error("attempt to access a non-existent attribute");
+ throw nb::key_error("attempt to access a non-existent attribute");
return attr;
});
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
- throw py::index_error("attempt to access out of bounds attribute");
+ throw nb::index_error("attempt to access out of bounds attribute");
}
MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
return PyNamedAttribute(
@@ -1365,25 +1555,25 @@ class PyDenseFPElementsAttribute
static constexpr const char *pyClassName = "DenseFPElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
- py::float_ dunderGetItem(intptr_t pos) {
+ nb::float_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
- throw py::index_error("attempt to access out of bounds element");
+ throw nb::index_error("attempt to access out of bounds element");
}
MlirType type = mlirAttributeGetType(*this);
type = mlirShapedTypeGetElementType(type);
// Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. py::float_ is implicitly constructible
+ // elemental type of the attribute. nb::float_ is implicitly constructible
// from float and double.
// TODO: consider caching the type properties in the constructor to avoid
// querying them on each element access.
if (mlirTypeIsAF32(type)) {
- return mlirDenseElementsAttrGetFloatValue(*this, pos);
+ return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
}
if (mlirTypeIsAF64(type)) {
- return mlirDenseElementsAttrGetDoubleValue(*this, pos);
+ return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
}
- throw py::type_error("Unsupported floating-point type");
+ throw nb::type_error("Unsupported floating-point type");
}
static void bindDerived(ClassTy &c) {
@@ -1406,9 +1596,9 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
MlirAttribute attr = mlirTypeAttrGet(value.get());
return PyTypeAttribute(context->getRef(), attr);
},
- py::arg("value"), py::arg("context") = py::none(),
+ nb::arg("value"), nb::arg("context").none() = nb::none(),
"Gets a uniqued Type attribute");
- c.def_property_readonly("value", [](PyTypeAttribute &self) {
+ c.def_prop_ro("value", [](PyTypeAttribute &self) {
return mlirTypeAttrGetValue(self.get());
});
}
@@ -1430,7 +1620,7 @@ class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
return PyUnitAttribute(context->getRef(),
mlirUnitAttrGet(context->get()));
},
- py::arg("context") = py::none(), "Create a Unit attribute.");
+ nb::arg("context").none() = nb::none(), "Create a Unit attribute.");
}
};
@@ -1453,7 +1643,8 @@ class PyStridedLayoutAttribute
ctx->get(), offset, strides.size(), strides.data());
return PyStridedLayoutAttribute(ctx->getRef(), attr);
},
- py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
+ nb::arg("offset"), nb::arg("strides"),
+ nb::arg("context").none() = nb::none(),
"Gets a strided layout attribute.");
c.def_static(
"get_fully_dynamic",
@@ -1465,16 +1656,17 @@ class PyStridedLayoutAttribute
ctx->get(), dynamic, strides.size(), strides.data());
return PyStridedLayoutAttribute(ctx->getRef(), attr);
},
- py::arg("rank"), py::arg("context") = py::none(),
- "Gets a strided layout attribute with dynamic offset and strides of a "
+ nb::arg("rank"), nb::arg("context").none() = nb::none(),
+ "Gets a strided layout attribute with dynamic offset and strides of "
+ "a "
"given rank.");
- c.def_property_readonly(
+ c.def_prop_ro(
"offset",
[](PyStridedLayoutAttribute &self) {
return mlirStridedLayoutAttrGetOffset(self);
},
"Returns the value of the float point attribute");
- c.def_property_readonly(
+ c.def_prop_ro(
"strides",
[](PyStridedLayoutAttribute &self) {
intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
@@ -1488,63 +1680,64 @@ class PyStridedLayoutAttribute
}
};
-py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
+nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
+ return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
+ return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
+ return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
+ return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
+ return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
+ return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
+ return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
std::string msg =
std::string("Can't cast unknown element type DenseArrayAttr (") +
- std::string(py::repr(py::cast(pyAttribute))) + ")";
- throw py::cast_error(msg);
+ nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
+ throw nb::type_error(msg.c_str());
}
-py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
+nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseFPElementsAttribute(pyAttribute));
+ return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
- return py::cast(PyDenseIntElementsAttribute(pyAttribute));
+ return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
std::string msg =
std::string(
"Can't cast unknown element type DenseIntOrFPElementsAttr (") +
- std::string(py::repr(py::cast(pyAttribute))) + ")";
- throw py::cast_error(msg);
+ nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
+ throw nb::type_error(msg.c_str());
}
-py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
+nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
if (PyBoolAttribute::isaFunction(pyAttribute))
- return py::cast(PyBoolAttribute(pyAttribute));
+ return nb::cast(PyBoolAttribute(pyAttribute));
if (PyIntegerAttribute::isaFunction(pyAttribute))
- return py::cast(PyIntegerAttribute(pyAttribute));
+ return nb::cast(PyIntegerAttribute(pyAttribute));
std::string msg =
std::string("Can't cast unknown element type DenseArrayAttr (") +
- std::string(py::repr(py::cast(pyAttribute))) + ")";
- throw py::cast_error(msg);
+ nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
+ throw nb::type_error(msg.c_str());
}
-py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
+nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
- return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
+ return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
if (PySymbolRefAttribute::isaFunction(pyAttribute))
- return py::cast(PySymbolRefAttribute(pyAttribute));
+ return nb::cast(PySymbolRefAttribute(pyAttribute));
std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
- std::string(py::repr(py::cast(pyAttribute))) + ")";
- throw py::cast_error(msg);
+ nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
+ ")";
+ throw nb::type_error(msg.c_str());
}
} // namespace
-void mlir::python::populateIRAttributes(py::module &m) {
+void mlir::python::populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
@@ -1562,24 +1755,26 @@ void mlir::python::populateIRAttributes(py::module &m) {
PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
PyGlobals::get().registerTypeCaster(
mlirDenseArrayAttrGetTypeID(),
- pybind11::cpp_function(denseArrayAttributeCaster));
+ nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
PyArrayAttribute::bind(m);
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
PyBoolAttribute::bind(m);
- PyDenseElementsAttribute::bind(m);
+ PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
PyDenseFPElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirDenseIntOrFPElementsAttrGetTypeID(),
- pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
+ nb::cast<nb::callable>(
+ nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
PyDenseResourceElementsAttribute::bind(m);
PyDictAttribute::bind(m);
PySymbolRefAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirSymbolRefAttrGetTypeID(),
- pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
+ nb::cast<nb::callable>(
+ nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
PyFlatSymbolRefAttribute::bind(m);
PyOpaqueAttribute::bind(m);
@@ -1590,7 +1785,7 @@ void mlir::python::populateIRAttributes(py::module &m) {
PyTypeAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirIntegerAttrGetTypeID(),
- pybind11::cpp_function(integerOrBoolAttributeCaster));
+ nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3e96f8c60ba7cd..ff4ad1a0d806c7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,26 +6,31 @@
//
//===----------------------------------------------------------------------===//
-#include "IRModule.h"
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/function.h>
+#include <nanobind/stl/optional.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/tuple.h>
+#include <nanobind/stl/vector.h>
-#include "Globals.h"
-#include "PybindUtils.h"
+#include <optional>
+#include <utility>
+#include "Globals.h"
+#include "IRModule.h"
+#include "NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
-#include <optional>
-#include <utility>
-
-namespace py = pybind11;
-using namespace py::literals;
+namespace nb = nanobind;
+using namespace nb::literals;
using namespace mlir;
using namespace mlir::python;
@@ -190,18 +195,18 @@ operations.
/// Helper for creating an @classmethod.
template <class Func, typename... Args>
-py::object classmethod(Func f, Args... args) {
- py::object cf = py::cpp_function(f, args...);
- return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
+nb::object classmethod(Func f, Args... args) {
+ nb::object cf = nb::cpp_function(f, args...);
+ return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
}
-static py::object
+static nb::object
createCustomDialectWrapper(const std::string &dialectNamespace,
- py::object dialectDescriptor) {
+ nb::object dialectDescriptor) {
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
if (!dialectClass) {
// Use the base class.
- return py::cast(PyDialect(std::move(dialectDescriptor)));
+ return nb::cast(PyDialect(std::move(dialectDescriptor)));
}
// Create the custom implementation.
@@ -212,42 +217,47 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
+static MlirStringRef toMlirStringRef(const nb::bytes &s) {
+ return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
+}
+
/// Create a block, using the current location context if no locations are
/// specified.
-static MlirBlock createBlock(const py::sequence &pyArgTypes,
- const std::optional<py::sequence> &pyArgLocs) {
+static MlirBlock createBlock(const nb::sequence &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
SmallVector<MlirType> argTypes;
- argTypes.reserve(pyArgTypes.size());
+ argTypes.reserve(nb::len(pyArgTypes));
for (const auto &pyType : pyArgTypes)
- argTypes.push_back(pyType.cast<PyType &>());
+ argTypes.push_back(nb::cast<PyType &>(pyType));
SmallVector<MlirLocation> argLocs;
if (pyArgLocs) {
- argLocs.reserve(pyArgLocs->size());
+ argLocs.reserve(nb::len(*pyArgLocs));
for (const auto &pyLoc : *pyArgLocs)
- argLocs.push_back(pyLoc.cast<PyLocation &>());
+ argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
} else if (!argTypes.empty()) {
argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
}
if (argTypes.size() != argLocs.size())
- throw py::value_error(("Expected " + Twine(argTypes.size()) +
+ throw nb::value_error(("Expected " + Twine(argTypes.size()) +
" locations, got: " + Twine(argLocs.size()))
- .str());
+ .str()
+ .c_str());
return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
}
/// Wrapper for the global LLVM debugging flag.
struct PyGlobalDebugFlag {
- static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
+ static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
- static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
+ static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
- static void bind(py::module &m) {
+ static void bind(nb::module_ &m) {
// Debug flags.
- py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
- .def_property_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
+ nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
.def_static(
"set_types",
[](const std::string &type) {
@@ -268,20 +278,20 @@ struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
}
- static py::function dundeGetItemNamed(const std::string &attributeKind) {
+ static nb::callable dundeGetItemNamed(const std::string &attributeKind) {
auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
if (!builder)
- throw py::key_error(attributeKind);
+ throw nb::key_error(attributeKind.c_str());
return *builder;
}
static void dundeSetItemNamed(const std::string &attributeKind,
- py::function func, bool replace) {
+ nb::callable func, bool replace) {
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
replace);
}
- static void bind(py::module &m) {
- py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
.def_static("contains", &PyAttrBuilderMap::dunderContains)
.def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
.def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
@@ -295,8 +305,8 @@ struct PyAttrBuilderMap {
// PyBlock
//------------------------------------------------------------------------------
-py::object PyBlock::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonBlockToCapsule(get()));
+nb::object PyBlock::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
}
//------------------------------------------------------------------------------
@@ -315,14 +325,14 @@ class PyRegionIterator {
PyRegion dunderNext() {
operation->checkValid();
if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
- throw py::stop_iteration();
+ throw nb::stop_iteration();
}
MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
return PyRegion(operation, region);
}
- static void bind(py::module &m) {
- py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyRegionIterator>(m, "RegionIterator")
.def("__iter__", &PyRegionIterator::dunderIter)
.def("__next__", &PyRegionIterator::dunderNext);
}
@@ -351,14 +361,14 @@ class PyRegionList {
PyRegion dunderGetItem(intptr_t index) {
// dunderLen checks validity.
if (index < 0 || index >= dunderLen()) {
- throw py::index_error("attempt to access out of bounds region");
+ throw nb::index_error("attempt to access out of bounds region");
}
MlirRegion region = mlirOperationGetRegion(operation->get(), index);
return PyRegion(operation, region);
}
- static void bind(py::module &m) {
- py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyRegionList>(m, "RegionSequence")
.def("__len__", &PyRegionList::dunderLen)
.def("__iter__", &PyRegionList::dunderIter)
.def("__getitem__", &PyRegionList::dunderGetItem);
@@ -378,7 +388,7 @@ class PyBlockIterator {
PyBlock dunderNext() {
operation->checkValid();
if (mlirBlockIsNull(next)) {
- throw py::stop_iteration();
+ throw nb::stop_iteration();
}
PyBlock returnBlock(operation, next);
@@ -386,8 +396,8 @@ class PyBlockIterator {
return returnBlock;
}
- static void bind(py::module &m) {
- py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyBlockIterator>(m, "BlockIterator")
.def("__iter__", &PyBlockIterator::dunderIter)
.def("__next__", &PyBlockIterator::dunderNext);
}
@@ -424,7 +434,7 @@ class PyBlockList {
PyBlock dunderGetItem(intptr_t index) {
operation->checkValid();
if (index < 0) {
- throw py::index_error("attempt to access out of bounds block");
+ throw nb::index_error("attempt to access out of bounds block");
}
MlirBlock block = mlirRegionGetFirstBlock(region);
while (!mlirBlockIsNull(block)) {
@@ -434,24 +444,26 @@ class PyBlockList {
block = mlirBlockGetNextInRegion(block);
index -= 1;
}
- throw py::index_error("attempt to access out of bounds block");
+ throw nb::index_error("attempt to access out of bounds block");
}
- PyBlock appendBlock(const py::args &pyArgTypes,
- const std::optional<py::sequence> &pyArgLocs) {
+ PyBlock appendBlock(const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
operation->checkValid();
- MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
mlirRegionAppendOwnedBlock(region, block);
return PyBlock(operation, block);
}
- static void bind(py::module &m) {
- py::class_<PyBlockList>(m, "BlockList", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyBlockList>(m, "BlockList")
.def("__getitem__", &PyBlockList::dunderGetItem)
.def("__iter__", &PyBlockList::dunderIter)
.def("__len__", &PyBlockList::dunderLen)
.def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
- py::arg("arg_locs") = std::nullopt);
+ nb::arg("args"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt);
}
private:
@@ -466,10 +478,10 @@ class PyOperationIterator {
PyOperationIterator &dunderIter() { return *this; }
- py::object dunderNext() {
+ nb::object dunderNext() {
parentOperation->checkValid();
if (mlirOperationIsNull(next)) {
- throw py::stop_iteration();
+ throw nb::stop_iteration();
}
PyOperationRef returnOperation =
@@ -478,8 +490,8 @@ class PyOperationIterator {
return returnOperation->createOpView();
}
- static void bind(py::module &m) {
- py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyOperationIterator>(m, "OperationIterator")
.def("__iter__", &PyOperationIterator::dunderIter)
.def("__next__", &PyOperationIterator::dunderNext);
}
@@ -515,10 +527,10 @@ class PyOperationList {
return count;
}
- py::object dunderGetItem(intptr_t index) {
+ nb::object dunderGetItem(intptr_t index) {
parentOperation->checkValid();
if (index < 0) {
- throw py::index_error("attempt to access out of bounds operation");
+ throw nb::index_error("attempt to access out of bounds operation");
}
MlirOperation childOp = mlirBlockGetFirstOperation(block);
while (!mlirOperationIsNull(childOp)) {
@@ -529,11 +541,11 @@ class PyOperationList {
childOp = mlirOperationGetNextInBlock(childOp);
index -= 1;
}
- throw py::index_error("attempt to access out of bounds operation");
+ throw nb::index_error("attempt to access out of bounds operation");
}
- static void bind(py::module &m) {
- py::class_<PyOperationList>(m, "OperationList", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyOperationList>(m, "OperationList")
.def("__getitem__", &PyOperationList::dunderGetItem)
.def("__iter__", &PyOperationList::dunderIter)
.def("__len__", &PyOperationList::dunderLen);
@@ -548,7 +560,7 @@ class PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
- py::object getOwner() {
+ nb::object getOwner() {
MlirOperation owner = mlirOpOperandGetOwner(opOperand);
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(owner));
@@ -557,11 +569,10 @@ class PyOpOperand {
size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
- static void bind(py::module &m) {
- py::class_<PyOpOperand>(m, "OpOperand", py::module_local())
- .def_property_readonly("owner", &PyOpOperand::getOwner)
- .def_property_readonly("operand_number",
- &PyOpOperand::getOperandNumber);
+ static void bind(nb::module_ &m) {
+ nb::class_<PyOpOperand>(m, "OpOperand")
+ .def_prop_ro("owner", &PyOpOperand::getOwner)
+ .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
}
private:
@@ -576,15 +587,15 @@ class PyOpOperandIterator {
PyOpOperand dunderNext() {
if (mlirOpOperandIsNull(opOperand))
- throw py::stop_iteration();
+ throw nb::stop_iteration();
PyOpOperand returnOpOperand(opOperand);
opOperand = mlirOpOperandGetNextUse(opOperand);
return returnOpOperand;
}
- static void bind(py::module &m) {
- py::class_<PyOpOperandIterator>(m, "OpOperandIterator", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
.def("__iter__", &PyOpOperandIterator::dunderIter)
.def("__next__", &PyOpOperandIterator::dunderNext);
}
@@ -600,7 +611,7 @@ class PyOpOperandIterator {
//------------------------------------------------------------------------------
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
- py::gil_scoped_acquire acquire;
+ nb::gil_scoped_acquire acquire;
auto &liveContexts = getLiveContexts();
liveContexts[context.ptr] = this;
}
@@ -609,41 +620,36 @@ PyMlirContext::~PyMlirContext() {
// Note that the only public way to construct an instance is via the
// forContext method, which always puts the associated handle into
// liveContexts.
- py::gil_scoped_acquire acquire;
+ nb::gil_scoped_acquire acquire;
getLiveContexts().erase(context.ptr);
mlirContextDestroy(context);
}
-py::object PyMlirContext::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
+nb::object PyMlirContext::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
}
-py::object PyMlirContext::createFromCapsule(py::object capsule) {
+nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
if (mlirContextIsNull(rawContext))
- throw py::error_already_set();
+ throw nb::python_error();
return forContext(rawContext).releaseObject();
}
-PyMlirContext *PyMlirContext::createNewContextForInit() {
- MlirContext context = mlirContextCreateWithThreading(false);
- return new PyMlirContext(context);
-}
-
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
- py::gil_scoped_acquire acquire;
+ nb::gil_scoped_acquire acquire;
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
// Create.
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
- py::object pyRef = py::cast(unownedContextWrapper);
- assert(pyRef && "cast to py::object failed");
+ nb::object pyRef = nb::cast(unownedContextWrapper);
+ assert(pyRef && "cast to nb::object failed");
liveContexts[context.ptr] = unownedContextWrapper;
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
}
// Use existing.
- py::object pyRef = py::cast(it->second);
+ nb::object pyRef = nb::cast(it->second);
return PyMlirContextRef(it->second, std::move(pyRef));
}
@@ -717,23 +723,23 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
-pybind11::object PyMlirContext::contextEnter() {
- return PyThreadContextEntry::pushContext(*this);
+nb::object PyMlirContext::contextEnter(nb::object context) {
+ return PyThreadContextEntry::pushContext(context);
}
-void PyMlirContext::contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb) {
+void PyMlirContext::contextExit(const nb::object &excType,
+ const nb::object &excVal,
+ const nb::object &excTb) {
PyThreadContextEntry::popContext(*this);
}
-py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
+nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
// Note that ownership is transferred to the delete callback below by way of
// an explicit inc_ref (borrow).
PyDiagnosticHandler *pyHandler =
new PyDiagnosticHandler(get(), std::move(callback));
- py::object pyHandlerObject =
- py::cast(pyHandler, py::return_value_policy::take_ownership);
+ nb::object pyHandlerObject =
+ nb::cast(pyHandler, nb::rv_policy::take_ownership);
pyHandlerObject.inc_ref();
// In these C callbacks, the userData is a PyDiagnosticHandler* that is
@@ -741,17 +747,17 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
auto handlerCallback =
+[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
- py::object pyDiagnosticObject =
- py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
+ nb::object pyDiagnosticObject =
+ nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
bool result = false;
{
// Since this can be called from arbitrary C++ contexts, always get the
// gil.
- py::gil_scoped_acquire gil;
+ nb::gil_scoped_acquire gil;
try {
- result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
+ result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
} catch (std::exception &e) {
fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
e.what());
@@ -768,8 +774,7 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
pyHandler->registeredID.reset();
// Decrement reference, balancing the inc_ref() above.
- py::object pyHandlerObject =
- py::cast(pyHandler, py::return_value_policy::reference);
+ nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
pyHandlerObject.dec_ref();
};
@@ -819,9 +824,9 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
return &stack.back();
}
-void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
- py::object insertionPoint,
- py::object location) {
+void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
+ nb::object insertionPoint,
+ nb::object location) {
auto &stack = getStack();
stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
std::move(location));
@@ -844,19 +849,19 @@ void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
PyMlirContext *PyThreadContextEntry::getContext() {
if (!context)
return nullptr;
- return py::cast<PyMlirContext *>(context);
+ return nb::cast<PyMlirContext *>(context);
}
PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
if (!insertionPoint)
return nullptr;
- return py::cast<PyInsertionPoint *>(insertionPoint);
+ return nb::cast<PyInsertionPoint *>(insertionPoint);
}
PyLocation *PyThreadContextEntry::getLocation() {
if (!location)
return nullptr;
- return py::cast<PyLocation *>(location);
+ return nb::cast<PyLocation *>(location);
}
PyMlirContext *PyThreadContextEntry::getDefaultContext() {
@@ -874,12 +879,11 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() {
return tos ? tos->getLocation() : nullptr;
}
-py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
- py::object contextObj = py::cast(context);
- push(FrameKind::Context, /*context=*/contextObj,
- /*insertionPoint=*/py::object(),
- /*location=*/py::object());
- return contextObj;
+nb::object PyThreadContextEntry::pushContext(nb::object context) {
+ push(FrameKind::Context, /*context=*/context,
+ /*insertionPoint=*/nb::object(),
+ /*location=*/nb::object());
+ return context;
}
void PyThreadContextEntry::popContext(PyMlirContext &context) {
@@ -892,15 +896,16 @@ void PyThreadContextEntry::popContext(PyMlirContext &context) {
stack.pop_back();
}
-py::object
-PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
- py::object contextObj =
+nb::object
+PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
+ PyInsertionPoint &insertionPoint =
+ nb::cast<PyInsertionPoint &>(insertionPointObj);
+ nb::object contextObj =
insertionPoint.getBlock().getParentOperation()->getContext().getObject();
- py::object insertionPointObj = py::cast(insertionPoint);
push(FrameKind::InsertionPoint,
/*context=*/contextObj,
/*insertionPoint=*/insertionPointObj,
- /*location=*/py::object());
+ /*location=*/nb::object());
return insertionPointObj;
}
@@ -915,11 +920,11 @@ void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
stack.pop_back();
}
-py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
- py::object contextObj = location.getContext().getObject();
- py::object locationObj = py::cast(location);
+nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
+ PyLocation &location = nb::cast<PyLocation &>(locationObj);
+ nb::object contextObj = location.getContext().getObject();
push(FrameKind::Location, /*context=*/contextObj,
- /*insertionPoint=*/py::object(),
+ /*insertionPoint=*/nb::object(),
/*location=*/locationObj);
return locationObj;
}
@@ -941,15 +946,15 @@ void PyThreadContextEntry::popLocation(PyLocation &location) {
void PyDiagnostic::invalidate() {
valid = false;
if (materializedNotes) {
- for (auto ¬eObject : *materializedNotes) {
- PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
+ for (nb::handle noteObject : *materializedNotes) {
+ PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
note->invalidate();
}
}
}
PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
- py::object callback)
+ nb::object callback)
: context(context), callback(std::move(callback)) {}
PyDiagnosticHandler::~PyDiagnosticHandler() = default;
@@ -984,32 +989,36 @@ PyLocation PyDiagnostic::getLocation() {
return PyLocation(PyMlirContext::forContext(context), loc);
}
-py::str PyDiagnostic::getMessage() {
+nb::str PyDiagnostic::getMessage() {
checkValid();
- py::object fileObject = py::module::import("io").attr("StringIO")();
+ nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
PyFileAccumulator accum(fileObject, /*binary=*/false);
mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
- return fileObject.attr("getvalue")();
+ return nb::cast<nb::str>(fileObject.attr("getvalue")());
}
-py::tuple PyDiagnostic::getNotes() {
+nb::tuple PyDiagnostic::getNotes() {
checkValid();
if (materializedNotes)
return *materializedNotes;
intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
- materializedNotes = py::tuple(numNotes);
+ nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
for (intptr_t i = 0; i < numNotes; ++i) {
MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
- (*materializedNotes)[i] = PyDiagnostic(noteDiag);
+ nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
+ PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
}
+ materializedNotes = std::move(notes);
+
return *materializedNotes;
}
PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
std::vector<DiagnosticInfo> notes;
- for (py::handle n : getNotes())
- notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
- return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
+ for (nb::handle n : getNotes())
+ notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
+ return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
+ std::move(notes)};
}
//------------------------------------------------------------------------------
@@ -1023,22 +1032,21 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
if (mlirDialectIsNull(dialect)) {
std::string msg = (Twine("Dialect '") + key + "' not found").str();
if (attrError)
- throw py::attribute_error(msg);
- throw py::index_error(msg);
+ throw nb::attribute_error(msg.c_str());
+ throw nb::index_error(msg.c_str());
}
return dialect;
}
-py::object PyDialectRegistry::getCapsule() {
- return py::reinterpret_steal<py::object>(
- mlirPythonDialectRegistryToCapsule(*this));
+nb::object PyDialectRegistry::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
}
-PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
+PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) {
MlirDialectRegistry rawRegistry =
mlirPythonCapsuleToDialectRegistry(capsule.ptr());
if (mlirDialectRegistryIsNull(rawRegistry))
- throw py::error_already_set();
+ throw nb::python_error();
return PyDialectRegistry(rawRegistry);
}
@@ -1046,25 +1054,25 @@ PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
// PyLocation
//------------------------------------------------------------------------------
-py::object PyLocation::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
+nb::object PyLocation::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
}
-PyLocation PyLocation::createFromCapsule(py::object capsule) {
+PyLocation PyLocation::createFromCapsule(nb::object capsule) {
MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
if (mlirLocationIsNull(rawLoc))
- throw py::error_already_set();
+ throw nb::python_error();
return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
rawLoc);
}
-py::object PyLocation::contextEnter() {
- return PyThreadContextEntry::pushLocation(*this);
+nb::object PyLocation::contextEnter(nb::object locationObj) {
+ return PyThreadContextEntry::pushLocation(locationObj);
}
-void PyLocation::contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb) {
+void PyLocation::contextExit(const nb::object &excType,
+ const nb::object &excVal,
+ const nb::object &excTb) {
PyThreadContextEntry::popLocation(*this);
}
@@ -1087,7 +1095,7 @@ PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
PyModule::~PyModule() {
- py::gil_scoped_acquire acquire;
+ nb::gil_scoped_acquire acquire;
auto &liveModules = getContext()->liveModules;
assert(liveModules.count(module.ptr) == 1 &&
"destroying module not in live map");
@@ -1099,7 +1107,7 @@ PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
- py::gil_scoped_acquire acquire;
+ nb::gil_scoped_acquire acquire;
auto &liveModules = contextRef->liveModules;
auto it = liveModules.find(module.ptr);
if (it == liveModules.end()) {
@@ -1108,8 +1116,7 @@ PyModuleRef PyModule::forModule(MlirModule module) {
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
- py::object pyRef =
- py::cast(unownedModule, py::return_value_policy::take_ownership);
+ nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
unownedModule->handle = pyRef;
liveModules[module.ptr] =
std::make_pair(unownedModule->handle, unownedModule);
@@ -1117,19 +1124,19 @@ PyModuleRef PyModule::forModule(MlirModule module) {
}
// Use existing.
PyModule *existing = it->second.second;
- py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+ nb::object pyRef = nb::borrow<nb::object>(it->second.first);
return PyModuleRef(existing, std::move(pyRef));
}
-py::object PyModule::createFromCapsule(py::object capsule) {
+nb::object PyModule::createFromCapsule(nb::object capsule) {
MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
if (mlirModuleIsNull(rawModule))
- throw py::error_already_set();
+ throw nb::python_error();
return forModule(rawModule).releaseObject();
}
-py::object PyModule::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
+nb::object PyModule::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
}
//------------------------------------------------------------------------------
@@ -1158,7 +1165,7 @@ PyOperation::~PyOperation() {
PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
- py::object parentKeepAlive) {
+ nb::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
// Create.
PyOperation *unownedOperation =
@@ -1166,8 +1173,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
- py::object pyRef =
- py::cast(unownedOperation, py::return_value_policy::take_ownership);
+ nb::object pyRef = nb::cast(unownedOperation, nb::rv_policy::take_ownership);
unownedOperation->handle = pyRef;
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
@@ -1178,7 +1184,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
- py::object parentKeepAlive) {
+ nb::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
auto it = liveOperations.find(operation.ptr);
if (it == liveOperations.end()) {
@@ -1188,13 +1194,13 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
}
// Use existing.
PyOperation *existing = it->second.second;
- py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+ nb::object pyRef = nb::borrow<nb::object>(it->second.first);
return PyOperationRef(existing, std::move(pyRef));
}
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
- py::object parentKeepAlive) {
+ nb::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
assert(liveOperations.count(operation.ptr) == 0 &&
"cannot create detached operation that already exists");
@@ -1227,12 +1233,12 @@ void PyOperation::checkValid() const {
void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified, py::object fileObject,
+ bool assumeVerified, nb::object fileObject,
bool binary, bool skipRegions) {
PyOperation &operation = getOperation();
operation.checkValid();
if (fileObject.is_none())
- fileObject = py::module::import("sys").attr("stdout");
+ fileObject = nb::module_::import_("sys").attr("stdout");
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (largeElementsLimit)
@@ -1255,18 +1261,18 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
mlirOpPrintingFlagsDestroy(flags);
}
-void PyOperationBase::print(PyAsmState &state, py::object fileObject,
+void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
bool binary) {
PyOperation &operation = getOperation();
operation.checkValid();
if (fileObject.is_none())
- fileObject = py::module::import("sys").attr("stdout");
+ fileObject = nb::module_::import_("sys").attr("stdout");
PyFileAccumulator accum(fileObject, binary);
mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
accum.getUserData());
}
-void PyOperationBase::writeBytecode(const py::object &fileObject,
+void PyOperationBase::writeBytecode(const nb::object &fileObject,
std::optional<int64_t> bytecodeVersion) {
PyOperation &operation = getOperation();
operation.checkValid();
@@ -1282,9 +1288,10 @@ void PyOperationBase::writeBytecode(const py::object &fileObject,
operation, config, accum.getCallback(), accum.getUserData());
mlirBytecodeWriterConfigDestroy(config);
if (mlirLogicalResultIsFailure(res))
- throw py::value_error((Twine("Unable to honor desired bytecode version ") +
+ throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
Twine(*bytecodeVersion))
- .str());
+ .str()
+ .c_str());
}
void PyOperationBase::walk(
@@ -1296,7 +1303,7 @@ void PyOperationBase::walk(
std::function<MlirWalkResult(MlirOperation)> callback;
bool gotException;
std::string exceptionWhat;
- py::object exceptionType;
+ nb::object exceptionType;
};
UserData userData{callback, false, {}, {}};
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
@@ -1304,10 +1311,10 @@ void PyOperationBase::walk(
UserData *calleeUserData = static_cast<UserData *>(userData);
try {
return (calleeUserData->callback)(op);
- } catch (py::error_already_set &e) {
+ } catch (nb::python_error &e) {
calleeUserData->gotException = true;
- calleeUserData->exceptionWhat = e.what();
- calleeUserData->exceptionType = e.type();
+ calleeUserData->exceptionWhat = std::string(e.what());
+ calleeUserData->exceptionType = nb::borrow(e.type());
return MlirWalkResult::MlirWalkResultInterrupt;
}
};
@@ -1319,16 +1326,16 @@ void PyOperationBase::walk(
}
}
-py::object PyOperationBase::getAsm(bool binary,
+nb::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, bool skipRegions) {
- py::object fileObject;
+ nb::object fileObject;
if (binary) {
- fileObject = py::module::import("io").attr("BytesIO")();
+ fileObject = nb::module_::import_("io").attr("BytesIO")();
} else {
- fileObject = py::module::import("io").attr("StringIO")();
+ fileObject = nb::module_::import_("io").attr("StringIO")();
}
print(/*largeElementsLimit=*/largeElementsLimit,
/*enableDebugInfo=*/enableDebugInfo,
@@ -1372,7 +1379,7 @@ bool PyOperationBase::verify() {
std::optional<PyOperationRef> PyOperation::getParentOperation() {
checkValid();
if (!isAttached())
- throw py::value_error("Detached operations have no parent");
+ throw nb::value_error("Detached operations have no parent");
MlirOperation operation = mlirOperationGetParentOperation(get());
if (mlirOperationIsNull(operation))
return {};
@@ -1388,42 +1395,42 @@ PyBlock PyOperation::getBlock() {
return PyBlock{std::move(*parentOperation), block};
}
-py::object PyOperation::getCapsule() {
+nb::object PyOperation::getCapsule() {
checkValid();
- return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
+ return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
}
-py::object PyOperation::createFromCapsule(py::object capsule) {
+nb::object PyOperation::createFromCapsule(nb::object capsule) {
MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
if (mlirOperationIsNull(rawOperation))
- throw py::error_already_set();
+ throw nb::python_error();
MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
.releaseObject();
}
static void maybeInsertOperation(PyOperationRef &op,
- const py::object &maybeIp) {
+ const nb::object &maybeIp) {
// InsertPoint active?
- if (!maybeIp.is(py::cast(false))) {
+ if (!maybeIp.is(nb::cast(false))) {
PyInsertionPoint *ip;
if (maybeIp.is_none()) {
ip = PyThreadContextEntry::getDefaultInsertionPoint();
} else {
- ip = py::cast<PyInsertionPoint *>(maybeIp);
+ ip = nb::cast<PyInsertionPoint *>(maybeIp);
}
if (ip)
ip->insert(*op.get());
}
}
-py::object PyOperation::create(const std::string &name,
+nb::object PyOperation::create(const std::string &name,
std::optional<std::vector<PyType *>> results,
std::optional<std::vector<PyValue *>> operands,
- std::optional<py::dict> attributes,
+ std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
int regions, DefaultingPyLocation location,
- const py::object &maybeIp, bool inferType) {
+ const nb::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirValue, 4> mlirOperands;
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1431,14 +1438,14 @@ py::object PyOperation::create(const std::string &name,
// General parameter validation.
if (regions < 0)
- throw py::value_error("number of regions must be >= 0");
+ throw nb::value_error("number of regions must be >= 0");
// Unpack/validate operands.
if (operands) {
mlirOperands.reserve(operands->size());
for (PyValue *operand : *operands) {
if (!operand)
- throw py::value_error("operand value cannot be None");
+ throw nb::value_error("operand value cannot be None");
mlirOperands.push_back(operand->get());
}
}
@@ -1449,38 +1456,38 @@ py::object PyOperation::create(const std::string &name,
for (PyType *result : *results) {
// TODO: Verify result type originate from the same context.
if (!result)
- throw py::value_error("result type cannot be None");
+ throw nb::value_error("result type cannot be None");
mlirResults.push_back(*result);
}
}
// Unpack/validate attributes.
if (attributes) {
mlirAttributes.reserve(attributes->size());
- for (auto &it : *attributes) {
+ for (std::pair<nb::handle, nb::handle> it : *attributes) {
std::string key;
try {
- key = it.first.cast<std::string>();
- } catch (py::cast_error &err) {
+ key = nb::cast<std::string>(it.first);
+ } catch (nb::cast_error &err) {
std::string msg = "Invalid attribute key (not a string) when "
"attempting to create the operation \"" +
name + "\" (" + err.what() + ")";
- throw py::cast_error(msg);
+ throw nb::type_error(msg.c_str());
}
try {
- auto &attribute = it.second.cast<PyAttribute &>();
+ auto &attribute = nb::cast<PyAttribute &>(it.second);
// TODO: Verify attribute originates from the same context.
mlirAttributes.emplace_back(std::move(key), attribute);
- } catch (py::reference_cast_error &) {
+ } catch (nb::cast_error &err) {
+ std::string msg = "Invalid attribute value for the key \"" + key +
+ "\" when attempting to create the operation \"" +
+ name + "\" (" + err.what() + ")";
+ throw nb::type_error(msg.c_str());
+ } catch (std::runtime_error &err) {
// This exception seems thrown when the value is "None".
std::string msg =
"Found an invalid (`None`?) attribute value for the key \"" + key +
"\" when attempting to create the operation \"" + name + "\"";
- throw py::cast_error(msg);
- } catch (py::cast_error &err) {
- std::string msg = "Invalid attribute value for the key \"" + key +
- "\" when attempting to create the operation \"" +
- name + "\" (" + err.what() + ")";
- throw py::cast_error(msg);
+ throw std::runtime_error(msg);
}
}
}
@@ -1490,7 +1497,7 @@ py::object PyOperation::create(const std::string &name,
for (auto *successor : *successors) {
// TODO: Verify successor originate from the same context.
if (!successor)
- throw py::value_error("successor block cannot be None");
+ throw nb::value_error("successor block cannot be None");
mlirSuccessors.push_back(successor->get());
}
}
@@ -1535,7 +1542,7 @@ py::object PyOperation::create(const std::string &name,
// Construct the operation.
MlirOperation operation = mlirOperationCreate(&state);
if (!operation.ptr)
- throw py::value_error("Operation creation failed");
+ throw nb::value_error("Operation creation failed");
PyOperationRef created =
PyOperation::createDetached(location->getContext(), operation);
maybeInsertOperation(created, maybeIp);
@@ -1543,7 +1550,7 @@ py::object PyOperation::create(const std::string &name,
return created.getObject();
}
-py::object PyOperation::clone(const py::object &maybeIp) {
+nb::object PyOperation::clone(const nb::object &maybeIp) {
MlirOperation clonedOperation = mlirOperationClone(operation);
PyOperationRef cloned =
PyOperation::createDetached(getContext(), clonedOperation);
@@ -1552,15 +1559,15 @@ py::object PyOperation::clone(const py::object &maybeIp) {
return cloned->createOpView();
}
-py::object PyOperation::createOpView() {
+nb::object PyOperation::createOpView() {
checkValid();
MlirIdentifier ident = mlirOperationGetName(get());
MlirStringRef identStr = mlirIdentifierStr(ident);
auto operationCls = PyGlobals::get().lookupOperationClass(
StringRef(identStr.data, identStr.length));
if (operationCls)
- return PyOpView::constructDerived(*operationCls, *getRef().get());
- return py::cast(PyOpView(getRef().getObject()));
+ return PyOpView::constructDerived(*operationCls, getRef().getObject());
+ return nb::cast(PyOpView(getRef().getObject()));
}
void PyOperation::erase() {
@@ -1573,8 +1580,8 @@ void PyOperation::erase() {
// PyOpView
//------------------------------------------------------------------------------
-static void populateResultTypes(StringRef name, py::list resultTypeList,
- const py::object &resultSegmentSpecObj,
+static void populateResultTypes(StringRef name, nb::list resultTypeList,
+ const nb::object &resultSegmentSpecObj,
std::vector<int32_t> &resultSegmentLengths,
std::vector<PyType *> &resultTypes) {
resultTypes.reserve(resultTypeList.size());
@@ -1582,26 +1589,28 @@ static void populateResultTypes(StringRef name, py::list resultTypeList,
// Non-variadic result unpacking.
for (const auto &it : llvm::enumerate(resultTypeList)) {
try {
- resultTypes.push_back(py::cast<PyType *>(it.value()));
+ resultTypes.push_back(nb::cast<PyType *>(it.value()));
if (!resultTypes.back())
- throw py::cast_error();
- } catch (py::cast_error &err) {
- throw py::value_error((llvm::Twine("Result ") +
+ throw nb::cast_error();
+ } catch (nb::cast_error &err) {
+ throw nb::value_error((llvm::Twine("Result ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Type (" + err.what() + ")")
- .str());
+ .str()
+ .c_str());
}
}
} else {
// Sized result unpacking.
- auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
+ auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
if (resultSegmentSpec.size() != resultTypeList.size()) {
- throw py::value_error((llvm::Twine("Operation \"") + name +
+ throw nb::value_error((llvm::Twine("Operation \"") + name +
"\" requires " +
llvm::Twine(resultSegmentSpec.size()) +
" result segments but was provided " +
llvm::Twine(resultTypeList.size()))
- .str());
+ .str()
+ .c_str());
}
resultSegmentLengths.reserve(resultTypeList.size());
for (const auto &it :
@@ -1610,7 +1619,7 @@ static void populateResultTypes(StringRef name, py::list resultTypeList,
if (segmentSpec == 1 || segmentSpec == 0) {
// Unpack unary element.
try {
- auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
+ auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
if (resultType) {
resultTypes.push_back(resultType);
resultSegmentLengths.push_back(1);
@@ -1618,14 +1627,20 @@ static void populateResultTypes(StringRef name, py::list resultTypeList,
// Allowed to be optional.
resultSegmentLengths.push_back(0);
} else {
- throw py::cast_error("was None and result is not optional");
+ throw nb::value_error(
+ (llvm::Twine("Result ") + llvm::Twine(it.index()) +
+ " of operation \"" + name +
+ "\" must be a Type (was None and result is not optional)")
+ .str()
+ .c_str());
}
- } catch (py::cast_error &err) {
- throw py::value_error((llvm::Twine("Result ") +
+ } catch (nb::cast_error &err) {
+ throw nb::value_error((llvm::Twine("Result ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Type (" + err.what() +
")")
- .str());
+ .str()
+ .c_str());
}
} else if (segmentSpec == -1) {
// Unpack sequence by appending.
@@ -1635,72 +1650,75 @@ static void populateResultTypes(StringRef name, py::list resultTypeList,
resultSegmentLengths.push_back(0);
} else {
// Unpack the list.
- auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
- for (py::object segmentItem : segment) {
- resultTypes.push_back(py::cast<PyType *>(segmentItem));
+ auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
+ for (nb::handle segmentItem : segment) {
+ resultTypes.push_back(nb::cast<PyType *>(segmentItem));
if (!resultTypes.back()) {
- throw py::cast_error("contained a None item");
+ throw nb::type_error("contained a None item");
}
}
- resultSegmentLengths.push_back(segment.size());
+ resultSegmentLengths.push_back(nb::len(segment));
}
} catch (std::exception &err) {
// NOTE: Sloppy to be using a catch-all here, but there are at least
// three different unrelated exceptions that can be thrown in the
// above "casts". Just keep the scope above small and catch them all.
- throw py::value_error((llvm::Twine("Result ") +
+ throw nb::value_error((llvm::Twine("Result ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Sequence of Types (" +
err.what() + ")")
- .str());
+ .str()
+ .c_str());
}
} else {
- throw py::value_error("Unexpected segment spec");
+ throw nb::value_error("Unexpected segment spec");
}
}
}
}
-py::object PyOpView::buildGeneric(
- const py::object &cls, std::optional<py::list> resultTypeList,
- py::list operandList, std::optional<py::dict> attributes,
+nb::object PyOpView::buildGeneric(
+ const nb::object &cls, std::optional<nb::list> resultTypeList,
+ nb::list operandList, std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
- const py::object &maybeIp) {
+ const nb::object &maybeIp) {
PyMlirContextRef context = location->getContext();
// Class level operation construction metadata.
- std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
+ std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
// Operand and result segment specs are either none, which does no
// variadic unpacking, or a list of ints with segment sizes, where each
// element is either a positive number (typically 1 for a scalar) or -1 to
// indicate that it is derived from the length of the same-indexed operand
// or result (implying that it is a list at that position).
- py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
- py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
+ nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
+ nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
std::vector<int32_t> operandSegmentLengths;
std::vector<int32_t> resultSegmentLengths;
// Validate/determine region count.
- auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+ auto opRegionSpec = nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
int opMinRegionCount = std::get<0>(opRegionSpec);
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
if (!regions) {
regions = opMinRegionCount;
}
if (*regions < opMinRegionCount) {
- throw py::value_error(
+ throw nb::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
llvm::Twine(opMinRegionCount) +
" regions but was built with regions=" + llvm::Twine(*regions))
- .str());
+ .str()
+ .c_str());
}
if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
- throw py::value_error(
+ throw nb::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
llvm::Twine(opMinRegionCount) +
" regions but was built with regions=" + llvm::Twine(*regions))
- .str());
+ .str()
+ .c_str());
}
// Unpack results.
@@ -1717,26 +1735,28 @@ py::object PyOpView::buildGeneric(
// Non-sized operand unpacking.
for (const auto &it : llvm::enumerate(operandList)) {
try {
- operands.push_back(py::cast<PyValue *>(it.value()));
+ operands.push_back(nb::cast<PyValue *>(it.value()));
if (!operands.back())
- throw py::cast_error();
- } catch (py::cast_error &err) {
- throw py::value_error((llvm::Twine("Operand ") +
+ throw nb::cast_error();
+ } catch (nb::cast_error &err) {
+ throw nb::value_error((llvm::Twine("Operand ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Value (" + err.what() + ")")
- .str());
+ .str()
+ .c_str());
}
}
} else {
// Sized operand unpacking.
- auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
+ auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
if (operandSegmentSpec.size() != operandList.size()) {
- throw py::value_error((llvm::Twine("Operation \"") + name +
+ throw nb::value_error((llvm::Twine("Operation \"") + name +
"\" requires " +
llvm::Twine(operandSegmentSpec.size()) +
"operand segments but was provided " +
llvm::Twine(operandList.size()))
- .str());
+ .str()
+ .c_str());
}
operandSegmentLengths.reserve(operandList.size());
for (const auto &it :
@@ -1745,7 +1765,7 @@ py::object PyOpView::buildGeneric(
if (segmentSpec == 1 || segmentSpec == 0) {
// Unpack unary element.
try {
- auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
+ auto *operandValue = nb::cast<PyValue *>(std::get<0>(it.value()));
if (operandValue) {
operands.push_back(operandValue);
operandSegmentLengths.push_back(1);
@@ -1753,14 +1773,20 @@ py::object PyOpView::buildGeneric(
// Allowed to be optional.
operandSegmentLengths.push_back(0);
} else {
- throw py::cast_error("was None and operand is not optional");
+ throw nb::value_error(
+ (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " of operation \"" + name +
+ "\" must be a Value (was None and operand is not optional)")
+ .str()
+ .c_str());
}
- } catch (py::cast_error &err) {
- throw py::value_error((llvm::Twine("Operand ") +
+ } catch (nb::cast_error &err) {
+ throw nb::value_error((llvm::Twine("Operand ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Value (" + err.what() +
")")
- .str());
+ .str()
+ .c_str());
}
} else if (segmentSpec == -1) {
// Unpack sequence by appending.
@@ -1770,27 +1796,28 @@ py::object PyOpView::buildGeneric(
operandSegmentLengths.push_back(0);
} else {
// Unpack the list.
- auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
- for (py::object segmentItem : segment) {
- operands.push_back(py::cast<PyValue *>(segmentItem));
+ auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
+ for (nb::handle segmentItem : segment) {
+ operands.push_back(nb::cast<PyValue *>(segmentItem));
if (!operands.back()) {
- throw py::cast_error("contained a None item");
+ throw nb::type_error("contained a None item");
}
}
- operandSegmentLengths.push_back(segment.size());
+ operandSegmentLengths.push_back(nb::len(segment));
}
} catch (std::exception &err) {
// NOTE: Sloppy to be using a catch-all here, but there are at least
// three different unrelated exceptions that can be thrown in the
// above "casts". Just keep the scope above small and catch them all.
- throw py::value_error((llvm::Twine("Operand ") +
+ throw nb::value_error((llvm::Twine("Operand ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Sequence of Values (" +
err.what() + ")")
- .str());
+ .str()
+ .c_str());
}
} else {
- throw py::value_error("Unexpected segment spec");
+ throw nb::value_error("Unexpected segment spec");
}
}
}
@@ -1799,13 +1826,13 @@ py::object PyOpView::buildGeneric(
if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
// Dup.
if (attributes) {
- attributes = py::dict(*attributes);
+ attributes = nb::dict(*attributes);
} else {
- attributes = py::dict();
+ attributes = nb::dict();
}
if (attributes->contains("resultSegmentSizes") ||
attributes->contains("operandSegmentSizes")) {
- throw py::value_error("Manually setting a 'resultSegmentSizes' or "
+ throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
"'operandSegmentSizes' attribute is unsupported. "
"Use Operation.create for such low-level access.");
}
@@ -1839,21 +1866,18 @@ py::object PyOpView::buildGeneric(
!resultTypeList);
}
-pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
- const PyOperation &operation) {
- // TODO: pybind11 2.6 supports a more direct form.
- // Upgrade many years from now.
- // auto opViewType = py::type::of<PyOpView>();
- py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
- py::object instance = cls.attr("__new__")(cls);
+nb::object PyOpView::constructDerived(const nb::object &cls,
+ const nb::object &operation) {
+ nb::handle opViewType = nb::type<PyOpView>();
+ nb::object instance = cls.attr("__new__")(cls);
opViewType.attr("__init__")(instance, operation);
return instance;
}
-PyOpView::PyOpView(const py::object &operationObject)
+PyOpView::PyOpView(const nb::object &operationObject)
// Casting through the PyOperationBase base-class and then back to the
// Operation lets us accept any PyOperationBase subclass.
- : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
+ : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
operationObject(operation.getRef().getObject()) {}
//------------------------------------------------------------------------------
@@ -1869,7 +1893,7 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
void PyInsertionPoint::insert(PyOperationBase &operationBase) {
PyOperation &operation = operationBase.getOperation();
if (operation.isAttached())
- throw py::value_error(
+ throw nb::value_error(
"Attempt to insert operation that is already attached");
block.getParentOperation()->checkValid();
MlirOperation beforeOp = {nullptr};
@@ -1882,7 +1906,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) {
// already end in a known terminator (violating this will cause assertion
// failures later).
if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
- throw py::index_error("Cannot insert operation at the end of a block "
+ throw nb::index_error("Cannot insert operation at the end of a block "
"that already has a terminator. Did you mean to "
"use 'InsertionPoint.at_block_terminator(block)' "
"versus 'InsertionPoint(block)'?");
@@ -1908,19 +1932,19 @@ PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
MlirOperation terminator = mlirBlockGetTerminator(block.get());
if (mlirOperationIsNull(terminator))
- throw py::value_error("Block has no terminator");
+ throw nb::value_error("Block has no terminator");
PyOperationRef terminatorOpRef = PyOperation::forOperation(
block.getParentOperation()->getContext(), terminator);
return PyInsertionPoint{block, std::move(terminatorOpRef)};
}
-py::object PyInsertionPoint::contextEnter() {
- return PyThreadContextEntry::pushInsertionPoint(*this);
+nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
+ return PyThreadContextEntry::pushInsertionPoint(insertPoint);
}
-void PyInsertionPoint::contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb) {
+void PyInsertionPoint::contextExit(const nb::object &excType,
+ const nb::object &excVal,
+ const nb::object &excTb) {
PyThreadContextEntry::popInsertionPoint(*this);
}
@@ -1932,14 +1956,14 @@ bool PyAttribute::operator==(const PyAttribute &other) const {
return mlirAttributeEqual(attr, other.attr);
}
-py::object PyAttribute::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
+nb::object PyAttribute::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
}
-PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
+PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
if (mlirAttributeIsNull(rawAttr))
- throw py::error_already_set();
+ throw nb::python_error();
return PyAttribute(
PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
}
@@ -1964,14 +1988,14 @@ bool PyType::operator==(const PyType &other) const {
return mlirTypeEqual(type, other.type);
}
-py::object PyType::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
+nb::object PyType::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
}
-PyType PyType::createFromCapsule(py::object capsule) {
+PyType PyType::createFromCapsule(nb::object capsule) {
MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
if (mlirTypeIsNull(rawType))
- throw py::error_already_set();
+ throw nb::python_error();
return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
rawType);
}
@@ -1980,14 +2004,14 @@ PyType PyType::createFromCapsule(py::object capsule) {
// PyTypeID.
//------------------------------------------------------------------------------
-py::object PyTypeID::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
+nb::object PyTypeID::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
}
-PyTypeID PyTypeID::createFromCapsule(py::object capsule) {
+PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
if (mlirTypeIDIsNull(mlirTypeID))
- throw py::error_already_set();
+ throw nb::python_error();
return PyTypeID(mlirTypeID);
}
bool PyTypeID::operator==(const PyTypeID &other) const {
@@ -1998,36 +2022,36 @@ bool PyTypeID::operator==(const PyTypeID &other) const {
// PyValue and subclasses.
//------------------------------------------------------------------------------
-pybind11::object PyValue::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
+nb::object PyValue::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
}
-pybind11::object PyValue::maybeDownCast() {
+nb::object PyValue::maybeDownCast() {
MlirType type = mlirValueGetType(get());
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
- std::optional<pybind11::function> valueCaster =
+ std::optional<nb::callable> valueCaster =
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
- // py::return_value_policy::move means use std::move to move the return value
+ // nb::rv_policy::move means use std::move to move the return value
// contents into a new instance that will be owned by Python.
- py::object thisObj = py::cast(this, py::return_value_policy::move);
+ nb::object thisObj = nb::cast(this, nb::rv_policy::move);
if (!valueCaster)
return thisObj;
return valueCaster.value()(thisObj);
}
-PyValue PyValue::createFromCapsule(pybind11::object capsule) {
+PyValue PyValue::createFromCapsule(nb::object capsule) {
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
if (mlirValueIsNull(value))
- throw py::error_already_set();
+ throw nb::python_error();
MlirOperation owner;
if (mlirValueIsAOpResult(value))
owner = mlirOpResultGetOwner(value);
if (mlirValueIsABlockArgument(value))
owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
if (mlirOperationIsNull(owner))
- throw py::error_already_set();
+ throw nb::python_error();
MlirContext ctx = mlirOperationGetContext(owner);
PyOperationRef ownerRef =
PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
@@ -2042,16 +2066,17 @@ PySymbolTable::PySymbolTable(PyOperationBase &operation)
: operation(operation.getOperation().getRef()) {
symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
if (mlirSymbolTableIsNull(symbolTable)) {
- throw py::cast_error("Operation is not a Symbol Table.");
+ throw nb::type_error("Operation is not a Symbol Table.");
}
}
-py::object PySymbolTable::dunderGetItem(const std::string &name) {
+nb::object PySymbolTable::dunderGetItem(const std::string &name) {
operation->checkValid();
MlirOperation symbol = mlirSymbolTableLookup(
symbolTable, mlirStringRefCreate(name.data(), name.length()));
if (mlirOperationIsNull(symbol))
- throw py::key_error("Symbol '" + name + "' not in the symbol table.");
+ throw nb::key_error(
+ ("Symbol '" + name + "' not in the symbol table.").c_str());
return PyOperation::forOperation(operation->getContext(), symbol,
operation.getObject())
@@ -2069,8 +2094,8 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
}
void PySymbolTable::dunderDel(const std::string &name) {
- py::object operation = dunderGetItem(name);
- erase(py::cast<PyOperationBase &>(operation));
+ nb::object operation = dunderGetItem(name);
+ erase(nb::cast<PyOperationBase &>(operation));
}
MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
@@ -2079,7 +2104,7 @@ MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
if (mlirAttributeIsNull(symbolAttr))
- throw py::value_error("Expected operation to have a symbol name.");
+ throw nb::value_error("Expected operation to have a symbol name.");
return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
}
@@ -2091,7 +2116,7 @@ MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
MlirAttribute existingNameAttr =
mlirOperationGetAttributeByName(operation.get(), attrName);
if (mlirAttributeIsNull(existingNameAttr))
- throw py::value_error("Expected operation to have a symbol name.");
+ throw nb::value_error("Expected operation to have a symbol name.");
return existingNameAttr;
}
@@ -2104,7 +2129,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol,
MlirAttribute existingNameAttr =
mlirOperationGetAttributeByName(operation.get(), attrName);
if (mlirAttributeIsNull(existingNameAttr))
- throw py::value_error("Expected operation to have a symbol name.");
+ throw nb::value_error("Expected operation to have a symbol name.");
MlirAttribute newNameAttr =
mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
@@ -2117,7 +2142,7 @@ MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
MlirAttribute existingVisAttr =
mlirOperationGetAttributeByName(operation.get(), attrName);
if (mlirAttributeIsNull(existingVisAttr))
- throw py::value_error("Expected operation to have a symbol visibility.");
+ throw nb::value_error("Expected operation to have a symbol visibility.");
return existingVisAttr;
}
@@ -2125,7 +2150,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol,
const std::string &visibility) {
if (visibility != "public" && visibility != "private" &&
visibility != "nested")
- throw py::value_error(
+ throw nb::value_error(
"Expected visibility to be 'public', 'private' or 'nested'");
PyOperation &operation = symbol.getOperation();
operation.checkValid();
@@ -2133,7 +2158,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol,
MlirAttribute existingVisAttr =
mlirOperationGetAttributeByName(operation.get(), attrName);
if (mlirAttributeIsNull(existingVisAttr))
- throw py::value_error("Expected operation to have a symbol visibility.");
+ throw nb::value_error("Expected operation to have a symbol visibility.");
MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
toMlirStringRef(visibility));
mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
@@ -2148,20 +2173,20 @@ void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
from.getOperation())))
- throw py::value_error("Symbol rename failed");
+ throw nb::value_error("Symbol rename failed");
}
void PySymbolTable::walkSymbolTables(PyOperationBase &from,
bool allSymUsesVisible,
- py::object callback) {
+ nb::object callback) {
PyOperation &fromOperation = from.getOperation();
fromOperation.checkValid();
struct UserData {
PyMlirContextRef context;
- py::object callback;
+ nb::object callback;
bool gotException;
std::string exceptionWhat;
- py::object exceptionType;
+ nb::object exceptionType;
};
UserData userData{
fromOperation.getContext(), std::move(callback), false, {}, {}};
@@ -2175,10 +2200,10 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
return;
try {
calleeUserData->callback(pyFoundOp.getObject(), isVisible);
- } catch (py::error_already_set &e) {
+ } catch (nb::python_error &e) {
calleeUserData->gotException = true;
calleeUserData->exceptionWhat = e.what();
- calleeUserData->exceptionType = e.type();
+ calleeUserData->exceptionType = nb::borrow(e.type());
}
},
static_cast<void *>(&userData));
@@ -2200,7 +2225,7 @@ class PyConcreteValue : public PyValue {
// IsAFunctionTy isaFunction
// const char *pyClassName
// and redefine bindDerived.
- using ClassTy = py::class_<DerivedTy, PyValue>;
+ using ClassTy = nb::class_<DerivedTy, PyValue>;
using IsAFunctionTy = bool (*)(MlirValue);
PyConcreteValue() = default;
@@ -2213,25 +2238,26 @@ class PyConcreteValue : public PyValue {
/// type mismatches.
static MlirValue castFrom(PyValue &orig) {
if (!DerivedTy::isaFunction(orig.get())) {
- auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw py::value_error((Twine("Cannot cast value to ") +
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+ throw nb::value_error((Twine("Cannot cast value to ") +
DerivedTy::pyClassName + " (from " + origRepr +
")")
- .str());
+ .str()
+ .c_str());
}
return orig.get();
}
/// Binds the Python module objects to functions of this class.
- static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
- cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
+ static void bind(nb::module_ &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
cls.def_static(
"isinstance",
[](PyValue &otherValue) -> bool {
return DerivedTy::isaFunction(otherValue);
},
- py::arg("other_value"));
+ nb::arg("other_value"));
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](DerivedTy &self) { return self.maybeDownCast(); });
DerivedTy::bindDerived(cls);
@@ -2249,11 +2275,11 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
using PyConcreteValue::PyConcreteValue;
static void bindDerived(ClassTy &c) {
- c.def_property_readonly("owner", [](PyBlockArgument &self) {
+ c.def_prop_ro("owner", [](PyBlockArgument &self) {
return PyBlock(self.getParentOperation(),
mlirBlockArgumentGetOwner(self.get()));
});
- c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
+ c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
return mlirBlockArgumentGetArgNumber(self.get());
});
c.def(
@@ -2261,7 +2287,7 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
[](PyBlockArgument &self, PyType type) {
return mlirBlockArgumentSetType(self.get(), type);
},
- py::arg("type"));
+ nb::arg("type"));
}
};
@@ -2273,14 +2299,14 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
using PyConcreteValue::PyConcreteValue;
static void bindDerived(ClassTy &c) {
- c.def_property_readonly("owner", [](PyOpResult &self) {
+ c.def_prop_ro("owner", [](PyOpResult &self) {
assert(
mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in the IR");
return self.getParentOperation().getObject();
});
- c.def_property_readonly("result_number", [](PyOpResult &self) {
+ c.def_prop_ro("result_number", [](PyOpResult &self) {
return mlirOpResultGetResultNumber(self.get());
});
}
@@ -2317,7 +2343,7 @@ class PyBlockArgumentList
operation(std::move(operation)), block(block) {}
static void bindDerived(ClassTy &c) {
- c.def_property_readonly("types", [](PyBlockArgumentList &self) {
+ c.def_prop_ro("types", [](PyBlockArgumentList &self) {
return getValueTypes(self, self.operation->getContext());
});
}
@@ -2422,10 +2448,10 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
operation(std::move(operation)) {}
static void bindDerived(ClassTy &c) {
- c.def_property_readonly("types", [](PyOpResultList &self) {
+ c.def_prop_ro("types", [](PyOpResultList &self) {
return getValueTypes(self, self.operation->getContext());
});
- c.def_property_readonly("owner", [](PyOpResultList &self) {
+ c.def_prop_ro("owner", [](PyOpResultList &self) {
return self.operation->createOpView();
});
}
@@ -2508,14 +2534,14 @@ class PyOpAttributeMap {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
- throw py::key_error("attempt to access a non-existent attribute");
+ throw nb::key_error("attempt to access a non-existent attribute");
}
return attr;
}
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
if (index < 0 || index >= dunderLen()) {
- throw py::index_error("attempt to access out of bounds attribute");
+ throw nb::index_error("attempt to access out of bounds attribute");
}
MlirNamedAttribute namedAttr =
mlirOperationGetAttribute(operation->get(), index);
@@ -2534,7 +2560,7 @@ class PyOpAttributeMap {
int removed = mlirOperationRemoveAttributeByName(operation->get(),
toMlirStringRef(name));
if (!removed)
- throw py::key_error("attempt to delete a non-existent attribute");
+ throw nb::key_error("attempt to delete a non-existent attribute");
}
intptr_t dunderLen() {
@@ -2546,8 +2572,8 @@ class PyOpAttributeMap {
operation->get(), toMlirStringRef(name)));
}
- static void bind(py::module &m) {
- py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
+ static void bind(nb::module_ &m) {
+ nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
.def("__contains__", &PyOpAttributeMap::dunderContains)
.def("__len__", &PyOpAttributeMap::dunderLen)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
@@ -2566,21 +2592,21 @@ class PyOpAttributeMap {
// Populates the core exports of the 'ir' submodule.
//------------------------------------------------------------------------------
-void mlir::python::populateIRCore(py::module &m) {
+void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
// Enums.
//----------------------------------------------------------------------------
- py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
+ nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
.value("ERROR", MlirDiagnosticError)
.value("WARNING", MlirDiagnosticWarning)
.value("NOTE", MlirDiagnosticNote)
.value("REMARK", MlirDiagnosticRemark);
- py::enum_<MlirWalkOrder>(m, "WalkOrder", py::module_local())
+ nb::enum_<MlirWalkOrder>(m, "WalkOrder")
.value("PRE_ORDER", MlirWalkPreOrder)
.value("POST_ORDER", MlirWalkPostOrder);
- py::enum_<MlirWalkResult>(m, "WalkResult", py::module_local())
+ nb::enum_<MlirWalkResult>(m, "WalkResult")
.value("ADVANCE", MlirWalkResultAdvance)
.value("INTERRUPT", MlirWalkResultInterrupt)
.value("SKIP", MlirWalkResultSkip);
@@ -2588,33 +2614,37 @@ void mlir::python::populateIRCore(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of Diagnostics.
//----------------------------------------------------------------------------
- py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
- .def_property_readonly("severity", &PyDiagnostic::getSeverity)
- .def_property_readonly("location", &PyDiagnostic::getLocation)
- .def_property_readonly("message", &PyDiagnostic::getMessage)
- .def_property_readonly("notes", &PyDiagnostic::getNotes)
- .def("__str__", [](PyDiagnostic &self) -> py::str {
+ nb::class_<PyDiagnostic>(m, "Diagnostic")
+ .def_prop_ro("severity", &PyDiagnostic::getSeverity)
+ .def_prop_ro("location", &PyDiagnostic::getLocation)
+ .def_prop_ro("message", &PyDiagnostic::getMessage)
+ .def_prop_ro("notes", &PyDiagnostic::getNotes)
+ .def("__str__", [](PyDiagnostic &self) -> nb::str {
if (!self.isValid())
- return "<Invalid Diagnostic>";
+ return nb::str("<Invalid Diagnostic>");
return self.getMessage();
});
- py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
- py::module_local())
- .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
- .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
- .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
- .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
- .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
+ nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
+ .def("__init__",
+ [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
+ new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
+ })
+ .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
+ .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
+ .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
+ .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
.def("__str__",
[](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
- py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
+ nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
.def("detach", &PyDiagnosticHandler::detach)
- .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
- .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
+ .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
+ .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
.def("__enter__", &PyDiagnosticHandler::contextEnter)
- .def("__exit__", &PyDiagnosticHandler::contextExit);
+ .def("__exit__", &PyDiagnosticHandler::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none());
//----------------------------------------------------------------------------
// Mapping of MlirContext.
@@ -2622,8 +2652,12 @@ void mlir::python::populateIRCore(py::module &m) {
// __init__.py will subclass it with site-specific functionality and set a
// "Context" attribute on this module.
//----------------------------------------------------------------------------
- py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
- .def(py::init<>(&PyMlirContext::createNewContextForInit))
+ nb::class_<PyMlirContext>(m, "_BaseContext")
+ .def("__init__",
+ [](PyMlirContext &self) {
+ MlirContext context = mlirContextCreateWithThreading(false);
+ new (&self) PyMlirContext(context);
+ })
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
.def("_get_context_again",
[](PyMlirContext &self) {
@@ -2635,28 +2669,28 @@ void mlir::python::populateIRCore(py::module &m) {
&PyMlirContext::getLiveOperationObjects)
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
.def("_clear_live_operations_inside",
- py::overload_cast<MlirOperation>(
+ nb::overload_cast<MlirOperation>(
&PyMlirContext::clearOperationsInside))
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyMlirContext::getCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
- .def("__exit__", &PyMlirContext::contextExit)
- .def_property_readonly_static(
+ .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
+ nb::arg("exc_value").none(), nb::arg("traceback").none())
+ .def_prop_ro_static(
"current",
- [](py::object & /*class*/) {
+ [](nb::object & /*class*/) {
auto *context = PyThreadContextEntry::getDefaultContext();
if (!context)
- return py::none().cast<py::object>();
- return py::cast(context);
+ return nb::none();
+ return nb::cast(context);
},
"Gets the Context bound to the current thread or raises ValueError")
- .def_property_readonly(
+ .def_prop_ro(
"dialects",
[](PyMlirContext &self) { return PyDialects(self.getRef()); },
"Gets a container for accessing dialects by name")
- .def_property_readonly(
+ .def_prop_ro(
"d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
"Alias for 'dialect'")
.def(
@@ -2665,14 +2699,14 @@ void mlir::python::populateIRCore(py::module &m) {
MlirDialect dialect = mlirContextGetOrLoadDialect(
self.get(), {name.data(), name.size()});
if (mlirDialectIsNull(dialect)) {
- throw py::value_error(
- (Twine("Dialect '") + name + "' not found").str());
+ throw nb::value_error(
+ (Twine("Dialect '") + name + "' not found").str().c_str());
}
return PyDialectDescriptor(self.getRef(), dialect);
},
- py::arg("dialect_name"),
+ nb::arg("dialect_name"),
"Gets or loads a dialect by name, returning its descriptor object")
- .def_property(
+ .def_prop_rw(
"allow_unregistered_dialects",
[](PyMlirContext &self) -> bool {
return mlirContextGetAllowUnregisteredDialects(self.get());
@@ -2681,32 +2715,32 @@ void mlir::python::populateIRCore(py::module &m) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
})
.def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
- py::arg("callback"),
+ nb::arg("callback"),
"Attaches a diagnostic handler that will receive callbacks")
.def(
"enable_multithreading",
[](PyMlirContext &self, bool enable) {
mlirContextEnableMultithreading(self.get(), enable);
},
- py::arg("enable"))
+ nb::arg("enable"))
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
return mlirContextIsRegisteredOperation(
self.get(), MlirStringRef{name.data(), name.size()});
},
- py::arg("operation_name"))
+ nb::arg("operation_name"))
.def(
"append_dialect_registry",
[](PyMlirContext &self, PyDialectRegistry ®istry) {
mlirContextAppendDialectRegistry(self.get(), registry);
},
- py::arg("registry"))
- .def_property("emit_error_diagnostics", nullptr,
- &PyMlirContext::setEmitErrorDiagnostics,
- "Emit error diagnostics to diagnostic handlers. By default "
- "error diagnostics are captured and reported through "
- "MLIRError exceptions.")
+ nb::arg("registry"))
+ .def_prop_rw("emit_error_diagnostics", nullptr,
+ &PyMlirContext::setEmitErrorDiagnostics,
+ "Emit error diagnostics to diagnostic handlers. By default "
+ "error diagnostics are captured and reported through "
+ "MLIRError exceptions.")
.def("load_all_available_dialects", [](PyMlirContext &self) {
mlirContextLoadAllAvailableDialects(self.get());
});
@@ -2714,13 +2748,12 @@ void mlir::python::populateIRCore(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of PyDialectDescriptor
//----------------------------------------------------------------------------
- py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
- .def_property_readonly("namespace",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns =
- mlirDialectGetNamespace(self.get());
- return py::str(ns.data, ns.length);
- })
+ nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+ .def_prop_ro("namespace",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ return nb::str(ns.data, ns.length);
+ })
.def("__repr__", [](PyDialectDescriptor &self) {
MlirStringRef ns = mlirDialectGetNamespace(self.get());
std::string repr("<DialectDescriptor ");
@@ -2732,66 +2765,66 @@ void mlir::python::populateIRCore(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of PyDialects
//----------------------------------------------------------------------------
- py::class_<PyDialects>(m, "Dialects", py::module_local())
+ nb::class_<PyDialects>(m, "Dialects")
.def("__getitem__",
[=](PyDialects &self, std::string keyName) {
MlirDialect dialect =
self.getDialectForKey(keyName, /*attrError=*/false);
- py::object descriptor =
- py::cast(PyDialectDescriptor{self.getContext(), dialect});
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
return createCustomDialectWrapper(keyName, std::move(descriptor));
})
.def("__getattr__", [=](PyDialects &self, std::string attrName) {
MlirDialect dialect =
self.getDialectForKey(attrName, /*attrError=*/true);
- py::object descriptor =
- py::cast(PyDialectDescriptor{self.getContext(), dialect});
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
return createCustomDialectWrapper(attrName, std::move(descriptor));
});
//----------------------------------------------------------------------------
// Mapping of PyDialect
//----------------------------------------------------------------------------
- py::class_<PyDialect>(m, "Dialect", py::module_local())
- .def(py::init<py::object>(), py::arg("descriptor"))
- .def_property_readonly(
- "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
- .def("__repr__", [](py::object self) {
+ nb::class_<PyDialect>(m, "Dialect")
+ .def(nb::init<nb::object>(), nb::arg("descriptor"))
+ .def_prop_ro("descriptor",
+ [](PyDialect &self) { return self.getDescriptor(); })
+ .def("__repr__", [](nb::object self) {
auto clazz = self.attr("__class__");
- return py::str("<Dialect ") +
- self.attr("descriptor").attr("namespace") + py::str(" (class ") +
- clazz.attr("__module__") + py::str(".") +
- clazz.attr("__name__") + py::str(")>");
+ return nb::str("<Dialect ") +
+ self.attr("descriptor").attr("namespace") + nb::str(" (class ") +
+ clazz.attr("__module__") + nb::str(".") +
+ clazz.attr("__name__") + nb::str(")>");
});
//----------------------------------------------------------------------------
// Mapping of PyDialectRegistry
//----------------------------------------------------------------------------
- py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyDialectRegistry::getCapsule)
+ nb::class_<PyDialectRegistry>(m, "DialectRegistry")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
- .def(py::init<>());
+ .def(nb::init<>());
//----------------------------------------------------------------------------
// Mapping of Location
//----------------------------------------------------------------------------
- py::class_<PyLocation>(m, "Location", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
+ nb::class_<PyLocation>(m, "Location")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
.def("__enter__", &PyLocation::contextEnter)
- .def("__exit__", &PyLocation::contextExit)
+ .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
+ nb::arg("exc_value").none(), nb::arg("traceback").none())
.def("__eq__",
[](PyLocation &self, PyLocation &other) -> bool {
return mlirLocationEqual(self, other);
})
- .def("__eq__", [](PyLocation &self, py::object other) { return false; })
- .def_property_readonly_static(
+ .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
+ .def_prop_ro_static(
"current",
- [](py::object & /*class*/) {
+ [](nb::object & /*class*/) {
auto *loc = PyThreadContextEntry::getDefaultLocation();
if (!loc)
- throw py::value_error("No current Location");
+ throw nb::value_error("No current Location");
return loc;
},
"Gets the Location bound to the current thread or raises ValueError")
@@ -2801,14 +2834,14 @@ void mlir::python::populateIRCore(py::module &m) {
return PyLocation(context->getRef(),
mlirLocationUnknownGet(context->get()));
},
- py::arg("context") = py::none(),
+ nb::arg("context").none() = nb::none(),
"Gets a Location representing an unknown location")
.def_static(
"callsite",
[](PyLocation callee, const std::vector<PyLocation> &frames,
DefaultingPyMlirContext context) {
if (frames.empty())
- throw py::value_error("No caller frames provided");
+ throw nb::value_error("No caller frames provided");
MlirLocation caller = frames.back().get();
for (const PyLocation &frame :
llvm::reverse(llvm::ArrayRef(frames).drop_back()))
@@ -2816,7 +2849,8 @@ void mlir::python::populateIRCore(py::module &m) {
return PyLocation(context->getRef(),
mlirLocationCallSiteGet(callee.get(), caller));
},
- py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
+ nb::arg("callee"), nb::arg("frames"),
+ nb::arg("context").none() = nb::none(),
kContextGetCallSiteLocationDocstring)
.def_static(
"file",
@@ -2827,8 +2861,9 @@ void mlir::python::populateIRCore(py::module &m) {
mlirLocationFileLineColGet(
context->get(), toMlirStringRef(filename), line, col));
},
- py::arg("filename"), py::arg("line"), py::arg("col"),
- py::arg("context") = py::none(), kContextGetFileLocationDocstring)
+ nb::arg("filename"), nb::arg("line"), nb::arg("col"),
+ nb::arg("context").none() = nb::none(),
+ kContextGetFileLocationDocstring)
.def_static(
"fused",
[](const std::vector<PyLocation> &pyLocations,
@@ -2843,8 +2878,9 @@ void mlir::python::populateIRCore(py::module &m) {
metadata ? metadata->get() : MlirAttribute{0});
return PyLocation(context->getRef(), location);
},
- py::arg("locations"), py::arg("metadata") = py::none(),
- py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
+ nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
+ kContextGetFusedLocationDocstring)
.def_static(
"name",
[](std::string name, std::optional<PyLocation> childLoc,
@@ -2856,21 +2892,22 @@ void mlir::python::populateIRCore(py::module &m) {
childLoc ? childLoc->get()
: mlirLocationUnknownGet(context->get())));
},
- py::arg("name"), py::arg("childLoc") = py::none(),
- py::arg("context") = py::none(), kContextGetNameLocationDocString)
+ nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
+ kContextGetNameLocationDocString)
.def_static(
"from_attr",
[](PyAttribute &attribute, DefaultingPyMlirContext context) {
return PyLocation(context->getRef(),
mlirLocationFromAttribute(attribute));
},
- py::arg("attribute"), py::arg("context") = py::none(),
+ nb::arg("attribute"), nb::arg("context").none() = nb::none(),
"Gets a Location from a LocationAttr")
- .def_property_readonly(
+ .def_prop_ro(
"context",
[](PyLocation &self) { return self.getContext().getObject(); },
"Context that owns the Location")
- .def_property_readonly(
+ .def_prop_ro(
"attr",
[](PyLocation &self) { return mlirLocationGetAttribute(self); },
"Get the underlying LocationAttr")
@@ -2879,7 +2916,7 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyLocation &self, std::string message) {
mlirEmitError(self, message.c_str());
},
- py::arg("message"), "Emits an error at this location")
+ nb::arg("message"), "Emits an error at this location")
.def("__repr__", [](PyLocation &self) {
PyPrintAccumulator printAccum;
mlirLocationPrint(self, printAccum.getCallback(),
@@ -2890,8 +2927,8 @@ void mlir::python::populateIRCore(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of Module
//----------------------------------------------------------------------------
- py::class_<PyModule>(m, "Module", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
+ nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
.def_static(
"parse",
@@ -2903,7 +2940,19 @@ void mlir::python::populateIRCore(py::module &m) {
throw MLIRError("Unable to parse module assembly", errors.take());
return PyModule::forModule(module).releaseObject();
},
- py::arg("asm"), py::arg("context") = py::none(),
+ nb::arg("asm"), nb::arg("context").none() = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "parse",
+ [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("asm"), nb::arg("context").none() = nb::none(),
kModuleParseDocstring)
.def_static(
"create",
@@ -2911,12 +2960,12 @@ void mlir::python::populateIRCore(py::module &m) {
MlirModule module = mlirModuleCreateEmpty(loc);
return PyModule::forModule(module).releaseObject();
},
- py::arg("loc") = py::none(), "Creates an empty module")
- .def_property_readonly(
+ nb::arg("loc").none() = nb::none(), "Creates an empty module")
+ .def_prop_ro(
"context",
[](PyModule &self) { return self.getContext().getObject(); },
"Context that created the Module")
- .def_property_readonly(
+ .def_prop_ro(
"operation",
[](PyModule &self) {
return PyOperation::forOperation(self.getContext(),
@@ -2925,7 +2974,7 @@ void mlir::python::populateIRCore(py::module &m) {
.releaseObject();
},
"Accesses the module as an operation")
- .def_property_readonly(
+ .def_prop_ro(
"body",
[](PyModule &self) {
PyOperationRef moduleOp = PyOperation::forOperation(
@@ -2943,7 +2992,7 @@ void mlir::python::populateIRCore(py::module &m) {
kDumpDocstring)
.def(
"__str__",
- [](py::object self) {
+ [](nb::object self) {
// Defer to the operation's __str__.
return self.attr("operation").attr("__str__")();
},
@@ -2952,27 +3001,26 @@ void mlir::python::populateIRCore(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of Operation.
//----------------------------------------------------------------------------
- py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- [](PyOperationBase &self) {
- return self.getOperation().getCapsule();
- })
+ nb::class_<PyOperationBase>(m, "_OperationBase")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
+ [](PyOperationBase &self) {
+ return self.getOperation().getCapsule();
+ })
.def("__eq__",
[](PyOperationBase &self, PyOperationBase &other) {
return &self.getOperation() == &other.getOperation();
})
.def("__eq__",
- [](PyOperationBase &self, py::object other) { return false; })
+ [](PyOperationBase &self, nb::object other) { return false; })
.def("__hash__",
[](PyOperationBase &self) {
return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
})
- .def_property_readonly("attributes",
- [](PyOperationBase &self) {
- return PyOpAttributeMap(
- self.getOperation().getRef());
- })
- .def_property_readonly(
+ .def_prop_ro("attributes",
+ [](PyOperationBase &self) {
+ return PyOpAttributeMap(self.getOperation().getRef());
+ })
+ .def_prop_ro(
"context",
[](PyOperationBase &self) {
PyOperation &concreteOperation = self.getOperation();
@@ -2980,46 +3028,44 @@ void mlir::python::populateIRCore(py::module &m) {
return concreteOperation.getContext().getObject();
},
"Context that owns the Operation")
- .def_property_readonly("name",
- [](PyOperationBase &self) {
- auto &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- MlirOperation operation =
- concreteOperation.get();
- MlirStringRef name = mlirIdentifierStr(
- mlirOperationGetName(operation));
- return py::str(name.data, name.length);
- })
- .def_property_readonly("operands",
- [](PyOperationBase &self) {
- return PyOpOperandList(
- self.getOperation().getRef());
- })
- .def_property_readonly("regions",
- [](PyOperationBase &self) {
- return PyRegionList(
- self.getOperation().getRef());
- })
- .def_property_readonly(
+ .def_prop_ro("name",
+ [](PyOperationBase &self) {
+ auto &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ MlirOperation operation = concreteOperation.get();
+ MlirStringRef name =
+ mlirIdentifierStr(mlirOperationGetName(operation));
+ return nb::str(name.data, name.length);
+ })
+ .def_prop_ro("operands",
+ [](PyOperationBase &self) {
+ return PyOpOperandList(self.getOperation().getRef());
+ })
+ .def_prop_ro("regions",
+ [](PyOperationBase &self) {
+ return PyRegionList(self.getOperation().getRef());
+ })
+ .def_prop_ro(
"results",
[](PyOperationBase &self) {
return PyOpResultList(self.getOperation().getRef());
},
"Returns the list of Operation results.")
- .def_property_readonly(
+ .def_prop_ro(
"result",
[](PyOperationBase &self) {
auto &operation = self.getOperation();
auto numResults = mlirOperationGetNumResults(operation);
if (numResults != 1) {
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
- throw py::value_error(
+ throw nb::value_error(
(Twine("Cannot call .result on operation ") +
StringRef(name.data, name.length) + " which has " +
Twine(numResults) +
" results (it is only valid for operations with a "
"single result)")
- .str());
+ .str()
+ .c_str());
}
return PyOpResult(operation.getRef(),
mlirOperationGetResult(operation, 0))
@@ -3027,7 +3073,7 @@ void mlir::python::populateIRCore(py::module &m) {
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
- .def_property_readonly(
+ .def_prop_ro(
"location",
[](PyOperationBase &self) {
PyOperation &operation = self.getOperation();
@@ -3036,14 +3082,13 @@ void mlir::python::populateIRCore(py::module &m) {
},
"Returns the source location the operation was defined or derived "
"from.")
- .def_property_readonly("parent",
- [](PyOperationBase &self) -> py::object {
- auto parent =
- self.getOperation().getParentOperation();
- if (parent)
- return parent->getObject();
- return py::none();
- })
+ .def_prop_ro("parent",
+ [](PyOperationBase &self) -> nb::object {
+ auto parent = self.getOperation().getParentOperation();
+ if (parent)
+ return parent->getObject();
+ return nb::none();
+ })
.def(
"__str__",
[](PyOperationBase &self) {
@@ -3058,75 +3103,76 @@ void mlir::python::populateIRCore(py::module &m) {
},
"Returns the assembly form of the operation.")
.def("print",
- py::overload_cast<PyAsmState &, pybind11::object, bool>(
+ nb::overload_cast<PyAsmState &, nb::object, bool>(
&PyOperationBase::print),
- py::arg("state"), py::arg("file") = py::none(),
- py::arg("binary") = false, kOperationPrintStateDocstring)
+ nb::arg("state"), nb::arg("file").none() = nb::none(),
+ nb::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
- py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool, bool>(
+ nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
+ bool, nb::object, bool, bool>(
&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
- py::arg("large_elements_limit") = py::none(),
- py::arg("enable_debug_info") = false,
- py::arg("pretty_debug_info") = false,
- py::arg("print_generic_op_form") = false,
- py::arg("use_local_scope") = false,
- py::arg("assume_verified") = false, py::arg("file") = py::none(),
- py::arg("binary") = false, py::arg("skip_regions") = false,
- kOperationPrintDocstring)
- .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
- py::arg("desired_version") = py::none(),
+ nb::arg("large_elements_limit").none() = nb::none(),
+ nb::arg("enable_debug_info") = false,
+ nb::arg("pretty_debug_info") = false,
+ nb::arg("print_generic_op_form") = false,
+ nb::arg("use_local_scope") = false,
+ nb::arg("assume_verified") = false,
+ nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
+ nb::arg("skip_regions") = false, kOperationPrintDocstring)
+ .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
+ nb::arg("desired_version").none() = nb::none(),
kOperationPrintBytecodeDocstring)
.def("get_asm", &PyOperationBase::getAsm,
// Careful: Lots of arguments must match up with get_asm method.
- py::arg("binary") = false,
- py::arg("large_elements_limit") = py::none(),
- py::arg("enable_debug_info") = false,
- py::arg("pretty_debug_info") = false,
- py::arg("print_generic_op_form") = false,
- py::arg("use_local_scope") = false,
- py::arg("assume_verified") = false, py::arg("skip_regions") = false,
+ nb::arg("binary") = false,
+ nb::arg("large_elements_limit").none() = nb::none(),
+ nb::arg("enable_debug_info") = false,
+ nb::arg("pretty_debug_info") = false,
+ nb::arg("print_generic_op_form") = false,
+ nb::arg("use_local_scope") = false,
+ nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
kOperationGetAsmDocstring)
.def("verify", &PyOperationBase::verify,
"Verify the operation. Raises MLIRError if verification fails, and "
"returns true otherwise.")
- .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
+ .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
"Puts self immediately after the other operation in its parent "
"block.")
- .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
+ .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
"Puts self immediately before the other operation in its parent "
"block.")
.def(
"clone",
- [](PyOperationBase &self, py::object ip) {
+ [](PyOperationBase &self, nb::object ip) {
return self.getOperation().clone(ip);
},
- py::arg("ip") = py::none())
+ nb::arg("ip").none() = nb::none())
.def(
"detach_from_parent",
[](PyOperationBase &self) {
PyOperation &operation = self.getOperation();
operation.checkValid();
if (!operation.isAttached())
- throw py::value_error("Detached operation has no parent.");
+ throw nb::value_error("Detached operation has no parent.");
operation.detachFromParent();
return operation.createOpView();
},
"Detaches the operation from its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
- .def("walk", &PyOperationBase::walk, py::arg("callback"),
- py::arg("walk_order") = MlirWalkPostOrder);
-
- py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
- .def_static("create", &PyOperation::create, py::arg("name"),
- py::arg("results") = py::none(),
- py::arg("operands") = py::none(),
- py::arg("attributes") = py::none(),
- py::arg("successors") = py::none(), py::arg("regions") = 0,
- py::arg("loc") = py::none(), py::arg("ip") = py::none(),
- py::arg("infer_type") = false, kOperationCreateDocstring)
+ .def("walk", &PyOperationBase::walk, nb::arg("callback"),
+ nb::arg("walk_order") = MlirWalkPostOrder);
+
+ nb::class_<PyOperation, PyOperationBase>(m, "Operation")
+ .def_static("create", &PyOperation::create, nb::arg("name"),
+ nb::arg("results").none() = nb::none(),
+ nb::arg("operands").none() = nb::none(),
+ nb::arg("attributes").none() = nb::none(),
+ nb::arg("successors").none() = nb::none(),
+ nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(),
+ nb::arg("ip").none() = nb::none(),
+ nb::arg("infer_type") = false, kOperationCreateDocstring)
.def_static(
"parse",
[](const std::string &sourceStr, const std::string &sourceName,
@@ -3134,16 +3180,15 @@ void mlir::python::populateIRCore(py::module &m) {
return PyOperation::parse(context->getRef(), sourceStr, sourceName)
->createOpView();
},
- py::arg("source"), py::kw_only(), py::arg("source_name") = "",
- py::arg("context") = py::none(),
+ nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
+ nb::arg("context").none() = nb::none(),
"Parses an operation. Supports both text assembly format and binary "
"bytecode format.")
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyOperation::getCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
- .def_property_readonly("operation", [](py::object self) { return self; })
- .def_property_readonly("opview", &PyOperation::createOpView)
- .def_property_readonly(
+ .def_prop_ro("operation", [](nb::object self) { return self; })
+ .def_prop_ro("opview", &PyOperation::createOpView)
+ .def_prop_ro(
"successors",
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
@@ -3151,30 +3196,33 @@ void mlir::python::populateIRCore(py::module &m) {
"Returns the list of Operation successors.");
auto opViewClass =
- py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
- .def(py::init<py::object>(), py::arg("operation"))
- .def_property_readonly("operation", &PyOpView::getOperationObject)
- .def_property_readonly("opview", [](py::object self) { return self; })
+ nb::class_<PyOpView, PyOperationBase>(m, "OpView")
+ .def(nb::init<nb::object>(), nb::arg("operation"))
+ .def_prop_ro("operation", &PyOpView::getOperationObject)
+ .def_prop_ro("opview", [](nb::object self) { return self; })
.def(
"__str__",
- [](PyOpView &self) { return py::str(self.getOperationObject()); })
- .def_property_readonly(
+ [](PyOpView &self) { return nb::str(self.getOperationObject()); })
+ .def_prop_ro(
"successors",
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
"Returns the list of Operation successors.");
- opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
- opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
- opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
+ opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
+ opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
+ opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
opViewClass.attr("build_generic") = classmethod(
- &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
- py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
- py::arg("successors") = py::none(), py::arg("regions") = py::none(),
- py::arg("loc") = py::none(), py::arg("ip") = py::none(),
+ &PyOpView::buildGeneric, nb::arg("cls"),
+ nb::arg("results").none() = nb::none(),
+ nb::arg("operands").none() = nb::none(),
+ nb::arg("attributes").none() = nb::none(),
+ nb::arg("successors").none() = nb::none(),
+ nb::arg("regions").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
"Builds a specific, generated OpView based on class level attributes.");
opViewClass.attr("parse") = classmethod(
- [](const py::object &cls, const std::string &sourceStr,
+ [](const nb::object &cls, const std::string &sourceStr,
const std::string &sourceName, DefaultingPyMlirContext context) {
PyOperationRef parsed =
PyOperation::parse(context->getRef(), sourceStr, sourceName);
@@ -3185,30 +3233,30 @@ void mlir::python::populateIRCore(py::module &m) {
// `OpView` subclasses, and is not intended to be used on `OpView`
// directly.
std::string clsOpName =
- py::cast<std::string>(cls.attr("OPERATION_NAME"));
+ nb::cast<std::string>(cls.attr("OPERATION_NAME"));
MlirStringRef identifier =
mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
std::string_view parsedOpName(identifier.data, identifier.length);
if (clsOpName != parsedOpName)
throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
parsedOpName + "'");
- return PyOpView::constructDerived(cls, *parsed.get());
+ return PyOpView::constructDerived(cls, parsed.getObject());
},
- py::arg("cls"), py::arg("source"), py::kw_only(),
- py::arg("source_name") = "", py::arg("context") = py::none(),
+ nb::arg("cls"), nb::arg("source"), nb::kw_only(),
+ nb::arg("source_name") = "", nb::arg("context").none() = nb::none(),
"Parses a specific, generated OpView based on class level attributes");
//----------------------------------------------------------------------------
// Mapping of PyRegion.
//----------------------------------------------------------------------------
- py::class_<PyRegion>(m, "Region", py::module_local())
- .def_property_readonly(
+ nb::class_<PyRegion>(m, "Region")
+ .def_prop_ro(
"blocks",
[](PyRegion &self) {
return PyBlockList(self.getParentOperation(), self.get());
},
"Returns a forward-optimized sequence of blocks.")
- .def_property_readonly(
+ .def_prop_ro(
"owner",
[](PyRegion &self) {
return self.getParentOperation()->createOpView();
@@ -3226,27 +3274,27 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyRegion &self, PyRegion &other) {
return self.get().ptr == other.get().ptr;
})
- .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
+ .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
//----------------------------------------------------------------------------
// Mapping of PyBlock.
//----------------------------------------------------------------------------
- py::class_<PyBlock>(m, "Block", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
- .def_property_readonly(
+ nb::class_<PyBlock>(m, "Block")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
+ .def_prop_ro(
"owner",
[](PyBlock &self) {
return self.getParentOperation()->createOpView();
},
"Returns the owning operation of this block.")
- .def_property_readonly(
+ .def_prop_ro(
"region",
[](PyBlock &self) {
MlirRegion region = mlirBlockGetParentRegion(self.get());
return PyRegion(self.getParentOperation(), region);
},
"Returns the owning region of this block.")
- .def_property_readonly(
+ .def_prop_ro(
"arguments",
[](PyBlock &self) {
return PyBlockArgumentList(self.getParentOperation(), self.get());
@@ -3265,7 +3313,7 @@ void mlir::python::populateIRCore(py::module &m) {
return mlirBlockEraseArgument(self.get(), index);
},
"Erase the argument at 'index' and remove it from the argument list.")
- .def_property_readonly(
+ .def_prop_ro(
"operations",
[](PyBlock &self) {
return PyOperationList(self.getParentOperation(), self.get());
@@ -3273,15 +3321,15 @@ void mlir::python::populateIRCore(py::module &m) {
"Returns a forward-optimized sequence of operations.")
.def_static(
"create_at_start",
- [](PyRegion &parent, const py::list &pyArgTypes,
- const std::optional<py::sequence> &pyArgLocs) {
+ [](PyRegion &parent, const nb::sequence &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
parent.checkValid();
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
mlirRegionInsertOwnedBlock(parent, 0, block);
return PyBlock(parent.getParentOperation(), block);
},
- py::arg("parent"), py::arg("arg_types") = py::list(),
- py::arg("arg_locs") = std::nullopt,
+ nb::arg("parent"), nb::arg("arg_types") = nb::list(),
+ nb::arg("arg_locs") = std::nullopt,
"Creates and returns a new Block at the beginning of the given "
"region (with given argument types and locations).")
.def(
@@ -3295,28 +3343,32 @@ void mlir::python::populateIRCore(py::module &m) {
"Append this block to a region, transferring ownership if necessary")
.def(
"create_before",
- [](PyBlock &self, const py::args &pyArgTypes,
- const std::optional<py::sequence> &pyArgLocs) {
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
self.checkValid();
- MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
MlirRegion region = mlirBlockGetParentRegion(self.get());
mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
return PyBlock(self.getParentOperation(), block);
},
- py::arg("arg_locs") = std::nullopt,
+ nb::arg("arg_types"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt,
"Creates and returns a new Block before this block "
"(with given argument types and locations).")
.def(
"create_after",
- [](PyBlock &self, const py::args &pyArgTypes,
- const std::optional<py::sequence> &pyArgLocs) {
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
self.checkValid();
- MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
MlirRegion region = mlirBlockGetParentRegion(self.get());
mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
return PyBlock(self.getParentOperation(), block);
},
- py::arg("arg_locs") = std::nullopt,
+ nb::arg("arg_types"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt,
"Creates and returns a new Block after this block "
"(with given argument types and locations).")
.def(
@@ -3333,7 +3385,7 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyBlock &self, PyBlock &other) {
return self.get().ptr == other.get().ptr;
})
- .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
+ .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
.def("__hash__",
[](PyBlock &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
@@ -3359,7 +3411,7 @@ void mlir::python::populateIRCore(py::module &m) {
operation.getOperation().setAttached(
self.getParentOperation().getObject());
},
- py::arg("operation"),
+ nb::arg("operation"),
"Appends an operation to this block. If the operation is currently "
"in another block, it will be moved.");
@@ -3367,39 +3419,41 @@ void mlir::python::populateIRCore(py::module &m) {
// Mapping of PyInsertionPoint.
//----------------------------------------------------------------------------
- py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
- .def(py::init<PyBlock &>(), py::arg("block"),
+ nb::class_<PyInsertionPoint>(m, "InsertionPoint")
+ .def(nb::init<PyBlock &>(), nb::arg("block"),
"Inserts after the last operation but still inside the block.")
.def("__enter__", &PyInsertionPoint::contextEnter)
- .def("__exit__", &PyInsertionPoint::contextExit)
- .def_property_readonly_static(
+ .def("__exit__", &PyInsertionPoint::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none())
+ .def_prop_ro_static(
"current",
- [](py::object & /*class*/) {
+ [](nb::object & /*class*/) {
auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
if (!ip)
- throw py::value_error("No current InsertionPoint");
+ throw nb::value_error("No current InsertionPoint");
return ip;
},
"Gets the InsertionPoint bound to the current thread or raises "
"ValueError if none has been set")
- .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
+ .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
"Inserts before a referenced operation.")
.def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
- py::arg("block"), "Inserts at the beginning of the block.")
+ nb::arg("block"), "Inserts at the beginning of the block.")
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
- py::arg("block"), "Inserts before the block terminator.")
- .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
+ nb::arg("block"), "Inserts before the block terminator.")
+ .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
"Inserts an operation.")
- .def_property_readonly(
+ .def_prop_ro(
"block", [](PyInsertionPoint &self) { return self.getBlock(); },
"Returns the block that this InsertionPoint points to.")
- .def_property_readonly(
+ .def_prop_ro(
"ref_operation",
- [](PyInsertionPoint &self) -> py::object {
+ [](PyInsertionPoint &self) -> nb::object {
auto refOperation = self.getRefOperation();
if (refOperation)
return refOperation->getObject();
- return py::none();
+ return nb::none();
},
"The reference operation before which new operations are "
"inserted, or None if the insertion point is at the end of "
@@ -3408,13 +3462,12 @@ void mlir::python::populateIRCore(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of PyAttribute.
//----------------------------------------------------------------------------
- py::class_<PyAttribute>(m, "Attribute", py::module_local())
+ nb::class_<PyAttribute>(m, "Attribute")
// Delegate to the PyAttribute copy constructor, which will also lifetime
// extend the backing context which owns the MlirAttribute.
- .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
+ .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
"Casts the passed attribute to the generic Attribute")
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyAttribute::getCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
.def_static(
"parse",
@@ -3426,24 +3479,24 @@ void mlir::python::populateIRCore(py::module &m) {
throw MLIRError("Unable to parse attribute", errors.take());
return attr;
},
- py::arg("asm"), py::arg("context") = py::none(),
+ nb::arg("asm"), nb::arg("context").none() = nb::none(),
"Parses an attribute from an assembly form. Raises an MLIRError on "
"failure.")
- .def_property_readonly(
+ .def_prop_ro(
"context",
[](PyAttribute &self) { return self.getContext().getObject(); },
"Context that owns the Attribute")
- .def_property_readonly(
- "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
+ .def_prop_ro("type",
+ [](PyAttribute &self) { return mlirAttributeGetType(self); })
.def(
"get_named",
[](PyAttribute &self, std::string name) {
return PyNamedAttribute(self, std::move(name));
},
- py::keep_alive<0, 1>(), "Binds a name to the attribute")
+ nb::keep_alive<0, 1>(), "Binds a name to the attribute")
.def("__eq__",
[](PyAttribute &self, PyAttribute &other) { return self == other; })
- .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
+ .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
.def("__hash__",
[](PyAttribute &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
@@ -3474,36 +3527,35 @@ void mlir::python::populateIRCore(py::module &m) {
printAccum.parts.append(")");
return printAccum.join();
})
- .def_property_readonly(
- "typeid",
- [](PyAttribute &self) -> MlirTypeID {
- MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
- assert(!mlirTypeIDIsNull(mlirTypeID) &&
- "mlirTypeID was expected to be non-null.");
- return mlirTypeID;
- })
+ .def_prop_ro("typeid",
+ [](PyAttribute &self) -> MlirTypeID {
+ MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ return mlirTypeID;
+ })
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
- std::optional<pybind11::function> typeCaster =
+ std::optional<nb::callable> typeCaster =
PyGlobals::get().lookupTypeCaster(mlirTypeID,
mlirAttributeGetDialect(self));
if (!typeCaster)
- return py::cast(self);
+ return nb::cast(self);
return typeCaster.value()(self);
});
//----------------------------------------------------------------------------
// Mapping of PyNamedAttribute
//----------------------------------------------------------------------------
- py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
+ nb::class_<PyNamedAttribute>(m, "NamedAttribute")
.def("__repr__",
[](PyNamedAttribute &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append("NamedAttribute(");
printAccum.parts.append(
- py::str(mlirIdentifierStr(self.namedAttr.name).data,
+ nb::str(mlirIdentifierStr(self.namedAttr.name).data,
mlirIdentifierStr(self.namedAttr.name).length));
printAccum.parts.append("=");
mlirAttributePrint(self.namedAttr.attribute,
@@ -3512,28 +3564,28 @@ void mlir::python::populateIRCore(py::module &m) {
printAccum.parts.append(")");
return printAccum.join();
})
- .def_property_readonly(
+ .def_prop_ro(
"name",
[](PyNamedAttribute &self) {
- return py::str(mlirIdentifierStr(self.namedAttr.name).data,
+ return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
mlirIdentifierStr(self.namedAttr.name).length);
},
"The name of the NamedAttribute binding")
- .def_property_readonly(
+ .def_prop_ro(
"attr",
[](PyNamedAttribute &self) { return self.namedAttr.attribute; },
- py::keep_alive<0, 1>(),
+ nb::keep_alive<0, 1>(),
"The underlying generic attribute of the NamedAttribute binding");
//----------------------------------------------------------------------------
// Mapping of PyType.
//----------------------------------------------------------------------------
- py::class_<PyType>(m, "Type", py::module_local())
+ nb::class_<PyType>(m, "Type")
// Delegate to the PyType copy constructor, which will also lifetime
// extend the backing context which owns the MlirType.
- .def(py::init<PyType &>(), py::arg("cast_from_type"),
+ .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
"Casts the passed type to the generic Type")
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
.def_static(
"parse",
@@ -3545,13 +3597,15 @@ void mlir::python::populateIRCore(py::module &m) {
throw MLIRError("Unable to parse type", errors.take());
return type;
},
- py::arg("asm"), py::arg("context") = py::none(),
+ nb::arg("asm"), nb::arg("context").none() = nb::none(),
kContextParseTypeDocstring)
- .def_property_readonly(
+ .def_prop_ro(
"context", [](PyType &self) { return self.getContext().getObject(); },
"Context that owns the Type")
.def("__eq__", [](PyType &self, PyType &other) { return self == other; })
- .def("__eq__", [](PyType &self, py::object &other) { return false; })
+ .def(
+ "__eq__", [](PyType &self, nb::object &other) { return false; },
+ nb::arg("other").none())
.def("__hash__",
[](PyType &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
@@ -3585,28 +3639,27 @@ void mlir::python::populateIRCore(py::module &m) {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
- std::optional<pybind11::function> typeCaster =
+ std::optional<nb::callable> typeCaster =
PyGlobals::get().lookupTypeCaster(mlirTypeID,
mlirTypeGetDialect(self));
if (!typeCaster)
- return py::cast(self);
+ return nb::cast(self);
return typeCaster.value()(self);
})
- .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
+ .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
return mlirTypeID;
- auto origRepr =
- pybind11::repr(pybind11::cast(self)).cast<std::string>();
- throw py::value_error(
- (origRepr + llvm::Twine(" has no typeid.")).str());
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
+ throw nb::value_error(
+ (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
});
//----------------------------------------------------------------------------
// Mapping of PyTypeID.
//----------------------------------------------------------------------------
- py::class_<PyTypeID>(m, "TypeID", py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
+ nb::class_<PyTypeID>(m, "TypeID")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
// Note, this tests whether the underlying TypeIDs are the same,
// not whether the wrapper MlirTypeIDs are the same, nor whether
@@ -3614,7 +3667,7 @@ void mlir::python::populateIRCore(py::module &m) {
.def("__eq__",
[](PyTypeID &self, PyTypeID &other) { return self == other; })
.def("__eq__",
- [](PyTypeID &self, const py::object &other) { return false; })
+ [](PyTypeID &self, const nb::object &other) { return false; })
// Note, this gives the hash value of the underlying TypeID, not the
// hash value of the Python object, nor the hash value of the
// MlirTypeID wrapper.
@@ -3625,20 +3678,20 @@ void mlir::python::populateIRCore(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of Value.
//----------------------------------------------------------------------------
- py::class_<PyValue>(m, "Value", py::module_local())
- .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
+ nb::class_<PyValue>(m, "Value")
+ .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
- .def_property_readonly(
+ .def_prop_ro(
"context",
[](PyValue &self) { return self.getParentOperation()->getContext(); },
"Context in which the value lives.")
.def(
"dump", [](PyValue &self) { mlirValueDump(self.get()); },
kDumpDocstring)
- .def_property_readonly(
+ .def_prop_ro(
"owner",
- [](PyValue &self) -> py::object {
+ [](PyValue &self) -> nb::object {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(
@@ -3651,22 +3704,22 @@ void mlir::python::populateIRCore(py::module &m) {
if (mlirValueIsABlockArgument(v)) {
MlirBlock block = mlirBlockArgumentGetOwner(self.get());
- return py::cast(PyBlock(self.getParentOperation(), block));
+ return nb::cast(PyBlock(self.getParentOperation(), block));
}
assert(false && "Value must be a block argument or an op result");
- return py::none();
+ return nb::none();
})
- .def_property_readonly("uses",
- [](PyValue &self) {
- return PyOpOperandIterator(
- mlirValueGetFirstUse(self.get()));
- })
+ .def_prop_ro("uses",
+ [](PyValue &self) {
+ return PyOpOperandIterator(
+ mlirValueGetFirstUse(self.get()));
+ })
.def("__eq__",
[](PyValue &self, PyValue &other) {
return self.get().ptr == other.get().ptr;
})
- .def("__eq__", [](PyValue &self, py::object other) { return false; })
+ .def("__eq__", [](PyValue &self, nb::object other) { return false; })
.def("__hash__",
[](PyValue &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
@@ -3698,26 +3751,26 @@ void mlir::python::populateIRCore(py::module &m) {
mlirAsmStateDestroy(valueState);
return printAccum.join();
},
- py::arg("use_local_scope") = false)
+ nb::arg("use_local_scope") = false)
.def(
"get_name",
- [](PyValue &self, std::reference_wrapper<PyAsmState> state) {
+ [](PyValue &self, PyAsmState &state) {
PyPrintAccumulator printAccum;
- MlirAsmState valueState = state.get().get();
+ MlirAsmState valueState = state.get();
mlirValuePrintAsOperand(self.get(), valueState,
printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
- py::arg("state"), kGetNameAsOperand)
- .def_property_readonly(
- "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
+ nb::arg("state"), kGetNameAsOperand)
+ .def_prop_ro("type",
+ [](PyValue &self) { return mlirValueGetType(self.get()); })
.def(
"set_type",
[](PyValue &self, const PyType &type) {
return mlirValueSetType(self.get(), type);
},
- py::arg("type"))
+ nb::arg("type"))
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {
@@ -3730,22 +3783,22 @@ void mlir::python::populateIRCore(py::module &m) {
MlirOperation exceptedUser = exception.get();
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
},
- py::arg("with"), py::arg("exceptions"),
+ nb::arg("with"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
- [](MlirValue self, MlirValue with, py::list exceptions) {
+ [](MlirValue self, MlirValue with, nb::list exceptions) {
// Convert Python list to a SmallVector of MlirOperations
llvm::SmallVector<MlirOperation> exceptionOps;
- for (py::handle exception : exceptions) {
- exceptionOps.push_back(exception.cast<PyOperation &>().get());
+ for (nb::handle exception : exceptions) {
+ exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
}
mlirValueReplaceAllUsesExcept(
self, with, static_cast<intptr_t>(exceptionOps.size()),
exceptionOps.data());
},
- py::arg("with"), py::arg("exceptions"),
+ nb::arg("with"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) { return self.maybeDownCast(); });
@@ -3753,20 +3806,20 @@ void mlir::python::populateIRCore(py::module &m) {
PyOpResult::bind(m);
PyOpOperand::bind(m);
- py::class_<PyAsmState>(m, "AsmState", py::module_local())
- .def(py::init<PyValue &, bool>(), py::arg("value"),
- py::arg("use_local_scope") = false)
- .def(py::init<PyOperationBase &, bool>(), py::arg("op"),
- py::arg("use_local_scope") = false);
+ nb::class_<PyAsmState>(m, "AsmState")
+ .def(nb::init<PyValue &, bool>(), nb::arg("value"),
+ nb::arg("use_local_scope") = false)
+ .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
+ nb::arg("use_local_scope") = false);
//----------------------------------------------------------------------------
// Mapping of SymbolTable.
//----------------------------------------------------------------------------
- py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
- .def(py::init<PyOperationBase &>())
+ nb::class_<PySymbolTable>(m, "SymbolTable")
+ .def(nb::init<PyOperationBase &>())
.def("__getitem__", &PySymbolTable::dunderGetItem)
- .def("insert", &PySymbolTable::insert, py::arg("operation"))
- .def("erase", &PySymbolTable::erase, py::arg("operation"))
+ .def("insert", &PySymbolTable::insert, nb::arg("operation"))
+ .def("erase", &PySymbolTable::erase, nb::arg("operation"))
.def("__delitem__", &PySymbolTable::dunderDel)
.def("__contains__",
[](PySymbolTable &table, const std::string &name) {
@@ -3775,19 +3828,19 @@ void mlir::python::populateIRCore(py::module &m) {
})
// Static helpers.
.def_static("set_symbol_name", &PySymbolTable::setSymbolName,
- py::arg("symbol"), py::arg("name"))
+ nb::arg("symbol"), nb::arg("name"))
.def_static("get_symbol_name", &PySymbolTable::getSymbolName,
- py::arg("symbol"))
+ nb::arg("symbol"))
.def_static("get_visibility", &PySymbolTable::getVisibility,
- py::arg("symbol"))
+ nb::arg("symbol"))
.def_static("set_visibility", &PySymbolTable::setVisibility,
- py::arg("symbol"), py::arg("visibility"))
+ nb::arg("symbol"), nb::arg("visibility"))
.def_static("replace_all_symbol_uses",
- &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
- py::arg("new_symbol"), py::arg("from_op"))
+ &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
+ nb::arg("new_symbol"), nb::arg("from_op"))
.def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
- py::arg("from_op"), py::arg("all_sym_uses_visible"),
- py::arg("callback"));
+ nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
+ nb::arg("callback"));
// Container bindings.
PyBlockArgumentList::bind(m);
@@ -3809,14 +3862,15 @@ void mlir::python::populateIRCore(py::module &m) {
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
- py::register_local_exception_translator([](std::exception_ptr p) {
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
// We can't define exceptions with custom fields through pybind, so instead
// the exception class is defined in python and imported here.
try {
if (p)
std::rethrow_exception(p);
} catch (const MLIRError &e) {
- py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("MLIRError")(e.message, e.errorDiagnostics);
PyErr_SetObject(PyExc_Exception, obj.ptr());
}
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 54cfa56066eb8b..c339a93e31857b 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/optional.h>
+#include <nanobind/stl/vector.h>
+
#include <cstdint>
#include <optional>
-#include <pybind11/cast.h>
-#include <pybind11/detail/common.h>
-#include <pybind11/pybind11.h>
-#include <pybind11/pytypes.h>
#include <string>
#include <utility>
#include <vector>
@@ -24,7 +24,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-namespace py = pybind11;
+namespace nb = nanobind;
namespace mlir {
namespace python {
@@ -53,10 +53,10 @@ namespace {
/// Takes in an optional ist of operands and converts them into a SmallVector
/// of MlirVlaues. Returns an empty SmallVector if the list is empty.
-llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
+llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
llvm::SmallVector<MlirValue> mlirOperands;
- if (!operandList || operandList->empty()) {
+ if (!operandList || operandList->size() == 0) {
return mlirOperands;
}
@@ -68,40 +68,42 @@ llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
PyValue *val;
try {
- val = py::cast<PyValue *>(it.value());
+ val = nb::cast<PyValue *>(it.value());
if (!val)
- throw py::cast_error();
+ throw nb::cast_error();
mlirOperands.push_back(val->get());
continue;
- } catch (py::cast_error &err) {
+ } catch (nb::cast_error &err) {
// Intentionally unhandled to try sequence below first.
(void)err;
}
try {
- auto vals = py::cast<py::sequence>(it.value());
- for (py::object v : vals) {
+ auto vals = nb::cast<nb::sequence>(it.value());
+ for (nb::handle v : vals) {
try {
- val = py::cast<PyValue *>(v);
+ val = nb::cast<PyValue *>(v);
if (!val)
- throw py::cast_error();
+ throw nb::cast_error();
mlirOperands.push_back(val->get());
- } catch (py::cast_error &err) {
- throw py::value_error(
+ } catch (nb::cast_error &err) {
+ throw nb::value_error(
(llvm::Twine("Operand ") + llvm::Twine(it.index()) +
" must be a Value or Sequence of Values (" + err.what() + ")")
- .str());
+ .str()
+ .c_str());
}
}
continue;
- } catch (py::cast_error &err) {
- throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ } catch (nb::cast_error &err) {
+ throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
" must be a Value or Sequence of Values (" +
err.what() + ")")
- .str());
+ .str()
+ .c_str());
}
- throw py::cast_error();
+ throw nb::cast_error();
}
return mlirOperands;
@@ -144,24 +146,24 @@ wrapRegions(std::optional<std::vector<PyRegion>> regions) {
template <typename ConcreteIface>
class PyConcreteOpInterface {
protected:
- using ClassTy = py::class_<ConcreteIface>;
+ using ClassTy = nb::class_<ConcreteIface>;
using GetTypeIDFunctionTy = MlirTypeID (*)();
public:
/// Constructs an interface instance from an object that is either an
/// operation or a subclass of OpView. In the latter case, only the static
/// methods of the interface are accessible to the caller.
- PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
+ PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
: obj(std::move(object)) {
try {
- operation = &py::cast<PyOperation &>(obj);
- } catch (py::cast_error &) {
+ operation = &nb::cast<PyOperation &>(obj);
+ } catch (nb::cast_error &) {
// Do nothing.
}
try {
- operation = &py::cast<PyOpView &>(obj).getOperation();
- } catch (py::cast_error &) {
+ operation = &nb::cast<PyOpView &>(obj).getOperation();
+ } catch (nb::cast_error &) {
// Do nothing.
}
@@ -169,7 +171,7 @@ class PyConcreteOpInterface {
if (!mlirOperationImplementsInterface(*operation,
ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
- throw py::value_error(msg + ConcreteIface::pyClassName);
+ throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
}
MlirIdentifier identifier = mlirOperationGetName(*operation);
@@ -177,9 +179,9 @@ class PyConcreteOpInterface {
opName = std::string(stringRef.data, stringRef.length);
} else {
try {
- opName = obj.attr("OPERATION_NAME").template cast<std::string>();
- } catch (py::cast_error &) {
- throw py::type_error(
+ opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
+ } catch (nb::cast_error &) {
+ throw nb::type_error(
"Op interface does not refer to an operation or OpView class");
}
@@ -187,22 +189,19 @@ class PyConcreteOpInterface {
mlirStringRefCreate(opName.data(), opName.length()),
context.resolve().get(), ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
- throw py::value_error(msg + ConcreteIface::pyClassName);
+ throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
}
}
}
/// Creates the Python bindings for this class in the given module.
- static void bind(py::module &m) {
- py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
- py::module_local());
- cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
- py::arg("context") = py::none(), constructorDoc)
- .def_property_readonly("operation",
- &PyConcreteOpInterface::getOperationObject,
- operationDoc)
- .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
- opviewDoc);
+ static void bind(nb::module_ &m) {
+ nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
+ cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
+ nb::arg("context").none() = nb::none(), constructorDoc)
+ .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
+ operationDoc)
+ .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
ConcreteIface::bindDerived(cls);
}
@@ -216,9 +215,9 @@ class PyConcreteOpInterface {
/// Returns the operation instance from which this object was constructed.
/// Throws a type error if this object was constructed from a subclass of
/// OpView.
- py::object getOperationObject() {
+ nb::object getOperationObject() {
if (operation == nullptr) {
- throw py::type_error("Cannot get an operation from a static interface");
+ throw nb::type_error("Cannot get an operation from a static interface");
}
return operation->getRef().releaseObject();
@@ -227,9 +226,9 @@ class PyConcreteOpInterface {
/// Returns the opview of the operation instance from which this object was
/// constructed. Throws a type error if this object was constructed form a
/// subclass of OpView.
- py::object getOpView() {
+ nb::object getOpView() {
if (operation == nullptr) {
- throw py::type_error("Cannot get an opview from a static interface");
+ throw nb::type_error("Cannot get an opview from a static interface");
}
return operation->createOpView();
@@ -242,7 +241,7 @@ class PyConcreteOpInterface {
private:
PyOperation *operation = nullptr;
std::string opName;
- py::object obj;
+ nb::object obj;
};
/// Python wrapper for InferTypeOpInterface. This interface has only static
@@ -276,7 +275,7 @@ class PyInferTypeOpInterface
/// Given the arguments required to build an operation, attempts to infer its
/// return types. Throws value_error on failure.
std::vector<PyType>
- inferReturnTypes(std::optional<py::list> operandList,
+ inferReturnTypes(std::optional<nb::list> operandList,
std::optional<PyAttribute> attributes, void *properties,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context,
@@ -299,7 +298,7 @@ class PyInferTypeOpInterface
mlirRegions.data(), &appendResultsCallback, &data);
if (mlirLogicalResultIsFailure(result)) {
- throw py::value_error("Failed to infer result types");
+ throw nb::value_error("Failed to infer result types");
}
return inferredTypes;
@@ -307,11 +306,12 @@ class PyInferTypeOpInterface
static void bindDerived(ClassTy &cls) {
cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
- py::arg("operands") = py::none(),
- py::arg("attributes") = py::none(),
- py::arg("properties") = py::none(), py::arg("regions") = py::none(),
- py::arg("context") = py::none(), py::arg("loc") = py::none(),
- inferReturnTypesDoc);
+ nb::arg("operands").none() = nb::none(),
+ nb::arg("attributes").none() = nb::none(),
+ nb::arg("properties").none() = nb::none(),
+ nb::arg("regions").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), inferReturnTypesDoc);
}
};
@@ -319,9 +319,9 @@ class PyInferTypeOpInterface
class PyShapedTypeComponents {
public:
PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
- PyShapedTypeComponents(py::list shape, MlirType elementType)
+ PyShapedTypeComponents(nb::list shape, MlirType elementType)
: shape(std::move(shape)), elementType(elementType), ranked(true) {}
- PyShapedTypeComponents(py::list shape, MlirType elementType,
+ PyShapedTypeComponents(nb::list shape, MlirType elementType,
MlirAttribute attribute)
: shape(std::move(shape)), elementType(elementType), attribute(attribute),
ranked(true) {}
@@ -330,10 +330,9 @@ class PyShapedTypeComponents {
: shape(other.shape), elementType(other.elementType),
attribute(other.attribute), ranked(other.ranked) {}
- static void bind(py::module &m) {
- py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
- py::module_local())
- .def_property_readonly(
+ static void bind(nb::module_ &m) {
+ nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
+ .def_prop_ro(
"element_type",
[](PyShapedTypeComponents &self) { return self.elementType; },
"Returns the element type of the shaped type components.")
@@ -342,57 +341,57 @@ class PyShapedTypeComponents {
[](PyType &elementType) {
return PyShapedTypeComponents(elementType);
},
- py::arg("element_type"),
+ nb::arg("element_type"),
"Create an shaped type components object with only the element "
"type.")
.def_static(
"get",
- [](py::list shape, PyType &elementType) {
+ [](nb::list shape, PyType &elementType) {
return PyShapedTypeComponents(std::move(shape), elementType);
},
- py::arg("shape"), py::arg("element_type"),
+ nb::arg("shape"), nb::arg("element_type"),
"Create a ranked shaped type components object.")
.def_static(
"get",
- [](py::list shape, PyType &elementType, PyAttribute &attribute) {
+ [](nb::list shape, PyType &elementType, PyAttribute &attribute) {
return PyShapedTypeComponents(std::move(shape), elementType,
attribute);
},
- py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
+ nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
"Create a ranked shaped type components object with attribute.")
- .def_property_readonly(
+ .def_prop_ro(
"has_rank",
[](PyShapedTypeComponents &self) -> bool { return self.ranked; },
"Returns whether the given shaped type component is ranked.")
- .def_property_readonly(
+ .def_prop_ro(
"rank",
- [](PyShapedTypeComponents &self) -> py::object {
+ [](PyShapedTypeComponents &self) -> nb::object {
if (!self.ranked) {
- return py::none();
+ return nb::none();
}
- return py::int_(self.shape.size());
+ return nb::int_(self.shape.size());
},
"Returns the rank of the given ranked shaped type components. If "
"the shaped type components does not have a rank, None is "
"returned.")
- .def_property_readonly(
+ .def_prop_ro(
"shape",
- [](PyShapedTypeComponents &self) -> py::object {
+ [](PyShapedTypeComponents &self) -> nb::object {
if (!self.ranked) {
- return py::none();
+ return nb::none();
}
- return py::list(self.shape);
+ return nb::list(self.shape);
},
"Returns the shape of the ranked shaped type components as a list "
"of integers. Returns none if the shaped type component does not "
"have a rank.");
}
- pybind11::object getCapsule();
- static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
+ nb::object getCapsule();
+ static PyShapedTypeComponents createFromCapsule(nb::object capsule);
private:
- py::list shape;
+ nb::list shape;
MlirType elementType;
MlirAttribute attribute;
bool ranked{false};
@@ -424,7 +423,7 @@ class PyInferShapedTypeOpInterface
if (!hasRank) {
data->inferredShapedTypeComponents.emplace_back(elementType);
} else {
- py::list shapeList;
+ nb::list shapeList;
for (intptr_t i = 0; i < rank; ++i) {
shapeList.append(shape[i]);
}
@@ -436,7 +435,7 @@ class PyInferShapedTypeOpInterface
/// Given the arguments required to build an operation, attempts to infer the
/// shaped type components. Throws value_error on failure.
std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
- std::optional<py::list> operandList,
+ std::optional<nb::list> operandList,
std::optional<PyAttribute> attributes, void *properties,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context, DefaultingPyLocation location) {
@@ -458,7 +457,7 @@ class PyInferShapedTypeOpInterface
mlirRegions.data(), &appendResultsCallback, &data);
if (mlirLogicalResultIsFailure(result)) {
- throw py::value_error("Failed to infer result shape type components");
+ throw nb::value_error("Failed to infer result shape type components");
}
return inferredShapedTypeComponents;
@@ -467,14 +466,16 @@ class PyInferShapedTypeOpInterface
static void bindDerived(ClassTy &cls) {
cls.def("inferReturnTypeComponents",
&PyInferShapedTypeOpInterface::inferReturnTypeComponents,
- py::arg("operands") = py::none(),
- py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
- py::arg("properties") = py::none(), py::arg("context") = py::none(),
- py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
+ nb::arg("operands").none() = nb::none(),
+ nb::arg("attributes").none() = nb::none(),
+ nb::arg("regions").none() = nb::none(),
+ nb::arg("properties").none() = nb::none(),
+ nb::arg("context").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc);
}
};
-void populateIRInterfaces(py::module &m) {
+void populateIRInterfaces(nb::module_ &m) {
PyInferTypeOpInterface::bind(m);
PyShapedTypeComponents::bind(m);
PyInferShapedTypeOpInterface::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 6727860c094a2a..416a14218f125d 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -7,16 +7,19 @@
//===----------------------------------------------------------------------===//
#include "IRModule.h"
-#include "Globals.h"
-#include "PybindUtils.h"
-#include "mlir-c/Bindings/Python/Interop.h"
-#include "mlir-c/Support.h"
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
#include <optional>
#include <vector>
-namespace py = pybind11;
+#include "Globals.h"
+#include "NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Support.h"
+
+namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
@@ -41,14 +44,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
return true;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
- py::object loaded = py::none();
+ nb::object loaded = nb::none();
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
try {
- loaded = py::module::import(moduleName.c_str());
- } catch (py::error_already_set &e) {
+ loaded = nb::module_::import_(moduleName.c_str());
+ } catch (nb::python_error &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
continue;
}
@@ -66,41 +69,39 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
- py::function pyFunc, bool replace) {
- py::object &found = attributeBuilderMap[attributeKind];
+ nb::callable pyFunc, bool replace) {
+ nb::object &found = attributeBuilderMap[attributeKind];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind +
"' is already registered with func: " +
- py::str(found).operator std::string())
+ nb::cast<std::string>(nb::str(found)))
.str());
}
found = std::move(pyFunc);
}
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
- pybind11::function typeCaster,
- bool replace) {
- pybind11::object &found = typeCasterMap[mlirTypeID];
+ nb::callable typeCaster, bool replace) {
+ nb::object &found = typeCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Type caster is already registered with caster: " +
- py::str(found).operator std::string());
+ nb::cast<std::string>(nb::str(found)));
found = std::move(typeCaster);
}
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
- pybind11::function valueCaster,
- bool replace) {
- pybind11::object &found = valueCasterMap[mlirTypeID];
+ nb::callable valueCaster, bool replace) {
+ nb::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
- py::repr(found).cast<std::string>());
+ nb::cast<std::string>(nb::repr(found)));
found = std::move(valueCaster);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
- py::object pyClass) {
- py::object &found = dialectClassMap[dialectNamespace];
+ nb::object pyClass) {
+ nb::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
dialectNamespace + "' is already registered.")
@@ -110,8 +111,8 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
- py::object pyClass, bool replace) {
- py::object &found = operationClassMap[operationName];
+ nb::object pyClass, bool replace) {
+ nb::object &found = operationClassMap[operationName];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
@@ -120,7 +121,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
found = std::move(pyClass);
}
-std::optional<py::function>
+std::optional<nb::callable>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
@@ -130,7 +131,7 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
return std::nullopt;
}
-std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
+std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
@@ -142,7 +143,7 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
return std::nullopt;
}
-std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
+std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
@@ -154,7 +155,7 @@ std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
return std::nullopt;
}
-std::optional<py::object>
+std::optional<nb::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
if (!loadDialectModule(dialectNamespace))
@@ -168,7 +169,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
return std::nullopt;
}
-std::optional<pybind11::object>
+std::optional<nb::object>
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Make sure dialect module is loaded.
auto split = operationName.split('.');
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 172898cfda0c52..a242ff26bbbf57 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -10,20 +10,22 @@
#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
+
#include <optional>
#include <utility>
#include <vector>
#include "Globals.h"
-#include "PybindUtils.h"
-
+#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
@@ -49,7 +51,7 @@ class PyValue;
template <typename T>
class PyObjectRef {
public:
- PyObjectRef(T *referrent, pybind11::object object)
+ PyObjectRef(T *referrent, nanobind::object object)
: referrent(referrent), object(std::move(object)) {
assert(this->referrent &&
"cannot construct PyObjectRef with null referrent");
@@ -67,13 +69,13 @@ class PyObjectRef {
int getRefCount() {
if (!object)
return 0;
- return object.ref_count();
+ return Py_REFCNT(object.ptr());
}
/// Releases the object held by this instance, returning it.
/// This is the proper thing to return from a function that wants to return
/// the reference. Note that this does not work from initializers.
- pybind11::object releaseObject() {
+ nanobind::object releaseObject() {
assert(referrent && object);
referrent = nullptr;
auto stolen = std::move(object);
@@ -85,7 +87,7 @@ class PyObjectRef {
assert(referrent && object);
return referrent;
}
- pybind11::object getObject() {
+ nanobind::object getObject() {
assert(referrent && object);
return object;
}
@@ -93,7 +95,7 @@ class PyObjectRef {
private:
T *referrent;
- pybind11::object object;
+ nanobind::object object;
};
/// Tracks an entry in the thread context stack. New entries are pushed onto
@@ -112,9 +114,9 @@ class PyThreadContextEntry {
Location,
};
- PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
- pybind11::object insertionPoint,
- pybind11::object location)
+ PyThreadContextEntry(FrameKind frameKind, nanobind::object context,
+ nanobind::object insertionPoint,
+ nanobind::object location)
: context(std::move(context)), insertionPoint(std::move(insertionPoint)),
location(std::move(location)), frameKind(frameKind) {}
@@ -134,26 +136,26 @@ class PyThreadContextEntry {
/// Stack management.
static PyThreadContextEntry *getTopOfStack();
- static pybind11::object pushContext(PyMlirContext &context);
+ static nanobind::object pushContext(nanobind::object context);
static void popContext(PyMlirContext &context);
- static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
+ static nanobind::object pushInsertionPoint(nanobind::object insertionPoint);
static void popInsertionPoint(PyInsertionPoint &insertionPoint);
- static pybind11::object pushLocation(PyLocation &location);
+ static nanobind::object pushLocation(nanobind::object location);
static void popLocation(PyLocation &location);
/// Gets the thread local stack.
static std::vector<PyThreadContextEntry> &getStack();
private:
- static void push(FrameKind frameKind, pybind11::object context,
- pybind11::object insertionPoint, pybind11::object location);
+ static void push(FrameKind frameKind, nanobind::object context,
+ nanobind::object insertionPoint, nanobind::object location);
/// An object reference to the PyContext.
- pybind11::object context;
+ nanobind::object context;
/// An object reference to the current insertion point.
- pybind11::object insertionPoint;
+ nanobind::object insertionPoint;
/// An object reference to the current location.
- pybind11::object location;
+ nanobind::object location;
// The kind of push that was performed.
FrameKind frameKind;
};
@@ -163,14 +165,15 @@ using PyMlirContextRef = PyObjectRef<PyMlirContext>;
class PyMlirContext {
public:
PyMlirContext() = delete;
+ PyMlirContext(MlirContext context);
PyMlirContext(const PyMlirContext &) = delete;
PyMlirContext(PyMlirContext &&) = delete;
- /// For the case of a python __init__ (py::init) method, pybind11 is quite
- /// strict about needing to return a pointer that is not yet associated to
- /// an py::object. Since the forContext() method acts like a pool, possibly
- /// returning a recycled context, it does not satisfy this need. The usual
- /// way in python to accomplish such a thing is to override __new__, but
+ /// For the case of a python __init__ (nanobind::init) method, pybind11 is
+ /// quite strict about needing to return a pointer that is not yet associated
+ /// to an nanobind::object. Since the forContext() method acts like a pool,
+ /// possibly returning a recycled context, it does not satisfy this need. The
+ /// usual way in python to accomplish such a thing is to override __new__, but
/// that is also not supported by pybind11. Instead, we use this entry
/// point which always constructs a fresh context (which cannot alias an
/// existing one because it is fresh).
@@ -187,17 +190,17 @@ class PyMlirContext {
/// Gets a strong reference to this context, which will ensure it is kept
/// alive for the life of the reference.
PyMlirContextRef getRef() {
- return PyMlirContextRef(this, pybind11::cast(this));
+ return PyMlirContextRef(this, nanobind::cast(this));
}
/// Gets a capsule wrapping the void* within the MlirContext.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
/// Note that PyMlirContext instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirContext
/// is taken by calling this function.
- static pybind11::object createFromCapsule(pybind11::object capsule);
+ static nanobind::object createFromCapsule(nanobind::object capsule);
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
@@ -237,14 +240,14 @@ class PyMlirContext {
size_t getLiveModuleCount();
/// Enter and exit the context manager.
- pybind11::object contextEnter();
- void contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb);
+ static nanobind::object contextEnter(nanobind::object context);
+ void contextExit(const nanobind::object &excType,
+ const nanobind::object &excVal,
+ const nanobind::object &excTb);
/// Attaches a Python callback as a diagnostic handler, returning a
/// registration object (internally a PyDiagnosticHandler).
- pybind11::object attachDiagnosticHandler(pybind11::object callback);
+ nanobind::object attachDiagnosticHandler(nanobind::object callback);
/// Controls whether error diagnostics should be propagated to diagnostic
/// handlers, instead of being captured by `ErrorCapture`.
@@ -252,8 +255,6 @@ class PyMlirContext {
struct ErrorCapture;
private:
- PyMlirContext(MlirContext context);
-
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an
@@ -268,7 +269,7 @@ class PyMlirContext {
// from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveModuleMap =
- llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
+ llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
LiveModuleMap liveModules;
// Interns all live operations associated with this context. Operations
@@ -276,7 +277,7 @@ class PyMlirContext {
// removed from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveOperationMap =
- llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
+ llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
LiveOperationMap liveOperations;
bool emitErrorDiagnostics = false;
@@ -324,19 +325,19 @@ class PyLocation : public BaseContextObject {
MlirLocation get() const { return loc; }
/// Enter and exit the context manager.
- pybind11::object contextEnter();
- void contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb);
+ static nanobind::object contextEnter(nanobind::object location);
+ void contextExit(const nanobind::object &excType,
+ const nanobind::object &excVal,
+ const nanobind::object &excTb);
/// Gets a capsule wrapping the void* within the MlirLocation.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyLocation from the MlirLocation wrapped by a capsule.
/// Note that PyLocation instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirLocation
/// is taken by calling this function.
- static PyLocation createFromCapsule(pybind11::object capsule);
+ static PyLocation createFromCapsule(nanobind::object capsule);
private:
MlirLocation loc;
@@ -353,8 +354,8 @@ class PyDiagnostic {
bool isValid() { return valid; }
MlirDiagnosticSeverity getSeverity();
PyLocation getLocation();
- pybind11::str getMessage();
- pybind11::tuple getNotes();
+ nanobind::str getMessage();
+ nanobind::tuple getNotes();
/// Materialized diagnostic information. This is safe to access outside the
/// diagnostic callback.
@@ -373,7 +374,7 @@ class PyDiagnostic {
/// If notes have been materialized from the diagnostic, then this will
/// be populated with the corresponding objects (all castable to
/// PyDiagnostic).
- std::optional<pybind11::tuple> materializedNotes;
+ std::optional<nanobind::tuple> materializedNotes;
bool valid = true;
};
@@ -398,7 +399,7 @@ class PyDiagnostic {
/// is no way to attach an existing handler object).
class PyDiagnosticHandler {
public:
- PyDiagnosticHandler(MlirContext context, pybind11::object callback);
+ PyDiagnosticHandler(MlirContext context, nanobind::object callback);
~PyDiagnosticHandler();
bool isAttached() { return registeredID.has_value(); }
@@ -407,16 +408,16 @@ class PyDiagnosticHandler {
/// Detaches the handler. Does nothing if not attached.
void detach();
- pybind11::object contextEnter() { return pybind11::cast(this); }
- void contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb) {
+ nanobind::object contextEnter() { return nanobind::cast(this); }
+ void contextExit(const nanobind::object &excType,
+ const nanobind::object &excVal,
+ const nanobind::object &excTb) {
detach();
}
private:
MlirContext context;
- pybind11::object callback;
+ nanobind::object callback;
std::optional<MlirDiagnosticHandlerID> registeredID;
bool hadError = false;
friend class PyMlirContext;
@@ -477,12 +478,12 @@ class PyDialects : public BaseContextObject {
/// objects of this type will be returned directly.
class PyDialect {
public:
- PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
+ PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {}
- pybind11::object getDescriptor() { return descriptor; }
+ nanobind::object getDescriptor() { return descriptor; }
private:
- pybind11::object descriptor;
+ nanobind::object descriptor;
};
/// Wrapper around an MlirDialectRegistry.
@@ -505,8 +506,8 @@ class PyDialectRegistry {
operator MlirDialectRegistry() const { return registry; }
MlirDialectRegistry get() const { return registry; }
- pybind11::object getCapsule();
- static PyDialectRegistry createFromCapsule(pybind11::object capsule);
+ nanobind::object getCapsule();
+ static PyDialectRegistry createFromCapsule(nanobind::object capsule);
private:
MlirDialectRegistry registry;
@@ -542,26 +543,25 @@ class PyModule : public BaseContextObject {
/// Gets a strong reference to this module.
PyModuleRef getRef() {
- return PyModuleRef(this,
- pybind11::reinterpret_borrow<pybind11::object>(handle));
+ return PyModuleRef(this, nanobind::borrow<nanobind::object>(handle));
}
/// Gets a capsule wrapping the void* within the MlirModule.
/// Note that the module does not (yet) provide a corresponding factory for
/// constructing from a capsule as that would require uniquing PyModule
/// instances, which is not currently done.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyModule from the MlirModule wrapped by a capsule.
/// Note that PyModule instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirModule
/// is taken by calling this function.
- static pybind11::object createFromCapsule(pybind11::object capsule);
+ static nanobind::object createFromCapsule(nanobind::object capsule);
private:
PyModule(PyMlirContextRef contextRef, MlirModule module);
MlirModule module;
- pybind11::handle handle;
+ nanobind::handle handle;
};
class PyAsmState;
@@ -574,18 +574,18 @@ class PyOperationBase {
/// Implements the bound 'print' method and helps with others.
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified, py::object fileObject, bool binary,
+ bool assumeVerified, nanobind::object fileObject, bool binary,
bool skipRegions);
- void print(PyAsmState &state, py::object fileObject, bool binary);
+ void print(PyAsmState &state, nanobind::object fileObject, bool binary);
- pybind11::object getAsm(bool binary,
+ nanobind::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, bool skipRegions);
// Implement the bound 'writeBytecode' method.
- void writeBytecode(const pybind11::object &fileObject,
+ void writeBytecode(const nanobind::object &fileObject,
std::optional<int64_t> bytecodeVersion);
// Implement the walk method.
@@ -621,13 +621,13 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// it with a parentKeepAlive.
static PyOperationRef
forOperation(PyMlirContextRef contextRef, MlirOperation operation,
- pybind11::object parentKeepAlive = pybind11::object());
+ nanobind::object parentKeepAlive = nanobind::object());
/// Creates a detached operation. The operation must not be associated with
/// any existing live operation.
static PyOperationRef
createDetached(PyMlirContextRef contextRef, MlirOperation operation,
- pybind11::object parentKeepAlive = pybind11::object());
+ nanobind::object parentKeepAlive = nanobind::object());
/// Parses a source string (either text assembly or bytecode), creating a
/// detached operation.
@@ -640,7 +640,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
void detachFromParent() {
mlirOperationRemoveFromParent(getOperation());
setDetached();
- parentKeepAlive = pybind11::object();
+ parentKeepAlive = nanobind::object();
}
/// Gets the backing operation.
@@ -651,12 +651,11 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
}
PyOperationRef getRef() {
- return PyOperationRef(
- this, pybind11::reinterpret_borrow<pybind11::object>(handle));
+ return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
}
bool isAttached() { return attached; }
- void setAttached(const pybind11::object &parent = pybind11::object()) {
+ void setAttached(const nanobind::object &parent = nanobind::object()) {
assert(!attached && "operation already attached");
attached = true;
}
@@ -675,24 +674,24 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
std::optional<PyOperationRef> getParentOperation();
/// Gets a capsule wrapping the void* within the MlirOperation.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyOperation from the MlirOperation wrapped by a capsule.
/// Ownership of the underlying MlirOperation is taken by calling this
/// function.
- static pybind11::object createFromCapsule(pybind11::object capsule);
+ static nanobind::object createFromCapsule(nanobind::object capsule);
/// Creates an operation. See corresponding python docstring.
- static pybind11::object
+ static nanobind::object
create(const std::string &name, std::optional<std::vector<PyType *>> results,
std::optional<std::vector<PyValue *>> operands,
- std::optional<pybind11::dict> attributes,
+ std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
- DefaultingPyLocation location, const pybind11::object &ip,
+ DefaultingPyLocation location, const nanobind::object &ip,
bool inferType);
/// Creates an OpView suitable for this operation.
- pybind11::object createOpView();
+ nanobind::object createOpView();
/// Erases the underlying MlirOperation, removes its pointer from the
/// parent context's live operations map, and sets the valid bit false.
@@ -702,23 +701,23 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
void setInvalid() { valid = false; }
/// Clones this operation.
- pybind11::object clone(const pybind11::object &ip);
+ nanobind::object clone(const nanobind::object &ip);
private:
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
static PyOperationRef createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
- pybind11::object parentKeepAlive);
+ nanobind::object parentKeepAlive);
MlirOperation operation;
- pybind11::handle handle;
+ nanobind::handle handle;
// Keeps the parent alive, regardless of whether it is an Operation or
// Module.
// TODO: As implemented, this facility is only sufficient for modeling the
// trivial module parent back-reference. Generalize this to also account for
// transitions from detached to attached and address TODOs in the
// ir_operation.py regarding testing corresponding lifetime guarantees.
- pybind11::object parentKeepAlive;
+ nanobind::object parentKeepAlive;
bool attached = true;
bool valid = true;
@@ -733,17 +732,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// python types.
class PyOpView : public PyOperationBase {
public:
- PyOpView(const pybind11::object &operationObject);
+ PyOpView(const nanobind::object &operationObject);
PyOperation &getOperation() override { return operation; }
- pybind11::object getOperationObject() { return operationObject; }
+ nanobind::object getOperationObject() { return operationObject; }
- static pybind11::object buildGeneric(
- const pybind11::object &cls, std::optional<pybind11::list> resultTypeList,
- pybind11::list operandList, std::optional<pybind11::dict> attributes,
+ static nanobind::object buildGeneric(
+ const nanobind::object &cls, std::optional<nanobind::list> resultTypeList,
+ nanobind::list operandList, std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
- const pybind11::object &maybeIp);
+ const nanobind::object &maybeIp);
/// Construct an instance of a class deriving from OpView, bypassing its
/// `__init__` method. The derived class will typically define a constructor
@@ -752,12 +751,12 @@ class PyOpView : public PyOperationBase {
///
/// The caller is responsible for verifying that `operation` is a valid
/// operation to construct `cls` with.
- static pybind11::object constructDerived(const pybind11::object &cls,
- const PyOperation &operation);
+ static nanobind::object constructDerived(const nanobind::object &cls,
+ const nanobind::object &operation);
private:
PyOperation &operation; // For efficient, cast-free access from C++
- pybind11::object operationObject; // Holds the reference.
+ nanobind::object operationObject; // Holds the reference.
};
/// Wrapper around an MlirRegion.
@@ -830,7 +829,7 @@ class PyBlock {
void checkValid() { return parentOperation->checkValid(); }
/// Gets a capsule wrapping the void* within the MlirBlock.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
private:
PyOperationRef parentOperation;
@@ -858,10 +857,10 @@ class PyInsertionPoint {
void insert(PyOperationBase &operationBase);
/// Enter and exit the context manager.
- pybind11::object contextEnter();
- void contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb);
+ static nanobind::object contextEnter(nanobind::object insertionPoint);
+ void contextExit(const nanobind::object &excType,
+ const nanobind::object &excVal,
+ const nanobind::object &excTb);
PyBlock &getBlock() { return block; }
std::optional<PyOperationRef> &getRefOperation() { return refOperation; }
@@ -886,13 +885,13 @@ class PyType : public BaseContextObject {
MlirType get() const { return type; }
/// Gets a capsule wrapping the void* within the MlirType.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyType from the MlirType wrapped by a capsule.
/// Note that PyType instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirType
/// is taken by calling this function.
- static PyType createFromCapsule(pybind11::object capsule);
+ static PyType createFromCapsule(nanobind::object capsule);
private:
MlirType type;
@@ -912,10 +911,10 @@ class PyTypeID {
MlirTypeID get() { return typeID; }
/// Gets a capsule wrapping the void* within the MlirTypeID.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
- static PyTypeID createFromCapsule(pybind11::object capsule);
+ static PyTypeID createFromCapsule(nanobind::object capsule);
private:
MlirTypeID typeID;
@@ -932,7 +931,7 @@ class PyConcreteType : public BaseTy {
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
- using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+ using ClassTy = nanobind::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirType);
using GetTypeIDFunctionTy = MlirTypeID (*)();
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
@@ -945,34 +944,38 @@ class PyConcreteType : public BaseTy {
static MlirType castFrom(PyType &orig) {
if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
- throw py::value_error((llvm::Twine("Cannot cast type to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str());
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
+ throw nanobind::value_error((llvm::Twine("Cannot cast type to ") +
+ DerivedTy::pyClassName + " (from " +
+ origRepr + ")")
+ .str()
+ .c_str());
}
return orig;
}
- static void bind(pybind11::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
- cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(),
- pybind11::arg("cast_from_type"));
+ static void bind(nanobind::module_ &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
+ nanobind::arg("cast_from_type"));
cls.def_static(
"isinstance",
[](PyType &otherType) -> bool {
return DerivedTy::isaFunction(otherType);
},
- pybind11::arg("other"));
- cls.def_property_readonly_static(
- "static_typeid", [](py::object & /*class*/) -> MlirTypeID {
+ nanobind::arg("other"));
+ cls.def_prop_ro_static(
+ "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID {
if (DerivedTy::getTypeIdFunction)
return DerivedTy::getTypeIdFunction();
- throw py::attribute_error(
- (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
+ throw nanobind::attribute_error(
+ (DerivedTy::pyClassName + llvm::Twine(" has no typeid."))
+ .str()
+ .c_str());
});
- cls.def_property_readonly("typeid", [](PyType &self) {
- return py::cast(self).attr("typeid").cast<MlirTypeID>();
+ cls.def_prop_ro("typeid", [](PyType &self) {
+ return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid"));
});
cls.def("__repr__", [](DerivedTy &self) {
PyPrintAccumulator printAccum;
@@ -986,8 +989,8 @@ class PyConcreteType : public BaseTy {
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
- pybind11::cpp_function(
- [](PyType pyType) -> DerivedTy { return pyType; }));
+ nanobind::cast<nanobind::callable>(nanobind::cpp_function(
+ [](PyType pyType) -> DerivedTy { return pyType; })));
}
DerivedTy::bindDerived(cls);
@@ -1008,13 +1011,13 @@ class PyAttribute : public BaseContextObject {
MlirAttribute get() const { return attr; }
/// Gets a capsule wrapping the void* within the MlirAttribute.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
/// Note that PyAttribute instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAttribute
/// is taken by calling this function.
- static PyAttribute createFromCapsule(pybind11::object capsule);
+ static PyAttribute createFromCapsule(nanobind::object capsule);
private:
MlirAttribute attr;
@@ -1054,7 +1057,7 @@ class PyConcreteAttribute : public BaseTy {
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
- using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+ using ClassTy = nanobind::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirAttribute);
using GetTypeIDFunctionTy = MlirTypeID (*)();
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
@@ -1067,37 +1070,45 @@ class PyConcreteAttribute : public BaseTy {
static MlirAttribute castFrom(PyAttribute &orig) {
if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
- throw py::value_error((llvm::Twine("Cannot cast attribute to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str());
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
+ throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") +
+ DerivedTy::pyClassName + " (from " +
+ origRepr + ")")
+ .str()
+ .c_str());
}
return orig;
}
- static void bind(pybind11::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
- pybind11::module_local());
- cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(),
- pybind11::arg("cast_from_attr"));
+ static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) {
+ ClassTy cls;
+ if (slots) {
+ cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots));
+ } else {
+ cls = ClassTy(m, DerivedTy::pyClassName);
+ }
+ cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
+ nanobind::arg("cast_from_attr"));
cls.def_static(
"isinstance",
[](PyAttribute &otherAttr) -> bool {
return DerivedTy::isaFunction(otherAttr);
},
- pybind11::arg("other"));
- cls.def_property_readonly(
+ nanobind::arg("other"));
+ cls.def_prop_ro(
"type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
- cls.def_property_readonly_static(
- "static_typeid", [](py::object & /*class*/) -> MlirTypeID {
+ cls.def_prop_ro_static(
+ "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID {
if (DerivedTy::getTypeIdFunction)
return DerivedTy::getTypeIdFunction();
- throw py::attribute_error(
- (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
+ throw nanobind::attribute_error(
+ (DerivedTy::pyClassName + llvm::Twine(" has no typeid."))
+ .str()
+ .c_str());
});
- cls.def_property_readonly("typeid", [](PyAttribute &self) {
- return py::cast(self).attr("typeid").cast<MlirTypeID>();
+ cls.def_prop_ro("typeid", [](PyAttribute &self) {
+ return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid"));
});
cls.def("__repr__", [](DerivedTy &self) {
PyPrintAccumulator printAccum;
@@ -1112,9 +1123,10 @@ class PyConcreteAttribute : public BaseTy {
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
- pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
- return pyAttribute;
- }));
+ nanobind::cast<nanobind::callable>(
+ nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
+ return pyAttribute;
+ })));
}
DerivedTy::bindDerived(cls);
@@ -1146,13 +1158,13 @@ class PyValue {
void checkValid() { return parentOperation->checkValid(); }
/// Gets a capsule wrapping the void* within the MlirValue.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
- pybind11::object maybeDownCast();
+ nanobind::object maybeDownCast();
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
- static PyValue createFromCapsule(pybind11::object capsule);
+ static PyValue createFromCapsule(nanobind::object capsule);
private:
PyOperationRef parentOperation;
@@ -1169,13 +1181,13 @@ class PyAffineExpr : public BaseContextObject {
MlirAffineExpr get() const { return affineExpr; }
/// Gets a capsule wrapping the void* within the MlirAffineExpr.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
/// Note that PyAffineExpr instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
/// is taken by calling this function.
- static PyAffineExpr createFromCapsule(pybind11::object capsule);
+ static PyAffineExpr createFromCapsule(nanobind::object capsule);
PyAffineExpr add(const PyAffineExpr &other) const;
PyAffineExpr mul(const PyAffineExpr &other) const;
@@ -1196,13 +1208,13 @@ class PyAffineMap : public BaseContextObject {
MlirAffineMap get() const { return affineMap; }
/// Gets a capsule wrapping the void* within the MlirAffineMap.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
/// Note that PyAffineMap instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineMap
/// is taken by calling this function.
- static PyAffineMap createFromCapsule(pybind11::object capsule);
+ static PyAffineMap createFromCapsule(nanobind::object capsule);
private:
MlirAffineMap affineMap;
@@ -1217,12 +1229,12 @@ class PyIntegerSet : public BaseContextObject {
MlirIntegerSet get() const { return integerSet; }
/// Gets a capsule wrapping the void* within the MlirIntegerSet.
- pybind11::object getCapsule();
+ nanobind::object getCapsule();
/// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
/// Note that PyIntegerSet instances may be uniqued, so the returned object
/// may be a pre-existing object. Integer sets are owned by the context.
- static PyIntegerSet createFromCapsule(pybind11::object capsule);
+ static PyIntegerSet createFromCapsule(nanobind::object capsule);
private:
MlirIntegerSet integerSet;
@@ -1239,7 +1251,7 @@ class PySymbolTable {
/// Returns the symbol (opview) with the given name, throws if there is no
/// such symbol in the table.
- pybind11::object dunderGetItem(const std::string &name);
+ nanobind::object dunderGetItem(const std::string &name);
/// Removes the given operation from the symbol table and erases it.
void erase(PyOperationBase &symbol);
@@ -1269,7 +1281,7 @@ class PySymbolTable {
/// Walks all symbol tables under and including 'from'.
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
- pybind11::object callback);
+ nanobind::object callback);
/// Casts the bindings class into the C API structure.
operator MlirSymbolTable() { return symbolTable; }
@@ -1289,16 +1301,16 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
-void populateIRAffine(pybind11::module &m);
-void populateIRAttributes(pybind11::module &m);
-void populateIRCore(pybind11::module &m);
-void populateIRInterfaces(pybind11::module &m);
-void populateIRTypes(pybind11::module &m);
+void populateIRAffine(nanobind::module_ &m);
+void populateIRAttributes(nanobind::module_ &m);
+void populateIRCore(nanobind::module_ &m);
+void populateIRInterfaces(nanobind::module_ &m);
+void populateIRTypes(nanobind::module_ &m);
} // namespace python
} // namespace mlir
-namespace pybind11 {
+namespace nanobind {
namespace detail {
template <>
@@ -1309,6 +1321,6 @@ struct type_caster<mlir::python::DefaultingPyLocation>
: MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
} // namespace detail
-} // namespace pybind11
+} // namespace nanobind
#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 6f192bc4bffeef..de21a3a3e63c2a 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -6,19 +6,26 @@
//
//===----------------------------------------------------------------------===//
+// clang-format: off
#include "IRModule.h"
+#include "mlir/Bindings/Python/IRTypes.h"
+// clang-format: on
-#include "PybindUtils.h"
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/optional.h>
+#include <nanobind/stl/pair.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/vector.h>
-#include "mlir/Bindings/Python/IRTypes.h"
+#include <optional>
+#include "IRModule.h"
+#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
-#include <optional>
-
-namespace py = pybind11;
+namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
@@ -48,7 +55,7 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
MlirType t = mlirIntegerTypeGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
- py::arg("width"), py::arg("context") = py::none(),
+ nb::arg("width"), nb::arg("context").none() = nb::none(),
"Create a signless integer type");
c.def_static(
"get_signed",
@@ -56,7 +63,7 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
- py::arg("width"), py::arg("context") = py::none(),
+ nb::arg("width"), nb::arg("context").none() = nb::none(),
"Create a signed integer type");
c.def_static(
"get_unsigned",
@@ -64,25 +71,25 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
- py::arg("width"), py::arg("context") = py::none(),
+ nb::arg("width"), nb::arg("context").none() = nb::none(),
"Create an unsigned integer type");
- c.def_property_readonly(
+ c.def_prop_ro(
"width",
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
"Returns the width of the integer type");
- c.def_property_readonly(
+ c.def_prop_ro(
"is_signless",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsSignless(self);
},
"Returns whether this is a signless integer");
- c.def_property_readonly(
+ c.def_prop_ro(
"is_signed",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsSigned(self);
},
"Returns whether this is a signed integer");
- c.def_property_readonly(
+ c.def_prop_ro(
"is_unsigned",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsUnsigned(self);
@@ -107,7 +114,7 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
MlirType t = mlirIndexTypeGet(context->get());
return PyIndexType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a index type.");
+ nb::arg("context").none() = nb::none(), "Create a index type.");
}
};
@@ -118,7 +125,7 @@ class PyFloatType : public PyConcreteType<PyFloatType> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def_property_readonly(
+ c.def_prop_ro(
"width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
"Returns the width of the floating-point type");
}
@@ -141,7 +148,7 @@ class PyFloat4E2M1FNType
MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
return PyFloat4E2M1FNType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float4_e2m1fn type.");
+ nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
}
};
@@ -162,7 +169,7 @@ class PyFloat6E2M3FNType
MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
return PyFloat6E2M3FNType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float6_e2m3fn type.");
+ nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
}
};
@@ -183,7 +190,7 @@ class PyFloat6E3M2FNType
MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
return PyFloat6E3M2FNType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float6_e3m2fn type.");
+ nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
}
};
@@ -204,7 +211,7 @@ class PyFloat8E4M3FNType
MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
return PyFloat8E4M3FNType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
+ nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
}
};
@@ -224,7 +231,7 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
MlirType t = mlirFloat8E5M2TypeGet(context->get());
return PyFloat8E5M2Type(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float8_e5m2 type.");
+ nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
}
};
@@ -244,7 +251,7 @@ class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
MlirType t = mlirFloat8E4M3TypeGet(context->get());
return PyFloat8E4M3Type(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float8_e4m3 type.");
+ nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
}
};
@@ -265,7 +272,8 @@ class PyFloat8E4M3FNUZType
MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
return PyFloat8E4M3FNUZType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
+ nb::arg("context").none() = nb::none(),
+ "Create a float8_e4m3fnuz type.");
}
};
@@ -286,7 +294,8 @@ class PyFloat8E4M3B11FNUZType
MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
return PyFloat8E4M3B11FNUZType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type.");
+ nb::arg("context").none() = nb::none(),
+ "Create a float8_e4m3b11fnuz type.");
}
};
@@ -307,7 +316,8 @@ class PyFloat8E5M2FNUZType
MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
return PyFloat8E5M2FNUZType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
+ nb::arg("context").none() = nb::none(),
+ "Create a float8_e5m2fnuz type.");
}
};
@@ -327,7 +337,7 @@ class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
MlirType t = mlirFloat8E3M4TypeGet(context->get());
return PyFloat8E3M4Type(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float8_e3m4 type.");
+ nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
}
};
@@ -348,7 +358,8 @@ class PyFloat8E8M0FNUType
MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
return PyFloat8E8M0FNUType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a float8_e8m0fnu type.");
+ nb::arg("context").none() = nb::none(),
+ "Create a float8_e8m0fnu type.");
}
};
@@ -368,7 +379,7 @@ class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
MlirType t = mlirBF16TypeGet(context->get());
return PyBF16Type(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a bf16 type.");
+ nb::arg("context").none() = nb::none(), "Create a bf16 type.");
}
};
@@ -388,7 +399,7 @@ class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
MlirType t = mlirF16TypeGet(context->get());
return PyF16Type(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a f16 type.");
+ nb::arg("context").none() = nb::none(), "Create a f16 type.");
}
};
@@ -408,7 +419,7 @@ class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
MlirType t = mlirTF32TypeGet(context->get());
return PyTF32Type(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a tf32 type.");
+ nb::arg("context").none() = nb::none(), "Create a tf32 type.");
}
};
@@ -428,7 +439,7 @@ class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
MlirType t = mlirF32TypeGet(context->get());
return PyF32Type(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a f32 type.");
+ nb::arg("context").none() = nb::none(), "Create a f32 type.");
}
};
@@ -448,7 +459,7 @@ class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
MlirType t = mlirF64TypeGet(context->get());
return PyF64Type(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a f64 type.");
+ nb::arg("context").none() = nb::none(), "Create a f64 type.");
}
};
@@ -468,7 +479,7 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
MlirType t = mlirNoneTypeGet(context->get());
return PyNoneType(context->getRef(), t);
},
- py::arg("context") = py::none(), "Create a none type.");
+ nb::arg("context").none() = nb::none(), "Create a none type.");
}
};
@@ -490,14 +501,15 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
MlirType t = mlirComplexTypeGet(elementType);
return PyComplexType(elementType.getContext(), t);
}
- throw py::value_error(
+ throw nb::value_error(
(Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
+ nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
"' and expected floating point or integer type.")
- .str());
+ .str()
+ .c_str());
},
"Create a complex type");
- c.def_property_readonly(
+ c.def_prop_ro(
"element_type",
[](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
"Returns element type.");
@@ -508,22 +520,22 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
// Shaped Type Interface - ShapedType
void mlir::PyShapedType::bindDerived(ClassTy &c) {
- c.def_property_readonly(
+ c.def_prop_ro(
"element_type",
[](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
"Returns the element type of the shaped type.");
- c.def_property_readonly(
+ c.def_prop_ro(
"has_rank",
[](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
"Returns whether the given shaped type is ranked.");
- c.def_property_readonly(
+ c.def_prop_ro(
"rank",
[](PyShapedType &self) {
self.requireHasRank();
return mlirShapedTypeGetRank(self);
},
"Returns the rank of the given ranked shaped type.");
- c.def_property_readonly(
+ c.def_prop_ro(
"has_static_shape",
[](PyShapedType &self) -> bool {
return mlirShapedTypeHasStaticShape(self);
@@ -535,7 +547,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
self.requireHasRank();
return mlirShapedTypeIsDynamicDim(self, dim);
},
- py::arg("dim"),
+ nb::arg("dim"),
"Returns whether the dim-th dimension of the given shaped type is "
"dynamic.");
c.def(
@@ -544,12 +556,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
self.requireHasRank();
return mlirShapedTypeGetDimSize(self, dim);
},
- py::arg("dim"),
+ nb::arg("dim"),
"Returns the dim-th dimension of the given ranked shaped type.");
c.def_static(
"is_dynamic_size",
[](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
- py::arg("dim_size"),
+ nb::arg("dim_size"),
"Returns whether the given dimension size indicates a dynamic "
"dimension.");
c.def(
@@ -558,10 +570,10 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
self.requireHasRank();
return mlirShapedTypeIsDynamicStrideOrOffset(val);
},
- py::arg("dim_size"),
+ nb::arg("dim_size"),
"Returns whether the given value is used as a placeholder for dynamic "
"strides and offsets in shaped types.");
- c.def_property_readonly(
+ c.def_prop_ro(
"shape",
[](PyShapedType &self) {
self.requireHasRank();
@@ -587,7 +599,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
void mlir::PyShapedType::requireHasRank() {
if (!mlirShapedTypeHasRank(*this)) {
- throw py::value_error(
+ throw nb::value_error(
"calling this method requires that the type has a rank.");
}
}
@@ -607,15 +619,15 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyVectorType::get, py::arg("shape"),
- py::arg("element_type"), py::kw_only(),
- py::arg("scalable") = py::none(),
- py::arg("scalable_dims") = py::none(),
- py::arg("loc") = py::none(), "Create a vector type")
- .def_property_readonly(
+ c.def_static("get", &PyVectorType::get, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable").none() = nb::none(),
+ nb::arg("scalable_dims").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), "Create a vector type")
+ .def_prop_ro(
"scalable",
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
- .def_property_readonly("scalable_dims", [](MlirType self) {
+ .def_prop_ro("scalable_dims", [](MlirType self) {
std::vector<bool> scalableDims;
size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
scalableDims.reserve(rank);
@@ -627,11 +639,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
private:
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<py::list> scalable,
+ std::optional<nb::list> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyLocation loc) {
if (scalable && scalableDims) {
- throw py::value_error("'scalable' and 'scalable_dims' kwargs "
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
}
@@ -639,10 +651,10 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
MlirType type;
if (scalable) {
if (scalable->size() != shape.size())
- throw py::value_error("Expected len(scalable) == len(shape).");
+ throw nb::value_error("Expected len(scalable) == len(shape).");
SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const py::handle &h) { return h.cast<bool>(); }));
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
scalableDimFlags.data(),
elementType);
@@ -650,7 +662,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
SmallVector<bool> scalableDimFlags(shape.size(), false);
for (int64_t dim : *scalableDims) {
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw py::value_error("Scalable dimension index out of bounds.");
+ throw nb::value_error("Scalable dimension index out of bounds.");
scalableDimFlags[dim] = true;
}
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
@@ -689,17 +701,17 @@ class PyRankedTensorType
throw MLIRError("Invalid type", errors.take());
return PyRankedTensorType(elementType.getContext(), t);
},
- py::arg("shape"), py::arg("element_type"),
- py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
- "Create a ranked tensor type");
- c.def_property_readonly(
- "encoding",
- [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
- MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
- if (mlirAttributeIsNull(encoding))
- return std::nullopt;
- return encoding;
- });
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
+ c.def_prop_ro("encoding",
+ [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
+ MlirAttribute encoding =
+ mlirRankedTensorTypeGetEncoding(self.get());
+ if (mlirAttributeIsNull(encoding))
+ return std::nullopt;
+ return encoding;
+ });
}
};
@@ -723,7 +735,7 @@ class PyUnrankedTensorType
throw MLIRError("Invalid type", errors.take());
return PyUnrankedTensorType(elementType.getContext(), t);
},
- py::arg("element_type"), py::arg("loc") = py::none(),
+ nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
"Create a unranked tensor type");
}
};
@@ -754,10 +766,11 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
throw MLIRError("Invalid type", errors.take());
return PyMemRefType(elementType.getContext(), t);
},
- py::arg("shape"), py::arg("element_type"),
- py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
- py::arg("loc") = py::none(), "Create a memref type")
- .def_property_readonly(
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout").none() = nb::none(),
+ nb::arg("memory_space").none() = nb::none(),
+ nb::arg("loc").none() = nb::none(), "Create a memref type")
+ .def_prop_ro(
"layout",
[](PyMemRefType &self) -> MlirAttribute {
return mlirMemRefTypeGetLayout(self);
@@ -775,14 +788,14 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
return {strides, offset};
},
"The strides and offset of the MemRef type.")
- .def_property_readonly(
+ .def_prop_ro(
"affine_map",
[](PyMemRefType &self) -> PyAffineMap {
MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
return PyAffineMap(self.getContext(), map);
},
"The layout of the MemRef type as an affine map.")
- .def_property_readonly(
+ .def_prop_ro(
"memory_space",
[](PyMemRefType &self) -> std::optional<MlirAttribute> {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
@@ -820,9 +833,9 @@ class PyUnrankedMemRefType
throw MLIRError("Invalid type", errors.take());
return PyUnrankedMemRefType(elementType.getContext(), t);
},
- py::arg("element_type"), py::arg("memory_space"),
- py::arg("loc") = py::none(), "Create a unranked memref type")
- .def_property_readonly(
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
+ .def_prop_ro(
"memory_space",
[](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
@@ -851,15 +864,15 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
elements.data());
return PyTupleType(context->getRef(), t);
},
- py::arg("elements"), py::arg("context") = py::none(),
+ nb::arg("elements"), nb::arg("context").none() = nb::none(),
"Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) {
return mlirTupleTypeGetType(self, pos);
},
- py::arg("pos"), "Returns the pos-th type in the tuple type.");
- c.def_property_readonly(
+ nb::arg("pos"), "Returns the pos-th type in the tuple type.");
+ c.def_prop_ro(
"num_types",
[](PyTupleType &self) -> intptr_t {
return mlirTupleTypeGetNumTypes(self);
@@ -887,13 +900,14 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
- py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
+ nb::arg("inputs"), nb::arg("results"),
+ nb::arg("context").none() = nb::none(),
"Gets a FunctionType from a list of input and result types");
- c.def_property_readonly(
+ c.def_prop_ro(
"inputs",
[](PyFunctionType &self) {
MlirType t = self;
- py::list types;
+ nb::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
types.append(mlirFunctionTypeGetInput(t, i));
@@ -901,10 +915,10 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
return types;
},
"Returns the list of input types in the FunctionType.");
- c.def_property_readonly(
+ c.def_prop_ro(
"results",
[](PyFunctionType &self) {
- py::list types;
+ nb::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
types.append(mlirFunctionTypeGetResult(self, i));
@@ -938,21 +952,21 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
toMlirStringRef(typeData));
return PyOpaqueType(context->getRef(), type);
},
- py::arg("dialect_namespace"), py::arg("buffer"),
- py::arg("context") = py::none(),
+ nb::arg("dialect_namespace"), nb::arg("buffer"),
+ nb::arg("context").none() = nb::none(),
"Create an unregistered (opaque) dialect type.");
- c.def_property_readonly(
+ c.def_prop_ro(
"dialect_namespace",
[](PyOpaqueType &self) {
MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
- return py::str(stringRef.data, stringRef.length);
+ return nb::str(stringRef.data, stringRef.length);
},
"Returns the dialect namespace for the Opaque type as a string.");
- c.def_property_readonly(
+ c.def_prop_ro(
"data",
[](PyOpaqueType &self) {
MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
- return py::str(stringRef.data, stringRef.length);
+ return nb::str(stringRef.data, stringRef.length);
},
"Returns the data for the Opaque type as a string.");
}
@@ -960,7 +974,7 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
} // namespace
-void mlir::python::populateIRTypes(py::module &m) {
+void mlir::python::populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 7c27021902de31..e5e64a921a79ad 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,29 +6,31 @@
//
//===----------------------------------------------------------------------===//
-#include "PybindUtils.h"
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
#include "Globals.h"
#include "IRModule.h"
+#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
-namespace py = pybind11;
+namespace nb = nanobind;
using namespace mlir;
-using namespace py::literals;
+using namespace nb::literals;
using namespace mlir::python;
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
-PYBIND11_MODULE(_mlir, m) {
+NB_MODULE(_mlir, m) {
m.doc() = "MLIR Python Native Extension";
- py::class_<PyGlobals>(m, "_Globals", py::module_local())
- .def_property("dialect_search_modules",
- &PyGlobals::getDialectSearchPrefixes,
- &PyGlobals::setDialectSearchPrefixes)
+ nb::class_<PyGlobals>(m, "_Globals")
+ .def_prop_rw("dialect_search_modules",
+ &PyGlobals::getDialectSearchPrefixes,
+ &PyGlobals::setDialectSearchPrefixes)
.def(
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
@@ -45,22 +47,21 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- "operation_name"_a, "operation_class"_a, py::kw_only(),
+ "operation_name"_a, "operation_class"_a, nb::kw_only(),
"replace"_a = false,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
// it is necessary to make sure it is destroyed (and releases its python
// resources) properly.
- m.attr("globals") =
- py::cast(new PyGlobals, py::return_value_policy::take_ownership);
+ m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
// Registration decorators.
m.def(
"register_dialect",
- [](py::type pyClass) {
+ [](nb::type_object pyClass) {
std::string dialectNamespace =
- pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
+ nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
return pyClass;
},
@@ -68,45 +69,46 @@ PYBIND11_MODULE(_mlir, m) {
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
- [](const py::type &dialectClass, bool replace) -> py::cpp_function {
- return py::cpp_function(
- [dialectClass, replace](py::type opClass) -> py::type {
+ [](const nb::type_object &dialectClass, bool replace) -> nb::object {
+ return nb::cpp_function(
+ [dialectClass,
+ replace](nb::type_object opClass) -> nb::type_object {
std::string operationName =
- opClass.attr("OPERATION_NAME").cast<std::string>();
+ nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);
// Dict-stuff the new opClass by name onto the dialect class.
- py::object opClassName = opClass.attr("__name__");
+ nb::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
return opClass;
});
},
- "dialect_class"_a, py::kw_only(), "replace"_a = false,
+ "dialect_class"_a, nb::kw_only(), "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
- return py::cpp_function([mlirTypeID,
- replace](py::object typeCaster) -> py::object {
+ [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ return nb::cpp_function([mlirTypeID, replace](
+ nb::callable typeCaster) -> nb::object {
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
return typeCaster;
});
},
- "typeid"_a, py::kw_only(), "replace"_a = false,
+ "typeid"_a, nb::kw_only(), "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
m.def(
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
- return py::cpp_function(
- [mlirTypeID, replace](py::object valueCaster) -> py::object {
+ [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ return nb::cpp_function(
+ [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
replace);
return valueCaster;
});
},
- "typeid"_a, py::kw_only(), "replace"_a = false,
+ "typeid"_a, nb::kw_only(), "replace"_a = false,
"Register a value caster for casting MLIR values to custom user values.");
// Define and populate IR submodule.
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
similarity index 85%
rename from mlir/lib/Bindings/Python/PybindUtils.h
rename to mlir/lib/Bindings/Python/NanobindUtils.h
index 38462ac8ba6db9..3b0f7f698b22d4 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -1,4 +1,5 @@
-//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
+//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
+//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -9,13 +10,21 @@
#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
+#include <nanobind/nanobind.h>
+
#include "mlir-c/Support.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"
-#include <pybind11/pybind11.h>
-#include <pybind11/stl.h>
+template <>
+struct std::iterator_traits<nanobind::detail::fast_iterator> {
+ using value_type = nanobind::handle;
+ using reference = const value_type;
+ using pointer = void;
+ using difference_type = std::ptrdiff_t;
+ using iterator_category = std::forward_iterator_tag;
+};
namespace mlir {
namespace python {
@@ -54,14 +63,14 @@ class Defaulting {
} // namespace python
} // namespace mlir
-namespace pybind11 {
+namespace nanobind {
namespace detail {
template <typename DefaultingTy>
struct MlirDefaultingCaster {
- PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
+ NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription));
- bool load(pybind11::handle src, bool) {
+ bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
if (src.is_none()) {
// Note that we do want an exception to propagate from here as it will be
// the most informative.
@@ -76,20 +85,20 @@ struct MlirDefaultingCaster {
// code to produce nice error messages (other than "Cannot cast...").
try {
value = DefaultingTy{
- pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
+ nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
return true;
} catch (std::exception &) {
return false;
}
}
- static handle cast(DefaultingTy src, return_value_policy policy,
- handle parent) {
- return pybind11::cast(src, policy);
+ static handle from_cpp(DefaultingTy src, rv_policy policy,
+ cleanup_list *cleanup) noexcept {
+ return nanobind::cast(src, policy);
}
};
} // namespace detail
-} // namespace pybind11
+} // namespace nanobind
//------------------------------------------------------------------------------
// Conversion utilities.
@@ -100,7 +109,7 @@ namespace mlir {
/// Accumulates into a python string from a method that accepts an
/// MlirStringCallback.
struct PyPrintAccumulator {
- pybind11::list parts;
+ nanobind::list parts;
void *getUserData() { return this; }
@@ -108,15 +117,15 @@ struct PyPrintAccumulator {
return [](MlirStringRef part, void *userData) {
PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData);
- pybind11::str pyPart(part.data,
+ nanobind::str pyPart(part.data,
part.length); // Decodes as UTF-8 by default.
printAccum->parts.append(std::move(pyPart));
};
}
- pybind11::str join() {
- pybind11::str delim("", 0);
- return delim.attr("join")(parts);
+ nanobind::str join() {
+ nanobind::str delim("", 0);
+ return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
}
};
@@ -124,21 +133,21 @@ struct PyPrintAccumulator {
/// or binary.
class PyFileAccumulator {
public:
- PyFileAccumulator(const pybind11::object &fileObject, bool binary)
+ PyFileAccumulator(const nanobind::object &fileObject, bool binary)
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](MlirStringRef part, void *userData) {
- pybind11::gil_scoped_acquire acquire;
+ nanobind::gil_scoped_acquire acquire;
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
if (accum->binary) {
// Note: Still has to copy and not avoidable with this API.
- pybind11::bytes pyBytes(part.data, part.length);
+ nanobind::bytes pyBytes(part.data, part.length);
accum->pyWriteFunction(pyBytes);
} else {
- pybind11::str pyStr(part.data,
+ nanobind::str pyStr(part.data,
part.length); // Decodes as UTF-8 by default.
accum->pyWriteFunction(pyStr);
}
@@ -146,7 +155,7 @@ class PyFileAccumulator {
}
private:
- pybind11::object pyWriteFunction;
+ nanobind::object pyWriteFunction;
bool binary;
};
@@ -163,17 +172,17 @@ struct PySinglePartStringAccumulator {
assert(!accum->invoked &&
"PySinglePartStringAccumulator called back multiple times");
accum->invoked = true;
- accum->value = pybind11::str(part.data, part.length);
+ accum->value = nanobind::str(part.data, part.length);
};
}
- pybind11::str takeValue() {
+ nanobind::str takeValue() {
assert(invoked && "PySinglePartStringAccumulator not called back");
return std::move(value);
}
private:
- pybind11::str value;
+ nanobind::str value;
bool invoked = false;
};
@@ -208,7 +217,7 @@ struct PySinglePartStringAccumulator {
template <typename Derived, typename ElementTy>
class Sliceable {
protected:
- using ClassTy = pybind11::class_<Derived>;
+ using ClassTy = nanobind::class_<Derived>;
/// Transforms `index` into a legal value to access the underlying sequence.
/// Returns <0 on failure.
@@ -237,7 +246,7 @@ class Sliceable {
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
- pybind11::object getItem(intptr_t index) {
+ nanobind::object getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
@@ -250,20 +259,20 @@ class Sliceable {
->getRawElement(linearizeIndex(index))
.maybeDownCast();
else
- return pybind11::cast(
+ return nanobind::cast(
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
}
/// Returns a new instance of the pseudo-container restricted to the given
/// slice. Returns a nullptr object on failure.
- pybind11::object getItemSlice(PyObject *slice) {
+ nanobind::object getItemSlice(PyObject *slice) {
ssize_t start, stop, extraStep, sliceLength;
if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
&sliceLength) != 0) {
PyErr_SetString(PyExc_IndexError, "index out of range");
return {};
}
- return pybind11::cast(static_cast<Derived *>(this)->slice(
+ return nanobind::cast(static_cast<Derived *>(this)->slice(
startIndex + start * step, sliceLength, step * extraStep));
}
@@ -279,7 +288,7 @@ class Sliceable {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
- throw pybind11::index_error("index out of range");
+ throw nanobind::index_error("index out of range");
}
return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
@@ -304,39 +313,38 @@ class Sliceable {
}
/// Binds the indexing and length methods in the Python class.
- static void bind(pybind11::module &m) {
- auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
- pybind11::module_local())
+ static void bind(nanobind::module_ &m) {
+ auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
.def("__add__", &Sliceable::dunderAdd);
Derived::bindDerived(clazz);
// Manually implement the sequence protocol via the C API. We do this
- // because it is approx 4x faster than via pybind11, largely because that
+ // because it is approx 4x faster than via nanobind, largely because that
// formulation requires a C++ exception to be thrown to detect end of
// sequence.
// Since we are in a C-context, any C++ exception that happens here
// will terminate the program. There is nothing in this implementation
// that should throw in a non-terminal way, so we forgo further
// exception marshalling.
- // See: https://github.com/pybind/pybind11/issues/2842
+ // See: https://github.com/pybind/nanobind/issues/2842
auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
"must be heap type");
heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
- auto self = pybind11::cast<Derived *>(rawSelf);
+ auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
return self->length;
};
// sq_item is called as part of the sequence protocol for iteration,
// list construction, etc.
heap_type->as_sequence.sq_item =
+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
- auto self = pybind11::cast<Derived *>(rawSelf);
+ auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
return self->getItem(index).release().ptr();
};
// mp_subscript is used for both slices and integer lookups.
heap_type->as_mapping.mp_subscript =
+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
- auto self = pybind11::cast<Derived *>(rawSelf);
+ auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
if (!PyErr_Occurred()) {
// Integer indexing.
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index e991deaae2daa5..b5dce4fe4128a5 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,12 +8,16 @@
#include "Pass.h"
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/optional.h>
+#include <nanobind/stl/string.h>
+
#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Pass.h"
-namespace py = pybind11;
-using namespace py::literals;
+namespace nb = nanobind;
+using namespace nb::literals;
using namespace mlir;
using namespace mlir::python;
@@ -34,16 +38,15 @@ class PyPassManager {
MlirPassManager get() { return passManager; }
void release() { passManager.ptr = nullptr; }
- pybind11::object getCapsule() {
- return py::reinterpret_steal<py::object>(
- mlirPythonPassManagerToCapsule(get()));
+ nb::object getCapsule() {
+ return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
}
- static pybind11::object createFromCapsule(pybind11::object capsule) {
+ static nb::object createFromCapsule(nb::object capsule) {
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(rawPm))
- throw py::error_already_set();
- return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
+ throw nb::python_error();
+ return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
}
private:
@@ -53,22 +56,23 @@ class PyPassManager {
} // namespace
/// Create the `mlir.passmanager` here.
-void mlir::python::populatePassManagerSubmodule(py::module &m) {
+void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
- py::class_<PyPassManager>(m, "PassManager", py::module_local())
- .def(py::init<>([](const std::string &anchorOp,
- DefaultingPyMlirContext context) {
- MlirPassManager passManager = mlirPassManagerCreateOnOperation(
- context->get(),
- mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
- return new PyPassManager(passManager);
- }),
- "anchor_op"_a = py::str("any"), "context"_a = py::none(),
- "Create a new PassManager for the current (or provided) Context.")
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyPassManager::getCapsule)
+ nb::class_<PyPassManager>(m, "PassManager")
+ .def(
+ "__init__",
+ [](PyPassManager &self, const std::string &anchorOp,
+ DefaultingPyMlirContext context) {
+ MlirPassManager passManager = mlirPassManagerCreateOnOperation(
+ context->get(),
+ mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
+ new (&self) PyPassManager(passManager);
+ },
+ "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
+ "Create a new PassManager for the current (or provided) Context.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
.def("_testing_release", &PyPassManager::release,
"Releases (leaks) the backing pass manager (testing)")
@@ -101,9 +105,9 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"print_before_all"_a = false, "print_after_all"_a = true,
"print_module_scope"_a = false, "print_after_change"_a = false,
"print_after_failure"_a = false,
- "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false,
- "print_generic_op_form"_a = false,
- "tree_printing_dir_path"_a = py::none(),
+ "large_elements_limit"_a.none() = nb::none(),
+ "enable_debug_info"_a = false, "print_generic_op_form"_a = false,
+ "tree_printing_dir_path"_a.none() = nb::none(),
"Enable IR printing, default as mlir-print-ir-after-all.")
.def(
"enable_verifier",
@@ -121,10 +125,10 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
- throw py::value_error(std::string(errorMsg.join()));
+ throw nb::value_error(errorMsg.join().c_str());
return new PyPassManager(passManager);
},
- "pipeline"_a, "context"_a = py::none(),
+ "pipeline"_a, "context"_a.none() = nb::none(),
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
@@ -137,7 +141,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
- throw py::value_error(std::string(errorMsg.join()));
+ throw nb::value_error(errorMsg.join().c_str());
},
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index 3a500d5e8257ac..bc409435218299 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -9,12 +9,12 @@
#ifndef MLIR_BINDINGS_PYTHON_PASS_H
#define MLIR_BINDINGS_PYTHON_PASS_H
-#include "PybindUtils.h"
+#include "NanobindUtils.h"
namespace mlir {
namespace python {
-void populatePassManagerSubmodule(pybind11::module &m);
+void populatePassManagerSubmodule(nanobind::module_ &m);
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 1d8128be9f0826..b2c1de4be9a69c 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,14 +8,16 @@
#include "Rewrite.h"
+#include <nanobind/nanobind.h>
+
#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Rewrite.h"
#include "mlir/Config/mlir-config.h"
-namespace py = pybind11;
+namespace nb = nanobind;
using namespace mlir;
-using namespace py::literals;
+using namespace nb::literals;
using namespace mlir::python;
namespace {
@@ -54,18 +56,17 @@ class PyFrozenRewritePatternSet {
}
MlirFrozenRewritePatternSet get() { return set; }
- pybind11::object getCapsule() {
- return py::reinterpret_steal<py::object>(
+ nb::object getCapsule() {
+ return nb::steal<nb::object>(
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}
- static pybind11::object createFromCapsule(pybind11::object capsule) {
+ static nb::object createFromCapsule(nb::object capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
- throw py::error_already_set();
- return py::cast(PyFrozenRewritePatternSet(rawPm),
- py::return_value_policy::move);
+ throw nb::python_error();
+ return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
}
private:
@@ -75,25 +76,27 @@ class PyFrozenRewritePatternSet {
} // namespace
/// Create the `mlir.rewrite` here.
-void mlir::python::populateRewriteSubmodule(py::module &m) {
+void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
- py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
- .def(py::init<>([](MlirModule module) {
- return mlirPDLPatternModuleFromModule(module);
- }),
- "module"_a, "Create a PDL module from the given module.")
+ nb::class_<PyPDLPatternModule>(m, "PDLModule")
+ .def(
+ "__init__",
+ [](PyPDLPatternModule &self, MlirModule module) {
+ new (&self)
+ PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
+ },
+ "module"_a, "Create a PDL module from the given module.")
.def("freeze", [](PyPDLPatternModule &self) {
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
});
-#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
- py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
- py::module_local())
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyFrozenRewritePatternSet::getCapsule)
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+ nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyFrozenRewritePatternSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
@@ -102,7 +105,7 @@ void mlir::python::populateRewriteSubmodule(py::module &m) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
- throw py::value_error("pattern application failed to converge");
+ throw nb::value_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index 997b80adda3038..ae89e2b9589f13 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,12 +9,12 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
-#include "PybindUtils.h"
+#include "NanobindUtils.h"
namespace mlir {
namespace python {
-void populateRewriteSubmodule(pybind11::module &m);
+void populateRewriteSubmodule(nanobind::module_ &m);
} // namespace python
} // namespace mlir
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index e1b870b53ad25c..d3ca940b408276 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -440,6 +440,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
MODULE_NAME _mlir
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
MainModule.cpp
IRAffine.cpp
@@ -455,7 +456,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
Globals.h
IRModule.h
Pass.h
- PybindUtils.h
+ NanobindUtils.h
Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index ab8a9122919e19..f240d6ef944ec7 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,4 @@
-nanobind>=2.0, <3.0
+nanobind>=2.4, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 577721ab2111f5..8b6d7ea5a197d7 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -176,5 +176,6 @@ def error_callback(symbol_table_op, uses_visible):
try:
SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
except RuntimeError as e:
- # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
+ # CHECK: GOT EXCEPTION: Exception raised in callback:
+ # CHECK: AssertionError: Raised from python
print(f"GOT EXCEPTION: {e}")
diff --git a/utils/bazel/WORKSPACE b/utils/bazel/WORKSPACE
index 66ba1ac1b17e1e..005a4b9d7b5ad2 100644
--- a/utils/bazel/WORKSPACE
+++ b/utils/bazel/WORKSPACE
@@ -161,9 +161,9 @@ maybe(
http_archive,
name = "nanobind",
build_file = "@llvm-raw//utils/bazel/third_party_build:nanobind.BUILD",
- sha256 = "bfbfc7e5759f1669e4ddb48752b1ddc5647d1430e94614d6f8626df1d508e65a",
- strip_prefix = "nanobind-2.2.0",
- url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.2.0.tar.gz",
+ sha256 = "bb35deaed7efac5029ed1e33880a415638352f757d49207a8e6013fefb6c49a7",
+ strip_prefix = "nanobind-2.4.0",
+ url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.4.0.tar.gz",
)
load("@rules_python//python:repositories.bzl", "py_repositories", "python_register_toolchains")
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 544becfa30b40f..aee8aab8498ce2 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1026,6 +1026,9 @@ cc_library(
srcs = [":MLIRBindingsPythonSourceFiles"],
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
+ includes = [
+ "lib/Bindings/Python",
+ ],
textual_hdrs = [":MLIRBindingsPythonCoreHeaders"],
deps = [
":CAPIAsync",
@@ -1033,11 +1036,11 @@ cc_library(
":CAPIIR",
":CAPIInterfaces",
":CAPITransforms",
- ":MLIRBindingsPythonHeadersAndDeps",
+ ":MLIRBindingsPythonNanobindHeadersAndDeps",
":Support",
":config",
"//llvm:Support",
- "@pybind11",
+ "@nanobind",
"@rules_python//python/cc:current_py_cc_headers",
],
)
@@ -1047,17 +1050,20 @@ cc_library(
srcs = [":MLIRBindingsPythonSourceFiles"],
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
+ includes = [
+ "lib/Bindings/Python",
+ ],
textual_hdrs = [":MLIRBindingsPythonCoreHeaders"],
deps = [
":CAPIAsyncHeaders",
":CAPIDebugHeaders",
":CAPIIRHeaders",
":CAPITransformsHeaders",
- ":MLIRBindingsPythonHeaders",
+ ":MLIRBindingsPythonNanobindHeaders",
":Support",
":config",
"//llvm:Support",
- "@pybind11",
+ "@nanobind",
"@rules_python//python/cc:current_py_cc_headers",
],
)
@@ -1090,6 +1096,7 @@ cc_binary(
deps = [
":MLIRBindingsPythonCore",
":MLIRBindingsPythonHeadersAndDeps",
+ "@nanobind",
],
)
More information about the llvm-commits
mailing list