[Mlir-commits] [llvm] [mlir] [mlir][Python] create MLIRPythonSupport (PR #171775)
Maksim Levental
llvmlistbot at llvm.org
Fri Jan 2 13:01:55 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/171775
>From 68d6a2fd61657c29df3fd6d16f85282f5ef36d2f Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 23:57:13 -0800
Subject: [PATCH] [mlir][Python] create MLIRPythonSupport
---
mlir/cmake/modules/AddMLIRPython.cmake | 221 +-
mlir/docs/Bindings/Python.md | 7 +
.../include/Standalone-c/Dialects.h | 7 +
.../examples/standalone/lib/CAPI/Dialects.cpp | 13 +
.../python/StandaloneExtensionNanobind.cpp | 27 +
.../standalone/test/python/smoketest.py | 4 +
mlir/include/mlir-c/Support.h | 2 +
.../mlir}/Bindings/Python/Globals.h | 16 +-
.../mlir/Bindings/Python/IRCore.h} | 686 +++++-
mlir/include/mlir/Bindings/Python/IRTypes.h | 9 +-
.../mlir}/Bindings/Python/NanobindUtils.h | 0
mlir/lib/Bindings/Python/DialectSMT.cpp | 2 +-
.../Python/{IRModule.cpp => Globals.cpp} | 20 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 40 +-
mlir/lib/Bindings/Python/IRAttributes.cpp | 30 +-
mlir/lib/Bindings/Python/IRCore.cpp | 1984 ++++++++---------
mlir/lib/Bindings/Python/IRInterfaces.cpp | 6 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 42 +-
mlir/lib/Bindings/Python/MainModule.cpp | 160 +-
mlir/lib/Bindings/Python/Pass.cpp | 43 +-
mlir/lib/Bindings/Python/Pass.h | 5 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 112 +-
mlir/lib/Bindings/Python/Rewrite.h | 6 +-
mlir/python/CMakeLists.txt | 22 +-
mlir/test/Examples/standalone/test.wheel.toy | 2 +
mlir/test/python/dialects/python_test.py | 36 +-
.../python/lib/PythonTestModuleNanobind.cpp | 70 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 2 +-
28 files changed, 2061 insertions(+), 1513 deletions(-)
rename mlir/{lib => include/mlir}/Bindings/Python/Globals.h (96%)
rename mlir/{lib/Bindings/Python/IRModule.h => include/mlir/Bindings/Python/IRCore.h} (69%)
rename mlir/{lib => include/mlir}/Bindings/Python/NanobindUtils.h (100%)
rename mlir/lib/Bindings/Python/{IRModule.cpp => Globals.cpp} (95%)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 8c301faf0941a..f4d078dfe7118 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -228,7 +228,7 @@ endfunction()
# aggregate dylib that is linked against.
function(declare_mlir_python_extension name)
cmake_parse_arguments(ARG
- ""
+ "_PRIVATE_SUPPORT_LIB"
"ROOT_DIR;MODULE_NAME;ADD_TO_PARENT"
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
${ARGN})
@@ -236,6 +236,11 @@ function(declare_mlir_python_extension name)
if(NOT ARG_ROOT_DIR)
set(ARG_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
endif()
+ if(ARG__PRIVATE_SUPPORT_LIB)
+ set(SOURCES_TYPE "support")
+ else()
+ set(SOURCES_TYPE "extension")
+ endif()
set(_install_destination "src/python/${name}")
add_library(${name} INTERFACE)
@@ -243,7 +248,7 @@ function(declare_mlir_python_extension name)
# Yes: Leading-lowercase property names are load bearing and the recommended
# way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS"
- mlir_python_SOURCES_TYPE extension
+ mlir_python_SOURCES_TYPE "${SOURCES_TYPE}"
mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}"
mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
mlir_python_DEPENDS ""
@@ -297,6 +302,58 @@ function(_mlir_python_install_sources name source_root_dir destination)
)
endfunction()
+function(build_nanobind_lib)
+ cmake_parse_arguments(ARG
+ ""
+ "INSTALL_COMPONENT;INSTALL_DESTINATION;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
+ ""
+ ${ARGN})
+
+ # Only build in free-threaded mode if the Python ABI supports it.
+ # See https://github.com/wjakob/nanobind/blob/4ba51fcf795971c5d603d875ae4184bc0c9bd8e6/cmake/nanobind-config.cmake#L363-L371.
+ if (NB_ABI MATCHES "[0-9]t")
+ set(_ft "-ft")
+ endif()
+ # nanobind does a string match on the suffix to figure out whether to build
+ # the lib with free threading...
+ set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
+ set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
+ nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
+ target_compile_definitions(${NB_LIBRARY_TARGET_NAME}
+ PRIVATE
+ NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ # nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols
+ # (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym)
+ # but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs
+ # we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much
+ # maintenance).
+ target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "LINKER:-z,undefs")
+ endif()
+ # nanobind configures with LTO for shared build which doesn't work everywhere
+ # (see https://github.com/llvm/llvm-project/issues/139602).
+ if(NOT LLVM_ENABLE_LTO)
+ set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ INTERPROCEDURAL_OPTIMIZATION_RELEASE OFF
+ INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL OFF
+ )
+ endif()
+ set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ # Needed for windows (and doesn't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ )
+ mlir_python_setup_extension_rpath(${NB_LIBRARY_TARGET_NAME})
+ install(TARGETS ${NB_LIBRARY_TARGET_NAME}
+ COMPONENT ${ARG_INSTALL_COMPONENT}
+ LIBRARY DESTINATION "${ARG_INSTALL_DESTINATION}"
+ RUNTIME DESTINATION "${ARG_INSTALL_DESTINATION}"
+ )
+endfunction()
+
# Function: add_mlir_python_modules
# Adds python modules to a project, building them from a list of declared
# source groupings (see declare_mlir_python_sources and
@@ -308,6 +365,11 @@ endfunction()
# for non-relocatable modules or a deeper directory tree for relocatable.
# INSTALL_PREFIX: Prefix into the install tree for installing the package.
# Typically mirrors the path above but without an absolute path.
+# MLIR_BINDINGS_PYTHON_NB_DOMAIN: nanobind (and MLIR) domain within which
+# extensions will be compiled. This determines whether this package
+# will share nanobind types with other bindings packages. Expected to be unique
+# per project (and per specific set of bindings, for projects with multiple
+# bindings packages).
# DECLARED_SOURCES: List of declared source groups to include. The entire
# DAG of source modules is included.
# COMMON_CAPI_LINK_LIBS: List of dylibs (typically one) to make every
@@ -315,11 +377,32 @@ endfunction()
function(add_mlir_python_modules name)
cmake_parse_arguments(ARG
""
- "ROOT_PREFIX;INSTALL_PREFIX"
+ "ROOT_PREFIX;INSTALL_PREFIX;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
+
+ # TODO(max): do the same for MLIR_PYTHON_PACKAGE_PREFIX?
+ if((NOT ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN) AND MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
+ endif()
+ if((NOT ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN) OR ("${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}" STREQUAL ""))
+ message(WARNING "MLIR_BINDINGS_PYTHON_NB_DOMAIN CMake var is not set - setting to a default `mlir`.\
+ It is highly recommend to set this to something unique so that your project's Python bindings do not collide with\
+ others'. You also pass explicitly to `add_mlir_python_modules`.\
+ See https://github.com/llvm/llvm-project/pull/171775 for more information.")
+ set(ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir")
+ endif()
+
+ # This call sets NB_LIBRARY_TARGET_NAME.
+ build_nanobind_lib(
+ INSTALL_COMPONENT ${name}
+ INSTALL_DESTINATION "${ARG_INSTALL_PREFIX}/_mlir_libs"
+ OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
+
# Helper to process an individual target.
- function(_process_target modules_target sources_target)
+ function(_process_target modules_target sources_target support_libs)
get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE)
if(_source_type STREQUAL "pure")
@@ -337,16 +420,20 @@ function(add_mlir_python_modules name)
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
# Transform relative source to based on root dir.
set(_extension_target "${modules_target}.extension.${_module_name}.dso")
- add_mlir_python_extension(${_extension_target} "${_module_name}"
+ add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${modules_target}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
LINK_LIBS PRIVATE
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
+ ${support_libs}
)
add_dependencies(${modules_target} ${_extension_target})
mlir_python_setup_extension_rpath(${_extension_target})
+ elseif(_source_type STREQUAL "support")
+ # do nothing because already built
else()
message(SEND_ERROR "Unrecognized source type '${_source_type}' for python source target ${sources_target}")
return()
@@ -356,8 +443,36 @@ function(add_mlir_python_modules name)
# Build the modules target.
add_custom_target(${name} ALL)
_flatten_mlir_python_targets(_flat_targets ${ARG_DECLARED_SOURCES})
+
+ # Build all support libs first.
+ set(_mlir_python_support_libs)
+ foreach(sources_target ${_flat_targets})
+ get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE)
+ if(_source_type STREQUAL "support")
+ get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
+ # Use a similar mechanism as nanobind to help the runtime loader pick the correct lib.
+ set(_module_name "${_module_name}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
+ set(_extension_target "${name}.extension.${_module_name}.so")
+ add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
+ INSTALL_COMPONENT ${name}
+ INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
+ OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ _PRIVATE_SUPPORT_LIB
+ LINK_LIBS PRIVATE
+ LLVMSupport
+ ${sources_target}
+ ${ARG_COMMON_CAPI_LINK_LIBS}
+ )
+ add_dependencies(${name} ${_extension_target})
+ mlir_python_setup_extension_rpath(${_extension_target})
+ list(APPEND _mlir_python_support_libs "${_extension_target}")
+ endif()
+ endforeach()
+
+ # Build extensions.
foreach(sources_target ${_flat_targets})
- _process_target(${name} ${sources_target})
+ _process_target(${name} ${sources_target} "${_mlir_python_support_libs}")
endforeach()
# Create an install target.
@@ -622,7 +737,7 @@ function(add_mlir_python_common_capi_library name)
set_target_properties(${name} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
- # Needed for windows (and don't hurt others).
+ # Needed for windows (and doesn't hurt others).
RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
)
@@ -742,10 +857,10 @@ endfunction()
################################################################################
# Build python extension
################################################################################
-function(add_mlir_python_extension libname extname)
+function(add_mlir_python_extension libname extname nb_library_target_name)
cmake_parse_arguments(ARG
- ""
- "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY"
+ "_PRIVATE_SUPPORT_LIB"
+ "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
"SOURCES;LINK_LIBS"
${ARGN})
if(ARG_UNPARSED_ARGUMENTS)
@@ -761,10 +876,41 @@ function(add_mlir_python_extension libname extname)
set(eh_rtti_enable -frtti -fexceptions)
endif ()
- nanobind_add_module(${libname}
- NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- FREE_THREADED
- ${ARG_SOURCES}
+ if(ARG__PRIVATE_SUPPORT_LIB)
+ add_library(${libname} SHARED ${ARG_SOURCES})
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ # nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols
+ # (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym)
+ # but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs
+ # we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much
+ # maintenance).
+ target_link_options(${libname} PRIVATE "LINKER:-z,undefs")
+ endif()
+ nanobind_link_options(${libname})
+ target_compile_definitions(${libname}
+ PRIVATE
+ NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ MLIR_CAPI_BUILDING_LIBRARY=1
+ )
+ if(MSVC)
+ set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
+ endif()
+ else()
+ nanobind_add_module(${libname}
+ NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ FREE_THREADED
+ NB_SHARED
+ ${ARG_SOURCES}
+ )
+ target_compile_definitions(${libname}
+ PRIVATE
+ MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
+ endif()
+ target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
+ target_compile_definitions(${libname}
+ PRIVATE
+ MLIR_BINDINGS_PYTHON_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
if(APPLE)
# In llvm/cmake/modules/HandleLLVMOptions.cmake:268 we set -Wl,-flat_namespace which breaks
@@ -778,29 +924,28 @@ function(add_mlir_python_extension libname extname)
# Avoid some warnings from upstream nanobind.
# If a superproject set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES, let
# the super project handle compile options as it wishes.
- get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES)
- target_compile_options(${NB_LIBRARY_TARGET_NAME}
+ target_compile_options(${nb_library_target_name}
PRIVATE
-Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- -Wno-missing-field-initializers
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-deprecated-literal-operator
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ -Wno-missing-field-initializers
${eh_rtti_enable})
target_compile_options(${libname}
PRIVATE
-Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- -Wno-missing-field-initializers
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-deprecated-literal-operator
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ -Wno-missing-field-initializers
${eh_rtti_enable})
endif()
@@ -818,12 +963,26 @@ function(add_mlir_python_extension libname extname)
target_compile_options(${libname} PRIVATE ${eh_rtti_enable})
+ # Quoting CMake:
+ #
+ # "If you use it on normal shared libraries which other targets link against, on some platforms a
+ # linker will insert a full path to the library (as specified at link time) into the dynamic section of the
+ # dependent binary. Therefore, once installed, dynamic loader may eventually fail to locate the library
+ # for the binary."
+ #
+ # So for support libs we do need an SO name but for extensions we do not (they're MODULEs anyway -
+ # i.e., can't be linked against, only loaded).
+ if (ARG__PRIVATE_SUPPORT_LIB)
+ set(_no_soname OFF)
+ else ()
+ set(_no_soname ON)
+ endif ()
# Configure the output to match python expectations.
set_target_properties(
${libname} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${ARG_OUTPUT_DIRECTORY}
OUTPUT_NAME "${extname}"
- NO_SONAME ON
+ NO_SONAME ${_no_soname}
)
if(WIN32)
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 4f4f531f7723c..4278774933a4a 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -37,6 +37,13 @@
LLVM ERROR: ... unregistered/uninitialized dialect/type/pass ...`
```
+* **`MLIR_BINDINGS_PYTHON_NB_DOMAIN`**: `STRING`
+
+ nanobind (and MLIR) domain within which extensions will be compiled.
+ This determines whether this package will share nanobind types with other bindings packages.
+ Expected to be unique per project (and per specific set of bindings, for projects with multiple bindings packages).
+ Can also be passed explicitly to `add_mlir_python_modules`.
+
### Recommended development practices
It is recommended to use a Python virtual environment. Many ways exist for this,
diff --git a/mlir/examples/standalone/include/Standalone-c/Dialects.h b/mlir/examples/standalone/include/Standalone-c/Dialects.h
index b3e47752ccc69..5aa9e004cb9fe 100644
--- a/mlir/examples/standalone/include/Standalone-c/Dialects.h
+++ b/mlir/examples/standalone/include/Standalone-c/Dialects.h
@@ -17,6 +17,13 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Standalone, standalone);
+MLIR_CAPI_EXPORTED MlirType mlirStandaloneCustomTypeGet(MlirContext ctx,
+ MlirStringRef value);
+
+MLIR_CAPI_EXPORTED bool mlirStandaloneTypeIsACustomType(MlirType t);
+
+MLIR_CAPI_EXPORTED MlirTypeID mlirStandaloneCustomTypeGetTypeID(void);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/examples/standalone/lib/CAPI/Dialects.cpp b/mlir/examples/standalone/lib/CAPI/Dialects.cpp
index 98006e81a3d26..4de55ba485490 100644
--- a/mlir/examples/standalone/lib/CAPI/Dialects.cpp
+++ b/mlir/examples/standalone/lib/CAPI/Dialects.cpp
@@ -9,7 +9,20 @@
#include "Standalone-c/Dialects.h"
#include "Standalone/StandaloneDialect.h"
+#include "Standalone/StandaloneTypes.h"
#include "mlir/CAPI/Registration.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Standalone, standalone,
mlir::standalone::StandaloneDialect)
+
+MlirType mlirStandaloneCustomTypeGet(MlirContext ctx, MlirStringRef value) {
+ return wrap(mlir::standalone::CustomType::get(unwrap(ctx), unwrap(value)));
+}
+
+bool mlirStandaloneTypeIsACustomType(MlirType t) {
+ return llvm::isa<mlir::standalone::CustomType>(unwrap(t));
+}
+
+MlirTypeID mlirStandaloneCustomTypeGetTypeID() {
+ return wrap(mlir::standalone::CustomType::getTypeID());
+}
diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
index 0ec6cdfa7994b..c568369913595 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
@@ -11,17 +11,44 @@
#include "Standalone-c/Dialects.h"
#include "mlir-c/Dialect/Arith.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
+struct PyCustomType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyCustomType> {
+ static constexpr IsAFunctionTy isaFunction = mlirStandaloneTypeIsACustomType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirStandaloneCustomTypeGetTypeID;
+ static constexpr const char *pyClassName = "CustomType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
+ return PyCustomType(
+ context->getRef(),
+ mlirStandaloneCustomTypeGet(
+ context.get()->get(),
+ mlirStringRefCreateFromCString(value.c_str())));
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_standaloneDialectsNanobind, m) {
//===--------------------------------------------------------------------===//
// standalone dialect
//===--------------------------------------------------------------------===//
auto standaloneM = m.def_submodule("standalone");
+ PyCustomType::bind(standaloneM);
+
standaloneM.def(
"register_dialects",
[](MlirContext context, bool load) {
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 09040eb2f45dc..fe4e40e6e8a99 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -19,6 +19,10 @@
# CHECK: standalone.foo %[[C2]] : i32
print(str(standalone_module), file=sys.stderr)
+ custom_type = standalone_d.CustomType.get("foo")
+ # CHECK: !standalone.custom<"foo">
+ print(custom_type, file=sys.stderr)
+
# CHECK: Testing mlir package
print("Testing mlir package", file=sys.stderr)
diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index 78fc94f93439e..6abd8894227c3 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -46,6 +46,8 @@
#define MLIR_CAPI_EXPORTED __attribute__((visibility("default")))
#endif
+#define MLIR_PYTHON_API_EXPORTED MLIR_CAPI_EXPORTED
+
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
similarity index 96%
rename from mlir/lib/Bindings/Python/Globals.h
rename to mlir/include/mlir/Bindings/Python/Globals.h
index 1e81f53e465ac..5548a716cbe21 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -15,10 +15,11 @@
#include <unordered_set>
#include <vector>
-#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir/CAPI/Support.h"
+
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -27,19 +28,16 @@
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
-class PyGlobals {
+class MLIR_PYTHON_API_EXPORTED PyGlobals {
public:
PyGlobals();
~PyGlobals();
/// Most code should get the globals via this static accessor.
- static PyGlobals &get() {
- assert(instance && "PyGlobals is null");
- return *instance;
- }
+ static PyGlobals &get();
/// Get and set the list of parent modules to search for dialect
/// implementation classes.
@@ -119,7 +117,7 @@ class PyGlobals {
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
- class TracebackLoc {
+ class MLIR_PYTHON_API_EXPORTED TracebackLoc {
public:
bool locTracebacksEnabled();
@@ -199,7 +197,7 @@ class PyGlobals {
TracebackLoc tracebackLoc;
TypeIDAllocator typeIDAllocator;
};
-
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/include/mlir/Bindings/Python/IRCore.h
similarity index 69%
rename from mlir/lib/Bindings/Python/IRModule.h
rename to mlir/include/mlir/Bindings/Python/IRCore.h
index e706be3b4d32a..d8662137b60e7 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1,4 +1,4 @@
-//===- IRModules.h - IR Submodules of pybind module -----------------------===//
+//===- IRCore.h - IR helpers of python bindings ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,8 +7,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
-#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
-#define MLIR_BINDINGS_PYTHON_IRMODULES_H
+#ifndef MLIR_BINDINGS_PYTHON_IRCORE_H
+#define MLIR_BINDINGS_PYTHON_IRCORE_H
#include <optional>
#include <sstream>
@@ -20,17 +20,21 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ThreadPool.h"
namespace mlir {
namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyBlock;
class PyDiagnostic;
@@ -47,10 +51,20 @@ class PyType;
class PySymbolTable;
class PyValue;
+/// Wrapper for the global LLVM debugging flag.
+struct MLIR_PYTHON_API_EXPORTED PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable);
+ static bool get(const nanobind::object &);
+ static void bind(nanobind::module_ &m);
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+
/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
template <typename T>
-class PyObjectRef {
+class MLIR_PYTHON_API_EXPORTED PyObjectRef {
public:
PyObjectRef(T *referrent, nanobind::object object)
: referrent(referrent), object(std::move(object)) {
@@ -109,7 +123,7 @@ class PyObjectRef {
/// Context. Pushing a Context will not modify the Location or InsertionPoint
/// unless if they are from a different context, in which case, they are
/// cleared.
-class PyThreadContextEntry {
+class MLIR_PYTHON_API_EXPORTED PyThreadContextEntry {
public:
enum class FrameKind {
Context,
@@ -165,22 +179,16 @@ class PyThreadContextEntry {
/// Wrapper around MlirLlvmThreadPool
/// Python object owns the C++ thread pool
-class PyThreadPool {
+class MLIR_PYTHON_API_EXPORTED PyThreadPool {
public:
- PyThreadPool() {
- ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
- }
+ PyThreadPool();
PyThreadPool(const PyThreadPool &) = delete;
PyThreadPool(PyThreadPool &&) = delete;
int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
- std::string _mlir_thread_pool_ptr() const {
- std::stringstream ss;
- ss << ownedThreadPool.get();
- return ss.str();
- }
+ std::string _mlir_thread_pool_ptr() const;
private:
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
@@ -188,7 +196,7 @@ class PyThreadPool {
/// Wrapper around MlirContext.
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
-class PyMlirContext {
+class MLIR_PYTHON_API_EXPORTED PyMlirContext {
public:
PyMlirContext() = delete;
PyMlirContext(MlirContext context);
@@ -205,9 +213,7 @@ class PyMlirContext {
/// Gets a strong reference to this context, which will ensure it is kept
/// alive for the life of the reference.
- PyMlirContextRef getRef() {
- return PyMlirContextRef(this, nanobind::cast(this));
- }
+ PyMlirContextRef getRef();
/// Gets a capsule wrapping the void* within the MlirContext.
nanobind::object getCapsule();
@@ -269,7 +275,7 @@ class PyMlirContext {
/// Used in function arguments when None should resolve to the current context
/// manager set instance.
-class DefaultingPyMlirContext
+class MLIR_PYTHON_API_EXPORTED DefaultingPyMlirContext
: public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
public:
using Defaulting::Defaulting;
@@ -281,7 +287,7 @@ class DefaultingPyMlirContext
/// MlirContext. The lifetime of the context will extend at least to the
/// lifetime of these instances.
/// Immutable objects that depend on a context extend this directly.
-class BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED BaseContextObject {
public:
BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
assert(this->contextRef &&
@@ -296,7 +302,7 @@ class BaseContextObject {
};
/// Wrapper around an MlirLocation.
-class PyLocation : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyLocation : public BaseContextObject {
public:
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
: BaseContextObject(std::move(contextRef)), loc(loc) {}
@@ -323,16 +329,35 @@ class PyLocation : public BaseContextObject {
MlirLocation loc;
};
+enum PyDiagnosticSeverity : std::underlying_type_t<MlirDiagnosticSeverity> {
+ MlirDiagnosticError = MlirDiagnosticError,
+ MlirDiagnosticWarning = MlirDiagnosticWarning,
+ MlirDiagnosticNote = MlirDiagnosticNote,
+ MlirDiagnosticRemark = MlirDiagnosticRemark
+};
+
+enum PyWalkResult : std::underlying_type_t<MlirWalkResult> {
+ MlirWalkResultAdvance = MlirWalkResultAdvance,
+ MlirWalkResultInterrupt = MlirWalkResultInterrupt,
+ MlirWalkResultSkip = MlirWalkResultSkip
+};
+
+/// Traversal order for operation walk.
+enum PyWalkOrder : std::underlying_type_t<MlirWalkOrder> {
+ MlirWalkPreOrder = MlirWalkPreOrder,
+ MlirWalkPostOrder = MlirWalkPostOrder
+};
+
/// Python class mirroring the C MlirDiagnostic struct. Note that these structs
/// are only valid for the duration of a diagnostic callback and attempting
/// to access them outside of that will raise an exception. This applies to
/// nested diagnostics (in the notes) as well.
-class PyDiagnostic {
+class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
public:
PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
void invalidate();
bool isValid() { return valid; }
- MlirDiagnosticSeverity getSeverity();
+ PyDiagnosticSeverity getSeverity();
PyLocation getLocation();
nanobind::str getMessage();
nanobind::tuple getNotes();
@@ -340,7 +365,7 @@ class PyDiagnostic {
/// Materialized diagnostic information. This is safe to access outside the
/// diagnostic callback.
struct DiagnosticInfo {
- MlirDiagnosticSeverity severity;
+ PyDiagnosticSeverity severity;
PyLocation location;
std::string message;
std::vector<DiagnosticInfo> notes;
@@ -377,7 +402,7 @@ class PyDiagnostic {
/// The object may remain live from a Python perspective for an arbitrary time
/// after detachment, but there is nothing the user can do with it (since there
/// is no way to attach an existing handler object).
-class PyDiagnosticHandler {
+class MLIR_PYTHON_API_EXPORTED PyDiagnosticHandler {
public:
PyDiagnosticHandler(MlirContext context, nanobind::object callback);
~PyDiagnosticHandler();
@@ -405,7 +430,7 @@ class PyDiagnosticHandler {
/// RAII object that captures any error diagnostics emitted to the provided
/// context.
-struct PyMlirContext::ErrorCapture {
+struct MLIR_PYTHON_API_EXPORTED PyMlirContext::ErrorCapture {
ErrorCapture(PyMlirContextRef ctx)
: ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler(
ctx->get(), handler, /*userData=*/this,
@@ -432,7 +457,7 @@ struct PyMlirContext::ErrorCapture {
/// plugins which extend dialect functionality through extension python code.
/// This should be seen as the "low-level" object and `Dialect` as the
/// high-level, user facing object.
-class PyDialectDescriptor : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyDialectDescriptor : public BaseContextObject {
public:
PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
: BaseContextObject(std::move(contextRef)), dialect(dialect) {}
@@ -445,7 +470,7 @@ class PyDialectDescriptor : public BaseContextObject {
/// User-level object for accessing dialects with dotted syntax such as:
/// ctx.dialect.std
-class PyDialects : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyDialects : public BaseContextObject {
public:
PyDialects(PyMlirContextRef contextRef)
: BaseContextObject(std::move(contextRef)) {}
@@ -456,7 +481,7 @@ class PyDialects : public BaseContextObject {
/// User-level dialect object. For dialects that have a registered extension,
/// this will be the base class of the extension dialect type. For un-extended,
/// objects of this type will be returned directly.
-class PyDialect {
+class MLIR_PYTHON_API_EXPORTED PyDialect {
public:
PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {}
@@ -469,7 +494,7 @@ class PyDialect {
/// Wrapper around an MlirDialectRegistry.
/// Upon construction, the Python wrapper takes ownership of the
/// underlying MlirDialectRegistry.
-class PyDialectRegistry {
+class MLIR_PYTHON_API_EXPORTED PyDialectRegistry {
public:
PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {}
PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {}
@@ -495,7 +520,7 @@ class PyDialectRegistry {
/// Used in function arguments when None should resolve to the current context
/// manager set instance.
-class DefaultingPyLocation
+class MLIR_PYTHON_API_EXPORTED DefaultingPyLocation
: public Defaulting<DefaultingPyLocation, PyLocation> {
public:
using Defaulting::Defaulting;
@@ -509,7 +534,7 @@ class DefaultingPyLocation
/// This is the top-level, user-owned object that contains regions/ops/blocks.
class PyModule;
using PyModuleRef = PyObjectRef<PyModule>;
-class PyModule : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyModule : public BaseContextObject {
public:
/// Returns a PyModule reference for the given MlirModule. This always returns
/// a new object.
@@ -549,7 +574,7 @@ class PyAsmState;
/// Base class for PyOperation and PyOpView which exposes the primary, user
/// visible methods for manipulating it.
-class PyOperationBase {
+class MLIR_PYTHON_API_EXPORTED PyOperationBase {
public:
virtual ~PyOperationBase() = default;
/// Implements the bound 'print' method and helps with others.
@@ -571,8 +596,8 @@ class PyOperationBase {
std::optional<int64_t> bytecodeVersion);
// Implement the walk method.
- void walk(std::function<MlirWalkResult(MlirOperation)> callback,
- MlirWalkOrder walkOrder);
+ void walk(std::function<PyWalkResult(MlirOperation)> callback,
+ PyWalkOrder walkOrder);
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
@@ -602,7 +627,8 @@ class PyOperationBase {
class PyOperation;
class PyOpView;
using PyOperationRef = PyObjectRef<PyOperation>;
-class PyOperation : public PyOperationBase, public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyOperation : public PyOperationBase,
+ public BaseContextObject {
public:
~PyOperation() override;
PyOperation &getOperation() override { return *this; }
@@ -627,32 +653,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// Detaches the operation from its parent block and updates its state
/// accordingly.
- void detachFromParent() {
- mlirOperationRemoveFromParent(getOperation());
- setDetached();
- parentKeepAlive = nanobind::object();
- }
+ void detachFromParent();
/// Gets the backing operation.
operator MlirOperation() const { return get(); }
- MlirOperation get() const {
- checkValid();
- return operation;
- }
+ MlirOperation get() const;
- PyOperationRef getRef() {
- return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
- }
+ PyOperationRef getRef();
bool isAttached() { return attached; }
- void setAttached(const nanobind::object &parent = nanobind::object()) {
- assert(!attached && "operation already attached");
- attached = true;
- }
- void setDetached() {
- assert(attached && "operation already detached");
- attached = false;
- }
+ void setAttached(const nanobind::object &parent = nanobind::object());
+ void setDetached();
void checkValid() const;
/// Gets the owning block or raises an exception if the operation has no
@@ -720,7 +731,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// custom ODS-style operation classes. Since this class is subclass on the
/// python side, it must present an __init__ method that operates in pure
/// python types.
-class PyOpView : public PyOperationBase {
+class MLIR_PYTHON_API_EXPORTED PyOpView : public PyOperationBase {
public:
PyOpView(const nanobind::object &operationObject);
PyOperation &getOperation() override { return operation; }
@@ -756,7 +767,7 @@ class PyOpView : public PyOperationBase {
/// Wrapper around an MlirRegion.
/// Regions are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached regions.
-class PyRegion {
+class MLIR_PYTHON_API_EXPORTED PyRegion {
public:
PyRegion(PyOperationRef parentOperation, MlirRegion region)
: parentOperation(std::move(parentOperation)), region(region) {
@@ -775,26 +786,10 @@ class PyRegion {
};
/// Wrapper around an MlirAsmState.
-class PyAsmState {
+class MLIR_PYTHON_API_EXPORTED PyAsmState {
public:
- PyAsmState(MlirValue value, bool useLocalScope) {
- flags = mlirOpPrintingFlagsCreate();
- // The OpPrintingFlags are not exposed Python side, create locally and
- // associate lifetime with the state.
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- state = mlirAsmStateCreateForValue(value, flags);
- }
-
- PyAsmState(PyOperationBase &operation, bool useLocalScope) {
- flags = mlirOpPrintingFlagsCreate();
- // The OpPrintingFlags are not exposed Python side, create locally and
- // associate lifetime with the state.
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- state =
- mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
- }
+ PyAsmState(MlirValue value, bool useLocalScope);
+ PyAsmState(PyOperationBase &operation, bool useLocalScope);
~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
// Delete copy constructors.
PyAsmState(PyAsmState &other) = delete;
@@ -810,7 +805,7 @@ class PyAsmState {
/// Wrapper around an MlirBlock.
/// Blocks are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached blocks.
-class PyBlock {
+class MLIR_PYTHON_API_EXPORTED PyBlock {
public:
PyBlock(PyOperationRef parentOperation, MlirBlock block)
: parentOperation(std::move(parentOperation)), block(block) {
@@ -834,7 +829,7 @@ class PyBlock {
/// Calls to insert() will insert a new operation before the
/// reference operation. If the reference operation is null, then appends to
/// the end of the block.
-class PyInsertionPoint {
+class MLIR_PYTHON_API_EXPORTED PyInsertionPoint {
public:
/// Creates an insertion point positioned after the last operation in the
/// block, but still inside the block.
@@ -873,9 +868,10 @@ class PyInsertionPoint {
std::optional<PyOperationRef> refOperation;
PyBlock block;
};
+
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
-class PyType : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyType : public BaseContextObject {
public:
PyType(PyMlirContextRef contextRef, MlirType type)
: BaseContextObject(std::move(contextRef)), type(type) {}
@@ -901,7 +897,7 @@ class PyType : public BaseContextObject {
/// A TypeID provides an efficient and unique identifier for a specific C++
/// type. This allows for a C++ type to be compared, hashed, and stored in an
/// opaque context. This class wraps around the generic MlirTypeID.
-class PyTypeID {
+class MLIR_PYTHON_API_EXPORTED PyTypeID {
public:
PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
// Note, this tests whether the underlying TypeIDs are the same,
@@ -927,7 +923,7 @@ class PyTypeID {
/// concrete type class extends PyType); however, intermediate python-visible
/// base classes can be modeled by specifying a BaseTy.
template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
+class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1005,7 +1001,7 @@ class PyConcreteType : public BaseTy {
/// Wrapper around the generic MlirAttribute.
/// The lifetime of a type is bound by the PyContext that created it.
-class PyAttribute : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAttribute : public BaseContextObject {
public:
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseContextObject(std::move(contextRef)), attr(attr) {}
@@ -1031,7 +1027,7 @@ class PyAttribute : public BaseContextObject {
/// Represents a Python MlirNamedAttr, carrying an optional owned name.
/// TODO: Refactor this and the C-API to be based on an Identifier owned
/// by the context so as to avoid ownership issues here.
-class PyNamedAttribute {
+class MLIR_PYTHON_API_EXPORTED PyNamedAttribute {
public:
/// Constructs a PyNamedAttr that retains an owned name. This should be
/// used in any code that originates an MlirNamedAttribute from a python
@@ -1057,7 +1053,7 @@ class PyNamedAttribute {
/// concrete attribute class extends PyAttribute); however, intermediate
/// python-visible base classes can be modeled by specifying a BaseTy.
template <typename DerivedTy, typename BaseTy = PyAttribute>
-class PyConcreteAttribute : public BaseTy {
+class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1147,7 +1143,8 @@ class PyConcreteAttribute : public BaseTy {
static void bindDerived(ClassTy &m) {}
};
-class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
+class MLIR_PYTHON_API_EXPORTED PyStringAttribute
+ : public PyConcreteAttribute<PyStringAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
static constexpr const char *pyClassName = "StringAttr";
@@ -1164,7 +1161,7 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
/// value. For block argument values, this is the operation that contains the
/// block to which the value is an argument (blocks cannot be detached in Python
/// bindings so such operation always exists).
-class PyValue {
+class MLIR_PYTHON_API_EXPORTED PyValue {
public:
// The virtual here is "load bearing" in that it enables RTTI
// for PyConcreteValue CRTP classes that support maybeDownCast.
@@ -1194,7 +1191,7 @@ class PyValue {
};
/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
-class PyAffineExpr : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAffineExpr : public BaseContextObject {
public:
PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
: BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
@@ -1221,7 +1218,7 @@ class PyAffineExpr : public BaseContextObject {
MlirAffineExpr affineExpr;
};
-class PyAffineMap : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAffineMap : public BaseContextObject {
public:
PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
: BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
@@ -1242,7 +1239,7 @@ class PyAffineMap : public BaseContextObject {
MlirAffineMap affineMap;
};
-class PyIntegerSet : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyIntegerSet : public BaseContextObject {
public:
PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
: BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
@@ -1263,7 +1260,7 @@ class PyIntegerSet : public BaseContextObject {
};
/// Bindings for MLIR symbol tables.
-class PySymbolTable {
+class MLIR_PYTHON_API_EXPORTED PySymbolTable {
public:
/// Constructs a symbol table for the given operation.
explicit PySymbolTable(PyOperationBase &operation);
@@ -1315,7 +1312,7 @@ class PySymbolTable {
/// Custom exception that allows access to error diagnostic information. This is
/// converted to the `ir.MLIRError` python exception when thrown.
-struct MLIRError {
+struct MLIR_PYTHON_API_EXPORTED MLIRError {
MLIRError(llvm::Twine message,
std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
: message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {}
@@ -1323,12 +1320,492 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
-void populateIRAffine(nanobind::module_ &m);
-void populateIRAttributes(nanobind::module_ &m);
-void populateIRCore(nanobind::module_ &m);
-void populateIRInterfaces(nanobind::module_ &m);
-void populateIRTypes(nanobind::module_ &m);
+//------------------------------------------------------------------------------
+// Utilities.
+//------------------------------------------------------------------------------
+
+inline MlirStringRef toMlirStringRef(const std::string &s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+inline MlirStringRef toMlirStringRef(std::string_view s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+inline MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
+ return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
+}
+
+/// Create a block, using the current location context if no locations are
+/// specified.
+MlirBlock MLIR_PYTHON_API_EXPORTED
+createBlock(const nanobind::sequence &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs);
+
+struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
+ static bool dunderContains(const std::string &attributeKind);
+ static nanobind::callable
+ dunderGetItemNamed(const std::string &attributeKind);
+ static void dunderSetItemNamed(const std::string &attributeKind,
+ nanobind::callable func, bool replace);
+
+ static void bind(nanobind::module_ &m);
+};
+
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+class MLIR_PYTHON_API_EXPORTED PyRegionIterator {
+public:
+ PyRegionIterator(PyOperationRef operation, int nextIndex)
+ : operation(std::move(operation)), nextIndex(nextIndex) {}
+
+ PyRegionIterator &dunderIter() { return *this; }
+
+ PyRegion dunderNext();
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ PyOperationRef operation;
+ intptr_t nextIndex = 0;
+};
+
+/// Regions of an op are fixed length and indexed numerically so are represented
+/// with a sequence-like container.
+class MLIR_PYTHON_API_EXPORTED PyRegionList
+ : public Sliceable<PyRegionList, PyRegion> {
+public:
+ static constexpr const char *pyClassName = "RegionSequence";
+
+ PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1);
+
+ PyRegionIterator dunderIter();
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyRegionList, PyRegion>;
+
+ intptr_t getRawNumElements();
+
+ PyRegion getRawElement(intptr_t pos);
+
+ PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) const;
+
+ PyOperationRef operation;
+};
+
+class MLIR_PYTHON_API_EXPORTED PyBlockIterator {
+public:
+ PyBlockIterator(PyOperationRef operation, MlirBlock next)
+ : operation(std::move(operation)), next(next) {}
+
+ PyBlockIterator &dunderIter() { return *this; }
+
+ PyBlock dunderNext();
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ PyOperationRef operation;
+ MlirBlock next;
+};
+
+/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
+/// we present them as a more full-featured list-like container but optimize
+/// it for forward iteration. Blocks are always owned by a region.
+class MLIR_PYTHON_API_EXPORTED PyBlockList {
+public:
+ PyBlockList(PyOperationRef operation, MlirRegion region)
+ : operation(std::move(operation)), region(region) {}
+
+ PyBlockIterator dunderIter();
+
+ intptr_t dunderLen();
+
+ PyBlock dunderGetItem(intptr_t index);
+
+ PyBlock appendBlock(const nanobind::args &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs);
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ PyOperationRef operation;
+ MlirRegion region;
+};
+
+class MLIR_PYTHON_API_EXPORTED PyOperationIterator {
+public:
+ PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
+ : parentOperation(std::move(parentOperation)), next(next) {}
+
+ PyOperationIterator &dunderIter() { return *this; }
+
+ nanobind::typed<nanobind::object, PyOpView> dunderNext();
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ PyOperationRef parentOperation;
+ MlirOperation next;
+};
+
+/// Operations are exposed by the C-API as a forward-only linked list. In
+/// Python, we present them as a more full-featured list-like container but
+/// optimize it for forward iteration. Iterable operations are always owned
+/// by a block.
+class MLIR_PYTHON_API_EXPORTED PyOperationList {
+public:
+ PyOperationList(PyOperationRef parentOperation, MlirBlock block)
+ : parentOperation(std::move(parentOperation)), block(block) {}
+
+ PyOperationIterator dunderIter();
+
+ intptr_t dunderLen();
+
+ nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index);
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ PyOperationRef parentOperation;
+ MlirBlock block;
+};
+
+class MLIR_PYTHON_API_EXPORTED PyOpOperand {
+public:
+ PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+ nanobind::typed<nanobind::object, PyOpView> getOwner() const;
+
+ size_t getOperandNumber() const;
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ MlirOpOperand opOperand;
+};
+
+class MLIR_PYTHON_API_EXPORTED PyOpOperandIterator {
+public:
+ PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+ PyOpOperandIterator &dunderIter() { return *this; }
+
+ PyOpOperand dunderNext();
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ MlirOpOperand opOperand;
+};
+
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = nanobind::class_<DerivedTy, PyValue>;
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ PyConcreteValue() = default;
+ PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+ : PyValue(operationRef, value) {}
+ PyConcreteValue(PyValue &orig)
+ : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+ /// Attempts to cast the original value to the derived type and throws on
+ /// type mismatches.
+ static MlirValue castFrom(PyValue &orig) {
+ if (!DerivedTy::isaFunction(orig.get())) {
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
+ throw nanobind::value_error((Twine("Cannot cast value to ") +
+ DerivedTy::pyClassName + " (from " +
+ origRepr + ")")
+ .str()
+ .c_str());
+ }
+ return orig.get();
+ }
+
+ /// Binds the Python module objects to functions of this class.
+ static void bind(nanobind::module_ &m) {
+ auto cls = ClassTy(
+ m, DerivedTy::pyClassName, nanobind::is_generic(),
+ nanobind::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
+ .str()
+ .c_str()));
+ cls.def(nanobind::init<PyValue &>(), nanobind::keep_alive<0, 1>(),
+ nanobind::arg("value"));
+ cls.def_static(
+ "isinstance",
+ [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ },
+ nanobind::arg("other_value"));
+ cls.def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
+ return self.maybeDownCast();
+ });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+/// Python wrapper for MlirOpResult.
+class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+ static constexpr const char *pyClassName = "OpResult";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class MLIR_PYTHON_API_EXPORTED PyOpResultList
+ : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+ static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+ PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1);
+
+ static void bindDerived(ClassTy &c);
+
+ PyOperationRef &getOperation() { return operation; }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpResultList, PyOpResult>;
+
+ intptr_t getRawNumElements();
+
+ PyOpResult getRawElement(intptr_t index);
+
+ PyOpResultList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
+
+ PyOperationRef operation;
+};
+
+/// Python wrapper for MlirBlockArgument.
+class MLIR_PYTHON_API_EXPORTED PyBlockArgument
+ : public PyConcreteValue<PyBlockArgument> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
+ static constexpr const char *pyClassName = "BlockArgument";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// A list of block arguments. Internally, these are stored as consecutive
+/// elements, random access is cheap. The argument list is associated with the
+/// operation that contains the block (detached blocks are not allowed in
+/// Python bindings) and extends its lifetime.
+class MLIR_PYTHON_API_EXPORTED PyBlockArgumentList
+ : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
+public:
+ static constexpr const char *pyClassName = "BlockArgumentList";
+ using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
+ PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1);
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
+ /// Returns the number of arguments in the list.
+ intptr_t getRawNumElements();
+
+ /// Returns `pos`-the element in the list.
+ PyBlockArgument getRawElement(intptr_t pos) const;
+
+ /// Returns a sublist of this list.
+ PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
+
+ PyOperationRef operation;
+ MlirBlock block;
+};
+
+/// A list of operation operands. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) operand list is associated
+/// with the operation whose operands these are, and thus extends the lifetime
+/// of this operation.
+class MLIR_PYTHON_API_EXPORTED PyOpOperandList
+ : public Sliceable<PyOpOperandList, PyValue> {
+public:
+ static constexpr const char *pyClassName = "OpOperandList";
+ using SliceableT = Sliceable<PyOpOperandList, PyValue>;
+
+ PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1);
+
+ void dunderSetItem(intptr_t index, PyValue value);
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpOperandList, PyValue>;
+ intptr_t getRawNumElements();
+
+ PyValue getRawElement(intptr_t pos);
+
+ PyOpOperandList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
+
+ PyOperationRef operation;
+};
+
+/// A list of operation successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation whose successors these are, and thus extends
+/// the lifetime of this operation.
+class MLIR_PYTHON_API_EXPORTED PyOpSuccessors
+ : public Sliceable<PyOpSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "OpSuccessors";
+
+ PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1);
+
+ void dunderSetItem(intptr_t index, PyBlock block);
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements();
+
+ PyBlock getRawElement(intptr_t pos);
+
+ PyOpSuccessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
+
+ PyOperationRef operation;
+};
+
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class MLIR_PYTHON_API_EXPORTED PyBlockSuccessors
+ : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockSuccessors";
+
+ PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1);
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements();
+
+ PyBlock getRawElement(intptr_t pos);
+
+ PyBlockSuccessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of block predecessors. The (returned) predecessor list is
+/// associated with the operation and block whose predecessors these are, and
+/// thus extends the lifetime of this operation and block.
+///
+/// WARNING: This Sliceable is more expensive than the others here because
+/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
+/// operands) anew for each indexed access.
+class MLIR_PYTHON_API_EXPORTED PyBlockPredecessors
+ : public Sliceable<PyBlockPredecessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockPredecessors";
+
+ PyBlockPredecessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1);
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockPredecessors, PyBlock>;
+
+ intptr_t getRawNumElements();
+
+ PyBlock getRawElement(intptr_t pos);
+
+ PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of operation attributes. Can be indexed by name, producing
+/// attributes, or by index, producing named attributes.
+class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
+public:
+ PyOpAttributeMap(PyOperationRef operation)
+ : operation(std::move(operation)) {}
+
+ nanobind::typed<nanobind::object, PyAttribute>
+ dunderGetItemNamed(const std::string &name);
+
+ PyNamedAttribute dunderGetItemIndexed(intptr_t index);
+
+ void dunderSetItem(const std::string &name, const PyAttribute &attr);
+
+ void dunderDelItem(const std::string &name);
+
+ intptr_t dunderLen();
+
+ bool dunderContains(const std::string &name);
+
+ static void
+ forEachAttr(MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn);
+
+ static void bind(nanobind::module_ &m);
+
+private:
+ PyOperationRef operation;
+};
+
+MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
+MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
+MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
@@ -1336,13 +1813,18 @@ namespace nanobind {
namespace detail {
template <>
-struct type_caster<mlir::python::DefaultingPyMlirContext>
- : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
+struct type_caster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
+ : MlirDefaultingCaster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext> {
+};
template <>
-struct type_caster<mlir::python::DefaultingPyLocation>
- : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
+struct type_caster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation>
+ : MlirDefaultingCaster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation> {};
} // namespace detail
} // namespace nanobind
-#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
+#endif // MLIR_BINDINGS_PYTHON_IRCORE_H
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index ba9642cf2c6a2..87e0e10764bd8 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -12,9 +12,11 @@
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace mlir {
-
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Shaped Type Interface - ShapedType
-class PyShapedType : public python::PyConcreteType<PyShapedType> {
+class MLIR_PYTHON_API_EXPORTED PyShapedType
+ : public PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
static constexpr const char *pyClassName = "ShapedType";
@@ -25,7 +27,8 @@ class PyShapedType : public python::PyConcreteType<PyShapedType> {
private:
void requireHasRank();
};
-
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_IRTYPES_H
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
similarity index 100%
rename from mlir/lib/Bindings/Python/NanobindUtils.h
rename to mlir/include/mlir/Bindings/Python/NanobindUtils.h
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 0d1d9e89f92f6..a87918a05b126 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Dialect/SMT.h"
#include "mlir-c/IR.h"
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/Globals.cpp
similarity index 95%
rename from mlir/lib/Bindings/Python/IRModule.cpp
rename to mlir/lib/Bindings/Python/Globals.cpp
index 0de2f1711829b..e2e8693ba45f3 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -6,25 +6,29 @@
//
//===----------------------------------------------------------------------===//
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include <optional>
#include <vector>
-#include "Globals.h"
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/Globals.h"
+// clang-format off
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
+// clang-format on
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
@@ -37,6 +41,11 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }
+PyGlobals &PyGlobals::get() {
+ assert(instance && "PyGlobals is null");
+ return *instance;
+}
+
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
{
nb::ft_lock_guard lock(mutex);
@@ -265,3 +274,6 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename(
}
return isUserTracebackFilenameCache[file];
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 7147f2cbad149..ce235470bbdc7 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -13,11 +13,13 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
+#include "mlir/Bindings/Python/IRCore.h"
+// clang-format off
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir-c/IntegerSet.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Support/LLVM.h"
@@ -28,7 +30,7 @@
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
using llvm::StringRef;
@@ -78,7 +80,9 @@ static bool isPermutation(const std::vector<PermutationTy> &permutation) {
return true;
}
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
/// and should be castable from it. Intermediate hierarchy classes can be
@@ -356,7 +360,9 @@ class PyAffineCeilDivExpr
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
return mlirAffineExprEqual(affineExpr, other.affineExpr);
@@ -378,7 +384,9 @@ PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) {
//------------------------------------------------------------------------------
// PyAffineMap and utilities.
//------------------------------------------------------------------------------
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// A list of expressions contained in an affine map. Internally these are
/// stored as a consecutive array leading to inexpensive random access. Both
@@ -414,7 +422,9 @@ class PyAffineMapExprList
PyAffineMap affineMap;
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyAffineMap::operator==(const PyAffineMap &other) const {
return mlirAffineMapEqual(affineMap, other.affineMap);
@@ -436,7 +446,9 @@ PyAffineMap PyAffineMap::createFromCapsule(const nb::object &capsule) {
//------------------------------------------------------------------------------
// PyIntegerSet and utilities.
//------------------------------------------------------------------------------
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyIntegerSetConstraint {
public:
@@ -490,7 +502,9 @@ class PyIntegerSetConstraintList
PyIntegerSet set;
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
return mlirIntegerSetEqual(integerSet, other.integerSet);
@@ -509,7 +523,10 @@ PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) {
rawIntegerSet);
}
-void mlir::python::populateIRAffine(nb::module_ &m) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void populateIRAffine(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineExpr and derived classes.
//----------------------------------------------------------------------------
@@ -995,3 +1012,6 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
PyIntegerSetConstraint::bind(m);
PyIntegerSetConstraintList::bind(m);
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c0a945e3f4f3b..f0f0ae9ba741e 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -12,19 +12,19 @@
#include <string_view>
#include <utility>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
namespace nb = nanobind;
using namespace nanobind::literals;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
@@ -121,7 +121,9 @@ subsequent processing.
type or if the buffer does not meet expectations.
)";
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
struct nb_buffer_info {
void *ptr = nullptr;
@@ -228,14 +230,6 @@ struct nb_format_descriptor<double> {
static const char *format() { return "d"; }
};
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(const nb::bytes &s) {
- return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
-}
-
class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
@@ -1753,7 +1747,9 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
void PyStringAttribute::bindDerived(ClassTy &c) {
c.def_static(
@@ -1799,7 +1795,10 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
-void mlir::python::populateIRAttributes(nb::module_ &m) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
@@ -1852,3 +1851,6 @@ void mlir::python::populateIRAttributes(nb::module_ &m) {
PyStridedLayoutAttribute::bind(m);
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..45e2cda4c91e2 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
+// clang-format off
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
@@ -27,7 +29,6 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
-using namespace mlir::python;
using llvm::SmallVector;
using llvm::StringRef;
@@ -64,44 +65,41 @@ static nb::object classmethod(Func f, Args... args) {
static nb::object
createCustomDialectWrapper(const std::string &dialectNamespace,
nb::object dialectDescriptor) {
- auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+ auto dialectClass =
+ python::MLIR_BINDINGS_PYTHON_DOMAIN::PyGlobals::get().lookupDialectClass(
+ dialectNamespace);
if (!dialectClass) {
// Use the base class.
- return nb::cast(PyDialect(std::move(dialectDescriptor)));
+ return nb::cast(python::MLIR_BINDINGS_PYTHON_DOMAIN::PyDialect(
+ std::move(dialectDescriptor)));
}
// Create the custom implementation.
return (*dialectClass)(std::move(dialectDescriptor));
}
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(std::string_view s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(const nb::bytes &s) {
- return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
-}
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-/// Create a block, using the current location context if no locations are
-/// specified.
-static MlirBlock createBlock(const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
+MlirBlock createBlock(const nb::sequence &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
SmallVector<MlirType> argTypes;
argTypes.reserve(nb::len(pyArgTypes));
for (const auto &pyType : pyArgTypes)
- argTypes.push_back(nb::cast<PyType &>(pyType));
+ argTypes.push_back(
+ nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyType &>(pyType));
SmallVector<MlirLocation> argLocs;
if (pyArgLocs) {
argLocs.reserve(nb::len(*pyArgLocs));
for (const auto &pyLoc : *pyArgLocs)
- argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
+ argLocs.push_back(
+ nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyLocation &>(pyLoc));
} else if (!argTypes.empty()) {
- argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
+ argLocs.assign(
+ argTypes.size(),
+ python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation::resolve());
}
if (argTypes.size() != argLocs.size())
@@ -112,82 +110,77 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
}
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nb::object &o, bool enable) {
- nb::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nb::object &) {
- nb::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
+void PyGlobalDebugFlag::set(nb::object &o, bool enable) {
+ nb::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+}
- static void bind(nb::module_ &m) {
- // Debug flags.
- nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- "types"_a, "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- "types"_a,
- "Sets multiple specific debug types to be produced by LLVM.");
- }
+bool PyGlobalDebugFlag::get(const nb::object &) {
+ nb::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+}
-private:
- static nb::ft_mutex mutex;
-};
+void PyGlobalDebugFlag::bind(nb::module_ &m) {
+ // Debug flags.
+ nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nb::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ "types"_a, "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nb::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ "types"_a,
+ "Sets multiple specific debug types to be produced by LLVM.");
+}
nb::ft_mutex PyGlobalDebugFlag::mutex;
-struct PyAttrBuilderMap {
- static bool dunderContains(const std::string &attributeKind) {
- return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
- }
- static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
- auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
- if (!builder)
- throw nb::key_error(attributeKind.c_str());
- return *builder;
- }
- static void dunderSetItemNamed(const std::string &attributeKind,
- nb::callable func, bool replace) {
- PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
- replace);
- }
+bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
+ return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
- .def_static("contains", &PyAttrBuilderMap::dunderContains,
- "attribute_kind"_a,
- "Checks whether an attribute builder is registered for the "
- "given attribute kind.")
- .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
- "attribute_kind"_a,
- "Gets the registered attribute builder for the given "
- "attribute kind.")
- .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
- "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
- "Register an attribute builder for building MLIR "
- "attributes from Python values.");
- }
-};
+nb::callable
+PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
+ auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+ if (!builder)
+ throw nb::key_error(attributeKind.c_str());
+ return *builder;
+}
+
+void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
+ nb::callable func, bool replace) {
+ PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
+ replace);
+}
+
+void PyAttrBuilderMap::bind(nb::module_ &m) {
+ nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
+ .def_static("contains", &PyAttrBuilderMap::dunderContains,
+ "attribute_kind"_a,
+ "Checks whether an attribute builder is registered for the "
+ "given attribute kind.")
+ .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
+ "attribute_kind"_a,
+ "Gets the registered attribute builder for the given "
+ "attribute kind.")
+ .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
+ "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
+ "Register an attribute builder for building MLIR "
+ "attributes from Python values.");
+}
//------------------------------------------------------------------------------
// PyBlock
@@ -201,335 +194,252 @@ nb::object PyBlock::getCapsule() {
// Collections.
//------------------------------------------------------------------------------
-namespace {
-
-class PyRegionIterator {
-public:
- PyRegionIterator(PyOperationRef operation, int nextIndex)
- : operation(std::move(operation)), nextIndex(nextIndex) {}
-
- PyRegionIterator &dunderIter() { return *this; }
-
- PyRegion dunderNext() {
- operation->checkValid();
- if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
- throw nb::stop_iteration();
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
- return PyRegion(operation, region);
+PyRegion PyRegionIterator::dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw nb::stop_iteration();
}
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyRegionIterator>(m, "RegionIterator")
- .def("__iter__", &PyRegionIterator::dunderIter,
- "Returns an iterator over the regions in the operation.")
- .def("__next__", &PyRegionIterator::dunderNext,
- "Returns the next region in the iteration.");
- }
+void PyRegionIterator::bind(nb::module_ &m) {
+ nb::class_<PyRegionIterator>(m, "RegionIterator")
+ .def("__iter__", &PyRegionIterator::dunderIter,
+ "Returns an iterator over the regions in the operation.")
+ .def("__next__", &PyRegionIterator::dunderNext,
+ "Returns the next region in the iteration.");
+}
-private:
- PyOperationRef operation;
- intptr_t nextIndex = 0;
-};
-
-/// Regions of an op are fixed length and indexed numerically so are represented
-/// with a sequence-like container.
-class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
-public:
- static constexpr const char *pyClassName = "RegionSequence";
-
- PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumRegions(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- PyRegionIterator dunderIter() {
- operation->checkValid();
- return PyRegionIterator(operation, startIndex);
- }
+PyRegionList::PyRegionList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
- static void bindDerived(ClassTy &c) {
- c.def("__iter__", &PyRegionList::dunderIter,
- "Returns an iterator over the regions in the sequence.");
- }
+PyRegionIterator PyRegionList::dunderIter() {
+ operation->checkValid();
+ return PyRegionIterator(operation, startIndex);
+}
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyRegionList, PyRegion>;
+void PyRegionList::bindDerived(ClassTy &c) {
+ c.def("__iter__", &PyRegionList::dunderIter,
+ "Returns an iterator over the regions in the sequence.");
+}
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumRegions(operation->get());
- }
+intptr_t PyRegionList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+}
- PyRegion getRawElement(intptr_t pos) {
- operation->checkValid();
- return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
- }
+PyRegion PyRegionList::getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
+}
- PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyRegionList(operation, startIndex, length, step);
- }
+PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyRegionList(operation, startIndex, length, step);
+}
- PyOperationRef operation;
-};
+PyBlock PyBlockIterator::dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw nb::stop_iteration();
+ }
-class PyBlockIterator {
-public:
- PyBlockIterator(PyOperationRef operation, MlirBlock next)
- : operation(std::move(operation)), next(next) {}
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+}
- PyBlockIterator &dunderIter() { return *this; }
+void PyBlockIterator::bind(nb::module_ &m) {
+ nb::class_<PyBlockIterator>(m, "BlockIterator")
+ .def("__iter__", &PyBlockIterator::dunderIter,
+ "Returns an iterator over the blocks in the operation's region.")
+ .def("__next__", &PyBlockIterator::dunderNext,
+ "Returns the next block in the iteration.");
+}
- PyBlock dunderNext() {
- operation->checkValid();
- if (mlirBlockIsNull(next)) {
- throw nb::stop_iteration();
- }
+PyBlockIterator PyBlockList::dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+}
- PyBlock returnBlock(operation, next);
- next = mlirBlockGetNextInRegion(next);
- return returnBlock;
+intptr_t PyBlockList::dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
}
+ return count;
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyBlockIterator>(m, "BlockIterator")
- .def("__iter__", &PyBlockIterator::dunderIter,
- "Returns an iterator over the blocks in the operation's region.")
- .def("__next__", &PyBlockIterator::dunderNext,
- "Returns the next block in the iteration.");
+PyBlock PyBlockList::dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
}
-
-private:
- PyOperationRef operation;
- MlirBlock next;
-};
-
-/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
-/// we present them as a more full-featured list-like container but optimize
-/// it for forward iteration. Blocks are always owned by a region.
-class PyBlockList {
-public:
- PyBlockList(PyOperationRef operation, MlirRegion region)
- : operation(std::move(operation)), region(region) {}
-
- PyBlockIterator dunderIter() {
- operation->checkValid();
- return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+ if (index < 0) {
+ throw nb::index_error("attempt to access out of bounds block");
}
-
- intptr_t dunderLen() {
- operation->checkValid();
- intptr_t count = 0;
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- count += 1;
- block = mlirBlockGetNextInRegion(block);
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
}
- return count;
- }
-
- PyBlock dunderGetItem(intptr_t index) {
- operation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nb::index_error("attempt to access out of bounds block");
- }
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- if (index == 0) {
- return PyBlock(operation, block);
- }
- block = mlirBlockGetNextInRegion(block);
- index -= 1;
- }
- throw nb::index_error("attempt to access out of bounds block");
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
}
+ throw nb::index_error("attempt to access out of bounds block");
+}
- PyBlock appendBlock(const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- operation->checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- mlirRegionAppendOwnedBlock(region, block);
- return PyBlock(operation, block);
- }
+PyBlock PyBlockList::appendBlock(const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ operation->checkValid();
+ MlirBlock block = createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ mlirRegionAppendOwnedBlock(region, block);
+ return PyBlock(operation, block);
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyBlockList>(m, "BlockList")
- .def("__getitem__", &PyBlockList::dunderGetItem,
- "Returns the block at the specified index.")
- .def("__iter__", &PyBlockList::dunderIter,
- "Returns an iterator over blocks in the operation's region.")
- .def("__len__", &PyBlockList::dunderLen,
- "Returns the number of blocks in the operation's region.")
- .def("append", &PyBlockList::appendBlock,
- R"(
+void PyBlockList::bind(nb::module_ &m) {
+ nb::class_<PyBlockList>(m, "BlockList")
+ .def("__getitem__", &PyBlockList::dunderGetItem,
+ "Returns the block at the specified index.")
+ .def("__iter__", &PyBlockList::dunderIter,
+ "Returns an iterator over blocks in the operation's region.")
+ .def("__len__", &PyBlockList::dunderLen,
+ "Returns the number of blocks in the operation's region.")
+ .def("append", &PyBlockList::appendBlock,
+ R"(
Appends a new block, with argument types as positional args.
Returns:
The created block.
)",
- nb::arg("args"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt);
- }
-
-private:
- PyOperationRef operation;
- MlirRegion region;
-};
-
-class PyOperationIterator {
-public:
- PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
- : parentOperation(std::move(parentOperation)), next(next) {}
+ "args"_a, nb::kw_only(), "arg_locs"_a = std::nullopt);
+}
- PyOperationIterator &dunderIter() { return *this; }
+nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw nb::stop_iteration();
+ }
- nb::typed<nb::object, PyOpView> dunderNext() {
- parentOperation->checkValid();
- if (mlirOperationIsNull(next)) {
- throw nb::stop_iteration();
- }
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation->createOpView();
+}
- PyOperationRef returnOperation =
- PyOperation::forOperation(parentOperation->getContext(), next);
- next = mlirOperationGetNextInBlock(next);
- return returnOperation->createOpView();
- }
+void PyOperationIterator::bind(nb::module_ &m) {
+ nb::class_<PyOperationIterator>(m, "OperationIterator")
+ .def("__iter__", &PyOperationIterator::dunderIter,
+ "Returns an iterator over the operations in an operation's block.")
+ .def("__next__", &PyOperationIterator::dunderNext,
+ "Returns the next operation in the iteration.");
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyOperationIterator>(m, "OperationIterator")
- .def("__iter__", &PyOperationIterator::dunderIter,
- "Returns an iterator over the operations in an operation's block.")
- .def("__next__", &PyOperationIterator::dunderNext,
- "Returns the next operation in the iteration.");
- }
+PyOperationIterator PyOperationList::dunderIter() {
+ parentOperation->checkValid();
+ return PyOperationIterator(parentOperation,
+ mlirBlockGetFirstOperation(block));
+}
-private:
- PyOperationRef parentOperation;
- MlirOperation next;
-};
-
-/// Operations are exposed by the C-API as a forward-only linked list. In
-/// Python, we present them as a more full-featured list-like container but
-/// optimize it for forward iteration. Iterable operations are always owned
-/// by a block.
-class PyOperationList {
-public:
- PyOperationList(PyOperationRef parentOperation, MlirBlock block)
- : parentOperation(std::move(parentOperation)), block(block) {}
-
- PyOperationIterator dunderIter() {
- parentOperation->checkValid();
- return PyOperationIterator(parentOperation,
- mlirBlockGetFirstOperation(block));
+intptr_t PyOperationList::dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
}
+ return count;
+}
- intptr_t dunderLen() {
- parentOperation->checkValid();
- intptr_t count = 0;
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- count += 1;
- childOp = mlirOperationGetNextInBlock(childOp);
- }
- return count;
+nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
}
-
- nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
- parentOperation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nb::index_error("attempt to access out of bounds operation");
- }
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- if (index == 0) {
- return PyOperation::forOperation(parentOperation->getContext(), childOp)
- ->createOpView();
- }
- childOp = mlirOperationGetNextInBlock(childOp);
- index -= 1;
- }
+ if (index < 0) {
throw nb::index_error("attempt to access out of bounds operation");
}
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOperationList>(m, "OperationList")
- .def("__getitem__", &PyOperationList::dunderGetItem,
- "Returns the operation at the specified index.")
- .def("__iter__", &PyOperationList::dunderIter,
- "Returns an iterator over operations in the list.")
- .def("__len__", &PyOperationList::dunderLen,
- "Returns the number of operations in the list.");
- }
-
-private:
- PyOperationRef parentOperation;
- MlirBlock block;
-};
-
-class PyOpOperand {
-public:
- PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
-
- nb::typed<nb::object, PyOpView> getOwner() {
- MlirOperation owner = mlirOpOperandGetOwner(opOperand);
- PyMlirContextRef context =
- PyMlirContext::forContext(mlirOperationGetContext(owner));
- return PyOperation::forOperation(context, owner)->createOpView();
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ ->createOpView();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
}
+ throw nb::index_error("attempt to access out of bounds operation");
+}
- size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
+void PyOperationList::bind(nb::module_ &m) {
+ nb::class_<PyOperationList>(m, "OperationList")
+ .def("__getitem__", &PyOperationList::dunderGetItem,
+ "Returns the operation at the specified index.")
+ .def("__iter__", &PyOperationList::dunderIter,
+ "Returns an iterator over operations in the list.")
+ .def("__len__", &PyOperationList::dunderLen,
+ "Returns the number of operations in the list.");
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyOpOperand>(m, "OpOperand")
- .def_prop_ro("owner", &PyOpOperand::getOwner,
- "Returns the operation that owns this operand.")
- .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
- "Returns the operand number in the owning operation.");
- }
+nb::typed<nb::object, PyOpView> PyOpOperand::getOwner() const {
+ MlirOperation owner = mlirOpOperandGetOwner(opOperand);
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(owner));
+ return PyOperation::forOperation(context, owner)->createOpView();
+}
-private:
- MlirOpOperand opOperand;
-};
+size_t PyOpOperand::getOperandNumber() const {
+ return mlirOpOperandGetOperandNumber(opOperand);
+}
-class PyOpOperandIterator {
-public:
- PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
+void PyOpOperand::bind(nb::module_ &m) {
+ nb::class_<PyOpOperand>(m, "OpOperand")
+ .def_prop_ro("owner", &PyOpOperand::getOwner,
+ "Returns the operation that owns this operand.")
+ .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
+ "Returns the operand number in the owning operation.");
+}
- PyOpOperandIterator &dunderIter() { return *this; }
+PyOpOperand PyOpOperandIterator::dunderNext() {
+ if (mlirOpOperandIsNull(opOperand))
+ throw nb::stop_iteration();
- PyOpOperand dunderNext() {
- if (mlirOpOperandIsNull(opOperand))
- throw nb::stop_iteration();
+ PyOpOperand returnOpOperand(opOperand);
+ opOperand = mlirOpOperandGetNextUse(opOperand);
+ return returnOpOperand;
+}
- PyOpOperand returnOpOperand(opOperand);
- opOperand = mlirOpOperandGetNextUse(opOperand);
- return returnOpOperand;
- }
+void PyOpOperandIterator::bind(nb::module_ &m) {
+ nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
+ .def("__iter__", &PyOpOperandIterator::dunderIter,
+ "Returns an iterator over operands.")
+ .def("__next__", &PyOpOperandIterator::dunderNext,
+ "Returns the next operand in the iteration.");
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
- .def("__iter__", &PyOpOperandIterator::dunderIter,
- "Returns an iterator over operands.")
- .def("__next__", &PyOpOperandIterator::dunderNext,
- "Returns the next operand in the iteration.");
- }
+//------------------------------------------------------------------------------
+// PyThreadPool
+//------------------------------------------------------------------------------
-private:
- MlirOpOperand opOperand;
-};
+PyThreadPool::PyThreadPool() {
+ ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+}
-} // namespace
+std::string PyThreadPool::_mlir_thread_pool_ptr() const {
+ std::stringstream ss;
+ ss << ownedThreadPool.get();
+ return ss.str();
+}
//------------------------------------------------------------------------------
// PyMlirContext
@@ -554,6 +464,10 @@ PyMlirContext::~PyMlirContext() {
mlirContextDestroy(context);
}
+PyMlirContextRef PyMlirContext::getRef() {
+ return PyMlirContextRef(this, nb::cast(this));
+}
+
nb::object PyMlirContext::getCapsule() {
return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
}
@@ -662,7 +576,8 @@ MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
if (self->ctx->emitErrorDiagnostics)
return mlirLogicalResultFailure();
- if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
+ if (mlirDiagnosticGetSeverity(diag) !=
+ MlirDiagnosticSeverity::MlirDiagnosticError)
return mlirLogicalResultFailure();
self->errors.emplace_back(PyDiagnostic(diag).getInfo());
@@ -849,9 +764,10 @@ void PyDiagnostic::checkValid() {
}
}
-MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
+PyDiagnosticSeverity PyDiagnostic::getSeverity() {
checkValid();
- return mlirDiagnosticGetSeverity(diagnostic);
+ return static_cast<PyDiagnosticSeverity>(
+ mlirDiagnosticGetSeverity(diagnostic));
}
PyLocation PyDiagnostic::getLocation() {
@@ -1088,6 +1004,31 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
return PyOperation::createDetached(std::move(contextRef), op);
}
+void PyOperation::detachFromParent() {
+ mlirOperationRemoveFromParent(getOperation());
+ setDetached();
+ parentKeepAlive = nb::object();
+}
+
+MlirOperation PyOperation::get() const {
+ checkValid();
+ return operation;
+}
+
+PyOperationRef PyOperation::getRef() {
+ return PyOperationRef(this, nb::borrow<nb::object>(handle));
+}
+
+void PyOperation::setAttached(const nb::object &parent) {
+ assert(!attached && "operation already attached");
+ attached = true;
+}
+
+void PyOperation::setDetached() {
+ assert(attached && "operation already detached");
+ attached = false;
+}
+
void PyOperation::checkValid() const {
if (!valid) {
throw std::runtime_error("the operation has been invalidated");
@@ -1164,13 +1105,12 @@ void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
.c_str());
}
-void PyOperationBase::walk(
- std::function<MlirWalkResult(MlirOperation)> callback,
- MlirWalkOrder walkOrder) {
+void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
+ PyWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
struct UserData {
- std::function<MlirWalkResult(MlirOperation)> callback;
+ std::function<PyWalkResult(MlirOperation)> callback;
bool gotException;
std::string exceptionWhat;
nb::object exceptionType;
@@ -1180,7 +1120,7 @@ void PyOperationBase::walk(
void *userData) {
UserData *calleeUserData = static_cast<UserData *>(userData);
try {
- return (calleeUserData->callback)(op);
+ return static_cast<MlirWalkResult>((calleeUserData->callback)(op));
} catch (nb::python_error &e) {
calleeUserData->gotException = true;
calleeUserData->exceptionWhat = std::string(e.what());
@@ -1188,7 +1128,8 @@ void PyOperationBase::walk(
return MlirWalkResult::MlirWalkResultInterrupt;
}
};
- mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+ mlirOperationWalk(operation, walkCallback, &userData,
+ static_cast<MlirWalkOrder>(walkOrder));
if (userData.gotException) {
std::string message("Exception raised in callback: ");
message.append(userData.exceptionWhat);
@@ -1448,93 +1389,22 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
-namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- // and redefine bindDerived.
- using ClassTy = nb::class_<DerivedTy, PyValue>;
- using IsAFunctionTy = bool (*)(MlirValue);
-
- PyConcreteValue() = default;
- PyConcreteValue(PyOperationRef operationRef, MlirValue value)
- : PyValue(operationRef, value) {}
- PyConcreteValue(PyValue &orig)
- : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
- /// Attempts to cast the original value to the derived type and throws on
- /// type mismatches.
- static MlirValue castFrom(PyValue &orig) {
- if (!DerivedTy::isaFunction(orig.get())) {
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
- throw nb::value_error((Twine("Cannot cast value to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str()
- .c_str());
- }
- return orig.get();
- }
-
- /// Binds the Python module objects to functions of this class.
- static void bind(nb::module_ &m) {
- auto cls = ClassTy(
- m, DerivedTy::pyClassName, nb::is_generic(),
- nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
- .str()
- .c_str()));
- cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
- cls.def_static(
- "isinstance",
- [](PyValue &otherValue) -> bool {
- return DerivedTy::isaFunction(otherValue);
- },
- nb::arg("other_value"));
- cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
- return self.maybeDownCast();
- });
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-} // namespace
-
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
- static constexpr const char *pyClassName = "OpResult";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in "
- "the IR");
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation that produces this result.");
- c.def_prop_ro(
- "result_number",
- [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- },
- "Returns the position of this result in the operation's result list.");
- }
-};
+void PyOpResult::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in "
+ "the IR");
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the operation that produces this result.");
+ c.def_prop_ro(
+ "result_number",
+ [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
+ "Returns the position of this result in the operation's result list.");
+}
/// Returns the list of types of the values held by container.
template <typename Container>
@@ -1550,60 +1420,43 @@ getValueTypes(Container &container, PyMlirContextRef &context) {
return result;
}
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
- static constexpr const char *pyClassName = "OpResultList";
- using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
- PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumResults(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all results in this result list.");
- c.def_prop_ro(
- "owner",
- [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
- return self.operation->createOpView();
- },
- "Returns the operation that owns this result list.");
- }
-
- PyOperationRef &getOperation() { return operation; }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpResultList, PyOpResult>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
- }
+PyOpResultList::PyOpResultList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+void PyOpResultList::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all results in this result list.");
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
+ return self.operation->createOpView();
+ },
+ "Returns the operation that owns this result list.");
+}
- PyOpResult getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
- }
+intptr_t PyOpResultList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+}
- PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpResultList(operation, startIndex, length, step);
- }
+PyOpResult PyOpResultList::getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+}
- PyOperationRef operation;
-};
+PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpResultList(operation, startIndex, length, step);
+}
//------------------------------------------------------------------------------
// PyOpView
@@ -1706,7 +1559,7 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
}
}
-static MlirValue getUniqueResult(MlirOperation operation) {
+MlirValue getUniqueResult(MlirOperation operation) {
auto numResults = mlirOperationGetNumResults(operation);
if (numResults != 1) {
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
@@ -1938,6 +1791,28 @@ PyOpView::PyOpView(const nb::object &operationObject)
: operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
operationObject(operation.getRef().getObject()) {}
+//------------------------------------------------------------------------------
+// PyAsmState
+//------------------------------------------------------------------------------
+
+PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
+ flags = mlirOpPrintingFlagsCreate();
+ // The OpPrintingFlags are not exposed Python side, create locally and
+ // associate lifetime with the state.
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ state = mlirAsmStateCreateForValue(value, flags);
+}
+
+PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) {
+ flags = mlirOpPrintingFlagsCreate();
+ // The OpPrintingFlags are not exposed Python side, create locally and
+ // associate lifetime with the state.
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
+}
+
//------------------------------------------------------------------------------
// PyInsertionPoint.
//------------------------------------------------------------------------------
@@ -2319,420 +2194,318 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-namespace {
-
-/// Python wrapper for MlirBlockArgument.
-class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
- static constexpr const char *pyClassName = "BlockArgument";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyBlockArgument &self) {
- return PyBlock(self.getParentOperation(),
- mlirBlockArgumentGetOwner(self.get()));
- },
- "Returns the block that owns this argument.");
- c.def_prop_ro(
- "arg_number",
- [](PyBlockArgument &self) {
- return mlirBlockArgumentGetArgNumber(self.get());
- },
- "Returns the position of this argument in the block's argument list.");
- c.def(
- "set_type",
- [](PyBlockArgument &self, PyType type) {
- return mlirBlockArgumentSetType(self.get(), type);
- },
- nb::arg("type"), "Sets the type of this block argument.");
- c.def(
- "set_location",
- [](PyBlockArgument &self, PyLocation loc) {
- return mlirBlockArgumentSetLocation(self.get(), loc);
- },
- nb::arg("loc"), "Sets the location of this block argument.");
- }
-};
-
-/// A list of block arguments. Internally, these are stored as consecutive
-/// elements, random access is cheap. The argument list is associated with the
-/// operation that contains the block (detached blocks are not allowed in
-/// Python bindings) and extends its lifetime.
-class PyBlockArgumentList
- : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
-public:
- static constexpr const char *pyClassName = "BlockArgumentList";
- using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
-
- PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumArguments(block) : length,
- step),
- operation(std::move(operation)), block(block) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyBlockArgumentList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all arguments in this argument list.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
+void PyBlockArgument::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyBlockArgument &self) {
+ return PyBlock(self.getParentOperation(),
+ mlirBlockArgumentGetOwner(self.get()));
+ },
+ "Returns the block that owns this argument.");
+ c.def_prop_ro(
+ "arg_number",
+ [](PyBlockArgument &self) {
+ return mlirBlockArgumentGetArgNumber(self.get());
+ },
+ "Returns the position of this argument in the block's argument list.");
+ c.def(
+ "set_type",
+ [](PyBlockArgument &self, PyType type) {
+ return mlirBlockArgumentSetType(self.get(), type);
+ },
+ "type"_a, "Sets the type of this block argument.");
+ c.def(
+ "set_location",
+ [](PyBlockArgument &self, PyLocation loc) {
+ return mlirBlockArgumentSetLocation(self.get(), loc);
+ },
+ "loc"_a, "Sets the location of this block argument.");
+}
- /// Returns the number of arguments in the list.
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirBlockGetNumArguments(block);
- }
+PyBlockArgumentList::PyBlockArgumentList(PyOperationRef operation,
+ MlirBlock block, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumArguments(block) : length, step),
+ operation(std::move(operation)), block(block) {}
+
+void PyBlockArgumentList::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all arguments in this argument list.");
+}
- /// Returns `pos`-the element in the list.
- PyBlockArgument getRawElement(intptr_t pos) {
- MlirValue argument = mlirBlockGetArgument(block, pos);
- return PyBlockArgument(operation, argument);
- }
+intptr_t PyBlockArgumentList::getRawNumElements() {
+ operation->checkValid();
+ return mlirBlockGetNumArguments(block);
+}
- /// Returns a sublist of this list.
- PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockArgumentList(operation, block, startIndex, length, step);
- }
+PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
+ MlirValue argument = mlirBlockGetArgument(block, pos);
+ return PyBlockArgument(operation, argument);
+}
- PyOperationRef operation;
- MlirBlock block;
-};
-
-/// A list of operation operands. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) operand list is associated
-/// with the operation whose operands these are, and thus extends the lifetime
-/// of this operation.
-class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
-public:
- static constexpr const char *pyClassName = "OpOperandList";
- using SliceableT = Sliceable<PyOpOperandList, PyValue>;
-
- PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumOperands(operation->get())
- : length,
- step),
- operation(operation) {}
-
- void dunderSetItem(intptr_t index, PyValue value) {
- index = wrapIndex(index);
- mlirOperationSetOperand(operation->get(), index, value.get());
- }
+PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
+ intptr_t length,
+ intptr_t step) const {
+ return PyBlockArgumentList(operation, block, startIndex, length, step);
+}
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"),
- nb::arg("value"),
- "Sets the operand at the specified index to a new value.");
- }
+PyOpOperandList::PyOpOperandList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumOperands(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+void PyOpOperandList::dunderSetItem(intptr_t index, PyValue value) {
+ index = wrapIndex(index);
+ mlirOperationSetOperand(operation->get(), index, value.get());
+}
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpOperandList, PyValue>;
+void PyOpOperandList::bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem, "index"_a, "value"_a,
+ "Sets the operand at the specified index to a new value.");
+}
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumOperands(operation->get());
- }
+intptr_t PyOpOperandList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+}
- PyValue getRawElement(intptr_t pos) {
- MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
- return PyValue(pyOwner, operand);
- }
+PyValue PyOpOperandList::getRawElement(intptr_t pos) {
+ MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(operand))
+ owner = mlirOpResultGetOwner(operand);
+ else if (mlirValueIsABlockArgument(operand))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ PyOperationRef pyOwner =
+ PyOperation::forOperation(operation->getContext(), owner);
+ return PyValue(pyOwner, operand);
+}
- PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpOperandList(operation, startIndex, length, step);
- }
+PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpOperandList(operation, startIndex, length, step);
+}
- PyOperationRef operation;
-};
-
-/// A list of operation successors. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) successor list is
-/// associated with the operation whose successors these are, and thus extends
-/// the lifetime of this operation.
-class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "OpSuccessors";
-
- PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumSuccessors(operation->get())
- : length,
- step),
- operation(operation) {}
-
- void dunderSetItem(intptr_t index, PyBlock block) {
- index = wrapIndex(index);
- mlirOperationSetSuccessor(operation->get(), index, block.get());
- }
+PyOpSuccessors::PyOpSuccessors(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumSuccessors(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+void PyOpSuccessors::dunderSetItem(intptr_t index, PyBlock block) {
+ index = wrapIndex(index);
+ mlirOperationSetSuccessor(operation->get(), index, block.get());
+}
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"),
- nb::arg("block"), "Sets the successor block at the specified index.");
- }
+void PyOpSuccessors::bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpSuccessors::dunderSetItem, "index"_a, "block"_a,
+ "Sets the successor block at the specified index.");
+}
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpSuccessors, PyBlock>;
+intptr_t PyOpSuccessors::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumSuccessors(operation->get());
+}
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumSuccessors(operation->get());
- }
+PyBlock PyOpSuccessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
+ return PyBlock(operation, block);
+}
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
- return PyBlock(operation, block);
- }
+PyOpSuccessors PyOpSuccessors::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpSuccessors(operation, startIndex, length, step);
+}
- PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpSuccessors(operation, startIndex, length, step);
- }
+PyBlockSuccessors::PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex, intptr_t length,
+ intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumSuccessors(block.get()) : length,
+ step),
+ operation(operation), block(block) {}
+
+intptr_t PyBlockSuccessors::getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumSuccessors(block.get());
+}
- PyOperationRef operation;
-};
-
-/// A list of block successors. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) successor list is
-/// associated with the operation and block whose successors these are, and thus
-/// extends the lifetime of this operation and block.
-class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "BlockSuccessors";
-
- PyBlockSuccessors(PyBlock block, PyOperationRef operation,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumSuccessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockSuccessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumSuccessors(block.get());
- }
+PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+}
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
+PyBlockSuccessors PyBlockSuccessors::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyBlockSuccessors(block, operation, startIndex, length, step);
+}
- PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyBlockSuccessors(block, operation, startIndex, length, step);
- }
+PyBlockPredecessors::PyBlockPredecessors(PyBlock block,
+ PyOperationRef operation,
+ intptr_t startIndex, intptr_t length,
+ intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumPredecessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+intptr_t PyBlockPredecessors::getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumPredecessors(block.get());
+}
- PyOperationRef operation;
- PyBlock block;
-};
-
-/// A list of block predecessors. The (returned) predecessor list is
-/// associated with the operation and block whose predecessors these are, and
-/// thus extends the lifetime of this operation and block.
-///
-/// WARNING: This Sliceable is more expensive than the others here because
-/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
-/// operands) anew for each indexed access.
-class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "BlockPredecessors";
-
- PyBlockPredecessors(PyBlock block, PyOperationRef operation,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumPredecessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockPredecessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumPredecessors(block.get());
- }
+PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+}
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
+PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
+ intptr_t length,
+ intptr_t step) const {
+ return PyBlockPredecessors(block, operation, startIndex, length, step);
+}
- PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockPredecessors(block, operation, startIndex, length, step);
+nb::typed<nb::object, PyAttribute>
+PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
+ MlirAttribute attr =
+ mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr)) {
+ throw nb::key_error("attempt to access a non-existent attribute");
}
+ return PyAttribute(operation->getContext(), attr).maybeDownCast();
+}
- PyOperationRef operation;
- PyBlock block;
-};
-
-/// A list of operation attributes. Can be indexed by name, producing
-/// attributes, or by index, producing named attributes.
-class PyOpAttributeMap {
-public:
- PyOpAttributeMap(PyOperationRef operation)
- : operation(std::move(operation)) {}
-
- nb::typed<nb::object, PyAttribute>
- dunderGetItemNamed(const std::string &name) {
- MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
- throw nb::key_error("attempt to access a non-existent attribute");
- }
- return PyAttribute(operation->getContext(), attr).maybeDownCast();
+PyNamedAttribute PyOpAttributeMap::dunderGetItemIndexed(intptr_t index) {
+ if (index < 0) {
+ index += dunderLen();
}
-
- PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0 || index >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr =
- mlirOperationGetAttribute(operation->get(), index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data,
- mlirIdentifierStr(namedAttr.name).length));
+ if (index < 0 || index >= dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds attribute");
}
+ MlirNamedAttribute namedAttr =
+ mlirOperationGetAttribute(operation->get(), index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data,
+ mlirIdentifierStr(namedAttr.name).length));
+}
- void dunderSetItem(const std::string &name, const PyAttribute &attr) {
- mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
- attr);
- }
+void PyOpAttributeMap::dunderSetItem(const std::string &name,
+ const PyAttribute &attr) {
+ mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
+ attr);
+}
- void dunderDelItem(const std::string &name) {
- int removed = mlirOperationRemoveAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (!removed)
- throw nb::key_error("attempt to delete a non-existent attribute");
- }
+void PyOpAttributeMap::dunderDelItem(const std::string &name) {
+ int removed = mlirOperationRemoveAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (!removed)
+ throw nb::key_error("attempt to delete a non-existent attribute");
+}
- intptr_t dunderLen() {
- return mlirOperationGetNumAttributes(operation->get());
- }
+intptr_t PyOpAttributeMap::dunderLen() {
+ return mlirOperationGetNumAttributes(operation->get());
+}
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
- operation->get(), toMlirStringRef(name)));
- }
+bool PyOpAttributeMap::dunderContains(const std::string &name) {
+ return !mlirAttributeIsNull(
+ mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)));
+}
- static void
- forEachAttr(MlirOperation op,
- llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
- intptr_t n = mlirOperationGetNumAttributes(op);
- for (intptr_t i = 0; i < n; ++i) {
- MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
- MlirStringRef name = mlirIdentifierStr(na.name);
- fn(name, na.attribute);
- }
+void PyOpAttributeMap::forEachAttr(
+ MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
}
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
- .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"),
- "Checks if an attribute with the given name exists in the map.")
- .def("__len__", &PyOpAttributeMap::dunderLen,
- "Returns the number of attributes in the map.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
- nb::arg("name"), "Gets an attribute by name.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
- nb::arg("index"), "Gets a named attribute by index.")
- .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"),
- nb::arg("attr"), "Sets an attribute with the given name.")
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"),
- "Deletes an attribute with the given name.")
- .def(
- "__iter__",
- [](PyOpAttributeMap &self) {
- nb::list keys;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- keys.append(nb::str(name.data, name.length));
- });
- return nb::iter(keys);
- },
- "Iterates over attribute names.")
- .def(
- "keys",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- out.append(nb::str(name.data, name.length));
- });
- return out;
- },
- "Returns a list of attribute names.")
- .def(
- "values",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef, MlirAttribute attr) {
- out.append(PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast());
- });
- return out;
- },
- "Returns a list of attribute values.")
- .def(
- "items",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute attr) {
- out.append(nb::make_tuple(
- nb::str(name.data, name.length),
- PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast()));
- });
- return out;
- },
- "Returns a list of `(name, attribute)` tuples.");
- }
+void PyOpAttributeMap::bind(nb::module_ &m) {
+ nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+ .def("__contains__", &PyOpAttributeMap::dunderContains, "name"_a,
+ "Checks if an attribute with the given name exists in the map.")
+ .def("__len__", &PyOpAttributeMap::dunderLen,
+ "Returns the number of attributes in the map.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, "name"_a,
+ "Gets an attribute by name.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, "index"_a,
+ "Gets a named attribute by index.")
+ .def("__setitem__", &PyOpAttributeMap::dunderSetItem, "name"_a, "attr"_a,
+ "Sets an attribute with the given name.")
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a,
+ "Deletes an attribute with the given name.")
+ .def(
+ "__iter__",
+ [](PyOpAttributeMap &self) {
+ nb::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nb::str(name.data, name.length));
+ });
+ return nb::iter(keys);
+ },
+ "Iterates over attribute names.")
+ .def(
+ "keys",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
+ out.append(nb::str(name.data, name.length));
+ });
+ return out;
+ },
+ "Returns a list of attribute names.")
+ .def(
+ "values",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ },
+ "Returns a list of attribute values.")
+ .def(
+ "items",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nb::make_tuple(
+ nb::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ },
+ "Returns a list of `(name, attribute)` tuples.");
+}
-private:
- PyOperationRef operation;
-};
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+namespace {
// see
// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
@@ -2799,6 +2572,8 @@ PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
#endif // Python 3.9.0b1
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
+
MlirLocation tracebackToLocation(MlirContext ctx) {
size_t framesLimit =
PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
@@ -2882,30 +2657,151 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}
-
} // namespace
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+void populateRoot(nb::module_ &m) {
+ m.attr("T") = nb::type_var("T");
+ m.attr("U") = nb::type_var("U");
+
+ nb::class_<PyGlobals>(m, "_Globals")
+ .def_prop_rw("dialect_search_modules",
+ &PyGlobals::getDialectSearchPrefixes,
+ &PyGlobals::setDialectSearchPrefixes)
+ .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
+ "module_name"_a)
+ .def(
+ "_check_dialect_module_loaded",
+ [](PyGlobals &self, const std::string &dialectNamespace) {
+ return self.loadDialectModule(dialectNamespace);
+ },
+ "dialect_namespace"_a)
+ .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
+ "dialect_namespace"_a, "dialect_class"_a,
+ "Testing hook for directly registering a dialect")
+ .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
+ "operation_name"_a, "operation_class"_a, nb::kw_only(),
+ "replace"_a = false,
+ "Testing hook for directly registering an operation")
+ .def("loc_tracebacks_enabled",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebacksEnabled();
+ })
+ .def("set_loc_tracebacks_enabled",
+ [](PyGlobals &self, bool enabled) {
+ self.getTracebackLoc().setLocTracebacksEnabled(enabled);
+ })
+ .def("loc_tracebacks_frame_limit",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebackFramesLimit();
+ })
+ .def("set_loc_tracebacks_frame_limit",
+ [](PyGlobals &self, std::optional<int> n) {
+ self.getTracebackLoc().setLocTracebackFramesLimit(
+ n.value_or(PyGlobals::TracebackLoc::kMaxFrames));
+ })
+ .def("register_traceback_file_inclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileInclusion(filename);
+ })
+ .def("register_traceback_file_exclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileExclusion(filename);
+ });
+
+ // Aside from making the globals accessible to python, having python manage
+ // it is necessary to make sure it is destroyed (and releases its python
+ // resources) properly.
+ m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
+
+ // Registration decorators.
+ m.def(
+ "register_dialect",
+ [](nb::type_object pyClass) {
+ std::string dialectNamespace =
+ nb::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
+ PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
+ return pyClass;
+ },
+ "dialect_class"_a,
+ "Class decorator for registering a custom Dialect wrapper");
+ m.def(
+ "register_operation",
+ [](const nb::type_object &dialectClass, bool replace) -> nb::object {
+ return nb::cpp_function(
+ [dialectClass,
+ replace](nb::type_object opClass) -> nb::type_object {
+ std::string operationName =
+ nb::cast<std::string>(opClass.attr("OPERATION_NAME"));
+ PyGlobals::get().registerOperationImpl(operationName, opClass,
+ replace);
+ // Dict-stuff the new opClass by name onto the dialect class.
+ nb::object opClassName = opClass.attr("__name__");
+ dialectClass.attr(opClassName) = opClass;
+ return opClass;
+ });
+ },
+ // clang-format off
+ nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
+ "-> typing.Callable[[type[T]], type[T]]"),
+ // clang-format on
+ "dialect_class"_a, nb::kw_only(), "replace"_a = false,
+ "Produce a class decorator for registering an Operation class as part of "
+ "a dialect");
+ m.def(
+ MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
+ [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ return nb::cpp_function([mlirTypeID, replace](
+ nb::callable typeCaster) -> nb::object {
+ PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
+ return typeCaster;
+ });
+ },
+ // clang-format off
+ nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
+ "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
+ // clang-format on
+ "typeid"_a, nb::kw_only(), "replace"_a = false,
+ "Register a type caster for casting MLIR types to custom user types.");
+ m.def(
+ MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
+ [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ return nb::cpp_function(
+ [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
+ PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
+ replace);
+ return valueCaster;
+ });
+ },
+ // clang-format off
+ nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
+ "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
+ // clang-format on
+ "typeid"_a, nb::kw_only(), "replace"_a = false,
+ "Register a value caster for casting MLIR values to custom user values.");
+}
+
//------------------------------------------------------------------------------
// Populates the core exports of the 'ir' submodule.
//------------------------------------------------------------------------------
-
-void mlir::python::populateIRCore(nb::module_ &m) {
- // disable leak warnings which tend to be false positives.
- nb::set_leak_warnings(false);
+void populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
// Enums.
//----------------------------------------------------------------------------
- nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
+ nb::enum_<PyDiagnosticSeverity>(m, "DiagnosticSeverity")
.value("ERROR", MlirDiagnosticError)
.value("WARNING", MlirDiagnosticWarning)
.value("NOTE", MlirDiagnosticNote)
.value("REMARK", MlirDiagnosticRemark);
- nb::enum_<MlirWalkOrder>(m, "WalkOrder")
+ nb::enum_<PyWalkOrder>(m, "WalkOrder")
.value("PRE_ORDER", MlirWalkPreOrder)
.value("POST_ORDER", MlirWalkPostOrder);
- nb::enum_<MlirWalkResult>(m, "WalkResult")
+ nb::enum_<PyWalkResult>(m, "WalkResult")
.value("ADVANCE", MlirWalkResultAdvance)
.value("INTERRUPT", MlirWalkResultInterrupt)
.value("SKIP", MlirWalkResultSkip);
@@ -2961,9 +2857,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"handling.")
.def("__enter__", &PyDiagnosticHandler::contextEnter,
"Enters the diagnostic handler as a context manager.")
- .def("__exit__", &PyDiagnosticHandler::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
+ .def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(),
+ "exc_value"_a.none(), "traceback"_a.none(),
"Exits the diagnostic handler context manager.");
// Expose DefaultThreadPool to python
@@ -3008,8 +2903,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Creates a Context from a capsule wrapping MlirContext.")
.def("__enter__", &PyMlirContext::contextEnter,
"Enters the context as a context manager.")
- .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ .def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(),
+ "exc_value"_a.none(), "traceback"_a.none(),
"Exits the context manager.")
.def_prop_ro_static(
"current",
@@ -3041,7 +2936,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
return PyDialectDescriptor(self.getRef(), dialect);
},
- nb::arg("dialect_name"),
+ "dialect_name"_a,
"Gets or loads a dialect by name, returning its descriptor object.")
.def_prop_rw(
"allow_unregistered_dialects",
@@ -3053,14 +2948,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
"Controls whether unregistered dialects are allowed in this context.")
.def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
- nb::arg("callback"),
+ "callback"_a,
"Attaches a diagnostic handler that will receive callbacks.")
.def(
"enable_multithreading",
[](PyMlirContext &self, bool enable) {
mlirContextEnableMultithreading(self.get(), enable);
},
- nb::arg("enable"),
+ "enable"_a,
R"(
Enables or disables multi-threading support in the context.
@@ -3105,7 +3000,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return mlirContextIsRegisteredOperation(
self.get(), MlirStringRef{name.data(), name.size()});
},
- nb::arg("operation_name"),
+ "operation_name"_a,
R"(
Checks whether an operation with the given name is registered.
@@ -3119,7 +3014,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyMlirContext &self, PyDialectRegistry ®istry) {
mlirContextAppendDialectRegistry(self.get(), registry);
},
- nb::arg("registry"),
+ "registry"_a,
R"(
Appends the contents of a dialect registry to the context.
@@ -3195,7 +3090,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Mapping of PyDialect
//----------------------------------------------------------------------------
nb::class_<PyDialect>(m, "Dialect")
- .def(nb::init<nb::object>(), nb::arg("descriptor"),
+ .def(nb::init<nb::object>(), "descriptor"_a,
"Creates a Dialect from a DialectDescriptor.")
.def_prop_ro(
"descriptor", [](PyDialect &self) { return self.getDescriptor(); },
@@ -3234,8 +3129,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Creates a Location from a capsule wrapping MlirLocation.")
.def("__enter__", &PyLocation::contextEnter,
"Enters the location as a context manager.")
- .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ .def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(),
+ "exc_value"_a.none(), "traceback"_a.none(),
"Exits the location context manager.")
.def(
"__eq__",
@@ -3264,7 +3159,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyLocation(context->getRef(),
mlirLocationUnknownGet(context->get()));
},
- nb::arg("context") = nb::none(),
+ "context"_a = nb::none(),
"Gets a Location representing an unknown location.")
.def_static(
"callsite",
@@ -3279,7 +3174,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyLocation(context->getRef(),
mlirLocationCallSiteGet(callee.get(), caller));
},
- nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
+ "callee"_a, "frames"_a, "context"_a = nb::none(),
"Gets a Location representing a caller and callsite.")
.def("is_a_callsite", mlirLocationIsACallSite,
"Returns True if this location is a CallSiteLoc.")
@@ -3306,8 +3201,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirLocationFileLineColGet(
context->get(), toMlirStringRef(filename), line, col));
},
- nb::arg("filename"), nb::arg("line"), nb::arg("col"),
- nb::arg("context") = nb::none(),
+ "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(),
"Gets a Location representing a file, line and column.")
.def_static(
"file",
@@ -3318,9 +3212,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
context->get(), toMlirStringRef(filename),
startLine, startCol, endLine, endCol));
},
- nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
- nb::arg("end_line"), nb::arg("end_col"),
- nb::arg("context") = nb::none(),
+ "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a,
+ "end_col"_a, "context"_a = nb::none(),
"Gets a Location representing a file, line and column range.")
.def("is_a_file", mlirLocationIsAFileLineColRange,
"Returns True if this location is a FileLineColLoc.")
@@ -3353,8 +3246,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
metadata ? metadata->get() : MlirAttribute{0});
return PyLocation(context->getRef(), location);
},
- nb::arg("locations"), nb::arg("metadata") = nb::none(),
- nb::arg("context") = nb::none(),
+ "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(),
"Gets a Location representing a fused location with optional "
"metadata.")
.def("is_a_fused", mlirLocationIsAFused,
@@ -3384,8 +3276,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
childLoc ? childLoc->get()
: mlirLocationUnknownGet(context->get())));
},
- nb::arg("name"), nb::arg("childLoc") = nb::none(),
- nb::arg("context") = nb::none(),
+ "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(),
"Gets a Location representing a named location with optional child "
"location.")
.def("is_a_name", mlirLocationIsAName,
@@ -3409,7 +3300,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyLocation(context->getRef(),
mlirLocationFromAttribute(attribute));
},
- nb::arg("attribute"), nb::arg("context") = nb::none(),
+ "attribute"_a, "context"_a = nb::none(),
"Gets a Location from a `LocationAttr`.")
.def_prop_ro(
"context",
@@ -3429,7 +3320,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyLocation &self, std::string message) {
mlirEmitError(self, message.c_str());
},
- nb::arg("message"),
+ "message"_a,
R"(
Emits an error diagnostic at this location.
@@ -3474,8 +3365,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
throw MLIRError("Unable to parse module assembly", errors.take());
return PyModule::forModule(module).releaseObject();
},
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
+ "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
.def_static(
"parse",
[](nb::bytes moduleAsm, DefaultingPyMlirContext context)
@@ -3487,8 +3377,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
throw MLIRError("Unable to parse module assembly", errors.take());
return PyModule::forModule(module).releaseObject();
},
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
+ "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
.def_static(
"parseFile",
[](const std::string &path, DefaultingPyMlirContext context)
@@ -3500,8 +3389,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
throw MLIRError("Unable to parse module assembly", errors.take());
return PyModule::forModule(module).releaseObject();
},
- nb::arg("path"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
+ "path"_a, "context"_a = nb::none(), kModuleParseDocstring)
.def_static(
"create",
[](const std::optional<PyLocation> &loc)
@@ -3510,7 +3398,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
},
- nb::arg("loc") = nb::none(), "Creates an empty module.")
+ "loc"_a = nb::none(), "Creates an empty module.")
.def_prop_ro(
"context",
[](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
@@ -3689,8 +3577,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def("print",
nb::overload_cast<PyAsmState &, nb::object, bool>(
&PyOperationBase::print),
- nb::arg("state"), nb::arg("file") = nb::none(),
- nb::arg("binary") = false,
+ "state"_a, "file"_a = nb::none(), "binary"_a = false,
R"(
Prints the assembly form of the operation to a file like object.
@@ -3703,15 +3590,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
bool, bool, bool, bool, bool, bool, nb::object,
bool, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
- nb::arg("binary") = false, nb::arg("skip_regions") = false,
+ "large_elements_limit"_a = nb::none(),
+ "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
+ "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
+ "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
+ "assume_verified"_a = false, "file"_a = nb::none(),
+ "binary"_a = false, "skip_regions"_a = false,
R"(
Prints the assembly form of the operation to a file like object.
@@ -3743,8 +3627,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
file: The file like object to write to. Defaults to sys.stdout.
binary: Whether to write bytes (True) or str (False). Defaults to False.
skip_regions: Whether to skip printing regions. Defaults to False.)")
- .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
- nb::arg("desired_version") = nb::none(),
+ .def("write_bytecode", &PyOperationBase::writeBytecode, "file"_a,
+ "desired_version"_a = nb::none(),
R"(
Write the bytecode form of the operation to a file like object.
@@ -3755,15 +3639,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
The bytecode writer status.)")
.def("get_asm", &PyOperationBase::getAsm,
// Careful: Lots of arguments must match up with get_asm method.
- nb::arg("binary") = false,
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
+ "binary"_a = false, "large_elements_limit"_a = nb::none(),
+ "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
+ "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
+ "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
+ "assume_verified"_a = false, "skip_regions"_a = false,
R"(
Gets the assembly form of the operation with all options available.
@@ -3778,14 +3658,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def("verify", &PyOperationBase::verify,
"Verify the operation. Raises MLIRError if verification fails, and "
"returns true otherwise.")
- .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
+ .def("move_after", &PyOperationBase::moveAfter, "other"_a,
"Puts self immediately after the other operation in its parent "
"block.")
- .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
+ .def("move_before", &PyOperationBase::moveBefore, "other"_a,
"Puts self immediately before the other operation in its parent "
"block.")
- .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
- nb::arg("other"),
+ .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, "other"_a,
R"(
Checks if this operation is before another in the same block.
@@ -3800,7 +3679,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
return self.getOperation().clone(ip);
},
- nb::arg("ip") = nb::none(),
+ "ip"_a = nb::none(),
R"(
Creates a deep copy of the operation.
@@ -3838,8 +3717,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
Note:
After erasing, any Python references to the operation become invalid.)")
- .def("walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = MlirWalkPostOrder,
+ .def("walk", &PyOperationBase::walk, "callback"_a,
+ "walk_order"_a = PyWalkOrder::MlirWalkPostOrder,
// clang-format off
nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
// clang-format on
@@ -3877,11 +3756,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
successors, regions, pyLoc, maybeIp,
inferType);
},
- nb::arg("name"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- nb::arg("infer_type") = false,
+ "name"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
+ "attributes"_a = nb::none(), "successors"_a = nb::none(),
+ "regions"_a = 0, "loc"_a = nb::none(), "ip"_a = nb::none(),
+ "infer_type"_a = false,
R"(
Creates a new operation.
@@ -3905,8 +3783,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyOperation::parse(context->getRef(), sourceStr, sourceName)
->createOpView();
},
- nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
- nb::arg("context") = nb::none(),
+ "source"_a, nb::kw_only(), "source_name"_a = "",
+ "context"_a = nb::none(),
"Parses an operation. Supports both text assembly format and binary "
"bytecode format.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
@@ -3952,8 +3830,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
auto opViewClass =
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
- .def(nb::init<nb::typed<nb::object, PyOperation>>(),
- nb::arg("operation"))
+ .def(nb::init<nb::typed<nb::object, PyOperation>>(), "operation"_a)
.def(
"__init__",
[](PyOpView *self, std::string_view name,
@@ -3972,14 +3849,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
resultSegmentSpecObj, resultTypeList, operandList,
attributes, successors, regions, pyLoc, maybeIp));
},
- nb::arg("name"), nb::arg("opRegionSpec"),
- nb::arg("operandSegmentSpecObj") = nb::none(),
- nb::arg("resultSegmentSpecObj") = nb::none(),
- nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
- nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(),
- nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
- nb::arg("ip") = nb::none())
+ "name"_a, "opRegionSpec"_a,
+ "operandSegmentSpecObj"_a = nb::none(),
+ "resultSegmentSpecObj"_a = nb::none(), "results"_a = nb::none(),
+ "operands"_a = nb::none(), "attributes"_a = nb::none(),
+ "successors"_a = nb::none(), "regions"_a = nb::none(),
+ "loc"_a = nb::none(), "ip"_a = nb::none())
.def_prop_ro(
"operation",
[](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
@@ -4025,10 +3900,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
operandList, attributes, successors,
regions, pyLoc, maybeIp);
},
- nb::arg("cls"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
+ "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
+ "attributes"_a = nb::none(), "successors"_a = nb::none(),
+ "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
"Builds a specific, generated OpView based on class level attributes.");
opViewClass.attr("parse") = classmethod(
[](const nb::object &cls, const std::string &sourceStr,
@@ -4052,8 +3926,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
parsedOpName + "'");
return PyOpView::constructDerived(cls, parsed.getObject());
},
- nb::arg("cls"), nb::arg("source"), nb::kw_only(),
- nb::arg("source_name") = "", nb::arg("context") = nb::none(),
+ "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
+ "context"_a = nb::none(),
"Parses a specific, generated OpView based on class level attributes.");
//----------------------------------------------------------------------------
@@ -4136,7 +4010,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyBlock &self, unsigned index) {
return mlirBlockEraseArgument(self.get(), index);
},
- nb::arg("index"),
+ "index"_a,
R"(
Erases the argument at the specified index.
@@ -4157,8 +4031,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirRegionInsertOwnedBlock(parent, 0, block);
return PyBlock(parent.getParentOperation(), block);
},
- nb::arg("parent"), nb::arg("arg_types") = nb::list(),
- nb::arg("arg_locs") = std::nullopt,
+ "parent"_a, "arg_types"_a = nb::list(), "arg_locs"_a = std::nullopt,
"Creates and returns a new Block at the beginning of the given "
"region (with given argument types and locations).")
.def(
@@ -4169,7 +4042,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirBlockDetach(b);
mlirRegionAppendOwnedBlock(region.get(), b);
},
- nb::arg("region"),
+ "region"_a,
R"(
Appends this block to a region.
@@ -4188,8 +4061,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
return PyBlock(self.getParentOperation(), block);
},
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
+ "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
"Creates and returns a new Block before this block "
"(with given argument types and locations).")
.def(
@@ -4203,8 +4075,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
return PyBlock(self.getParentOperation(), block);
},
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
+ "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
"Creates and returns a new Block after this block "
"(with given argument types and locations).")
.def(
@@ -4253,7 +4124,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
operation.getOperation().setAttached(
self.getParentOperation().getObject());
},
- nb::arg("operation"),
+ "operation"_a,
R"(
Appends an operation to this block.
@@ -4279,13 +4150,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyInsertionPoint>(m, "InsertionPoint")
- .def(nb::init<PyBlock &>(), nb::arg("block"),
+ .def(nb::init<PyBlock &>(), "block"_a,
"Inserts after the last operation but still inside the block.")
.def("__enter__", &PyInsertionPoint::contextEnter,
"Enters the insertion point as a context manager.")
- .def("__exit__", &PyInsertionPoint::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
+ .def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(),
+ "exc_value"_a.none(), "traceback"_a.none(),
"Exits the insertion point context manager.")
.def_prop_ro_static(
"current",
@@ -4298,10 +4168,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::sig("def current(/) -> InsertionPoint"),
"Gets the InsertionPoint bound to the current thread or raises "
"ValueError if none has been set.")
- .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
+ .def(nb::init<PyOperationBase &>(), "beforeOperation"_a,
"Inserts before a referenced operation.")
- .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
- nb::arg("block"),
+ .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, "block"_a,
R"(
Creates an insertion point at the beginning of a block.
@@ -4311,7 +4180,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
Returns:
An InsertionPoint at the block's beginning.)")
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
- nb::arg("block"),
+ "block"_a,
R"(
Creates an insertion point before a block's terminator.
@@ -4323,7 +4192,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
Raises:
ValueError: If the block has no terminator.)")
- .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+ .def_static("after", &PyInsertionPoint::after, "operation"_a,
R"(
Creates an insertion point immediately after an operation.
@@ -4332,7 +4201,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
Returns:
An InsertionPoint after the operation.)")
- .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
+ .def("insert", &PyInsertionPoint::insert, "operation"_a,
R"(
Inserts an operation at this insertion point.
@@ -4360,7 +4229,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::class_<PyAttribute>(m, "Attribute")
// Delegate to the PyAttribute copy constructor, which will also lifetime
// extend the backing context which owns the MlirAttribute.
- .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
+ .def(nb::init<PyAttribute &>(), "cast_from_type"_a,
"Casts the passed attribute to the generic `Attribute`.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
"Gets a capsule wrapping the MlirAttribute.")
@@ -4378,7 +4247,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
throw MLIRError("Unable to parse attribute", errors.take());
return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
},
- nb::arg("asm"), nb::arg("context") = nb::none(),
+ "asm"_a, "context"_a = nb::none(),
"Parses an attribute from an assembly form. Raises an `MLIRError` on "
"failure.")
.def_prop_ro(
@@ -4504,7 +4373,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::class_<PyType>(m, "Type")
// Delegate to the PyType copy constructor, which will also lifetime
// extend the backing context which owns the MlirType.
- .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
+ .def(nb::init<PyType &>(), "cast_from_type"_a,
"Casts the passed type to the generic `Type`.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
"Gets a capsule wrapping the `MlirType`.")
@@ -4521,7 +4390,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
throw MLIRError("Unable to parse type", errors.take());
return PyType(context.get()->getRef(), type).maybeDownCast();
},
- nb::arg("asm"), nb::arg("context") = nb::none(),
+ "asm"_a, "context"_a = nb::none(),
R"(
Parses the assembly form of a type.
@@ -4539,7 +4408,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Compares two types for equality.")
.def(
"__eq__", [](PyType &self, nb::object &other) { return false; },
- nb::arg("other").none(),
+ "other"_a.none(),
"Compares type with non-type object (always returns False).")
.def(
"__hash__",
@@ -4625,11 +4494,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of Value.
//----------------------------------------------------------------------------
- m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
+ m.attr("_T") = nb::type_var("_T", "bound"_a = m.attr("Type"));
nb::class_<PyValue>(m, "Value", nb::is_generic(),
nb::sig("class Value(Generic[_T])"))
- .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
+ .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), "value"_a,
"Creates a Value reference from another `Value`.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
"Gets a capsule wrapping the `MlirValue`.")
@@ -4724,8 +4593,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirAsmStateDestroy(valueState);
return printAccum.join();
},
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
+ "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
R"(
Returns the string form of value as an operand.
@@ -4745,7 +4613,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
printAccum.getUserData());
return printAccum.join();
},
- nb::arg("state"),
+ "state"_a,
"Returns the string form of value as an operand (i.e., the ValueID).")
.def_prop_ro(
"type",
@@ -4760,7 +4628,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyValue &self, const PyType &type) {
mlirValueSetType(self.get(), type);
},
- nb::arg("type"), "Sets the type of the value.",
+ "type"_a, "Sets the type of the value.",
nb::sig("def set_type(self, type: _T)"))
.def(
"replace_all_uses_with",
@@ -4775,8 +4643,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
MlirOperation exceptedUser = exception.get();
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
},
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
+ "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](PyValue &self, PyValue &with, const nb::list &exceptions) {
@@ -4790,16 +4657,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
self, with, static_cast<intptr_t>(exceptionOps.size()),
exceptionOps.data());
},
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
+ "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](PyValue &self, PyValue &with, PyOperation &exception) {
MlirOperation exceptedUser = exception.get();
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
},
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
+ "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](PyValue &self, PyValue &with,
@@ -4812,8 +4677,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
self, with, static_cast<intptr_t>(exceptionOps.size()),
exceptionOps.data());
},
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
+ "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
.def(
MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) -> nb::typed<nb::object, PyValue> {
@@ -4834,16 +4698,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyOpOperand::bind(m);
nb::class_<PyAsmState>(m, "AsmState")
- .def(nb::init<PyValue &, bool>(), nb::arg("value"),
- nb::arg("use_local_scope") = false,
+ .def(nb::init<PyValue &, bool>(), "value"_a, "use_local_scope"_a = false,
R"(
Creates an `AsmState` for consistent SSA value naming.
Args:
value: The value to create state for.
use_local_scope: Whether to use local scope for naming.)")
- .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
- nb::arg("use_local_scope") = false,
+ .def(nb::init<PyOperationBase &, bool>(), "op"_a,
+ "use_local_scope"_a = false,
R"(
Creates an AsmState for consistent SSA value naming.
@@ -4881,7 +4744,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
Raises:
KeyError: If the symbol is not found.)")
- .def("insert", &PySymbolTable::insert, nb::arg("operation"),
+ .def("insert", &PySymbolTable::insert, "operation"_a,
R"(
Inserts a symbol operation into the symbol table.
@@ -4893,7 +4756,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
Raises:
ValueError: If the operation does not have a symbol name.)")
- .def("erase", &PySymbolTable::erase, nb::arg("operation"),
+ .def("erase", &PySymbolTable::erase, "operation"_a,
R"(
Erases a symbol operation from the symbol table.
@@ -4912,26 +4775,22 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
"Checks if a symbol with the given name exists in the table.")
// Static helpers.
- .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
- nb::arg("symbol"), nb::arg("name"),
- "Sets the symbol name for a symbol operation.")
- .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
- nb::arg("symbol"),
+ .def_static("set_symbol_name", &PySymbolTable::setSymbolName, "symbol"_a,
+ "name"_a, "Sets the symbol name for a symbol operation.")
+ .def_static("get_symbol_name", &PySymbolTable::getSymbolName, "symbol"_a,
"Gets the symbol name from a symbol operation.")
- .def_static("get_visibility", &PySymbolTable::getVisibility,
- nb::arg("symbol"),
+ .def_static("get_visibility", &PySymbolTable::getVisibility, "symbol"_a,
"Gets the visibility attribute of a symbol operation.")
- .def_static("set_visibility", &PySymbolTable::setVisibility,
- nb::arg("symbol"), nb::arg("visibility"),
+ .def_static("set_visibility", &PySymbolTable::setVisibility, "symbol"_a,
+ "visibility"_a,
"Sets the visibility attribute of a symbol operation.")
.def_static("replace_all_symbol_uses",
- &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
- nb::arg("new_symbol"), nb::arg("from_op"),
+ &PySymbolTable::replaceAllSymbolUses, "old_symbol"_a,
+ "new_symbol"_a, "from_op"_a,
"Replaces all uses of a symbol with a new symbol name within "
"the given operation.")
.def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
- nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
- nb::arg("callback"),
+ "from_op"_a, "all_sym_uses_visible"_a, "callback"_a,
"Walks symbol tables starting from an operation with a "
"callback function.");
@@ -4956,18 +4815,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
-
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so instead
- // the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 31d4798ffb906..09112d4989ae4 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,11 +12,11 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -25,7 +25,7 @@ namespace nb = nanobind;
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
constexpr static const char *constructorDoc =
R"(Creates an interface from a given operation/opview object or from a
subclass of OpView. Raises ValueError if the operation does not implement the
@@ -469,6 +469,6 @@ void populateIRInterfaces(nb::module_ &m) {
PyShapedTypeComponents::bind(m);
PyInferShapedTypeOpInterface::bind(m);
}
-
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 34c5b8dd86a66..7350046f428c7 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -7,26 +7,27 @@
//===----------------------------------------------------------------------===//
// clang-format off
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRTypes.h"
// clang-format on
#include <optional>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
using llvm::Twine;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Checks whether the given type is an integer or float type.
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
@@ -509,10 +510,12 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
// Shaped Type Interface - ShapedType
-void mlir::PyShapedType::bindDerived(ClassTy &c) {
+void PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
[](PyShapedType &self) -> nb::typed<nb::object, PyType> {
@@ -617,17 +620,18 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
"shaped types.");
}
-void mlir::PyShapedType::requireHasRank() {
+void PyShapedType::requireHasRank() {
if (!mlirShapedTypeHasRank(*this)) {
throw nb::value_error(
"calling this method requires that the type has a rank.");
}
}
-const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
- mlirTypeIsAShaped;
+const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Vector Type subclass - VectorType.
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
@@ -1099,10 +1103,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
}
};
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
/// Opaque Type subclass - OpaqueType.
class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
public:
@@ -1142,9 +1142,14 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
-void mlir::python::populateIRTypes(nb::module_ &m) {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
@@ -1176,3 +1181,6 @@ void mlir::python::populateIRTypes(nb::module_ &m) {
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index ba767ad6692cf..b2c9380bc1d73 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,143 +6,37 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace nb::literals;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void populateIRAffine(nb::module_ &m);
+void populateIRAttributes(nb::module_ &m);
+void populateIRInterfaces(nb::module_ &m);
+void populateIRTypes(nb::module_ &m);
+void populateIRCore(nb::module_ &m);
+void populateRoot(nb::module_ &m);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
-
NB_MODULE(_mlir, m) {
- m.doc() = "MLIR Python Native Extension";
- m.attr("T") = nb::type_var("T");
- m.attr("U") = nb::type_var("U");
-
- nb::class_<PyGlobals>(m, "_Globals")
- .def_prop_rw("dialect_search_modules",
- &PyGlobals::getDialectSearchPrefixes,
- &PyGlobals::setDialectSearchPrefixes)
- .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
- "module_name"_a)
- .def(
- "_check_dialect_module_loaded",
- [](PyGlobals &self, const std::string &dialectNamespace) {
- return self.loadDialectModule(dialectNamespace);
- },
- "dialect_namespace"_a)
- .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
- "dialect_namespace"_a, "dialect_class"_a,
- "Testing hook for directly registering a dialect")
- .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- "operation_name"_a, "operation_class"_a, nb::kw_only(),
- "replace"_a = false,
- "Testing hook for directly registering an operation")
- .def("loc_tracebacks_enabled",
- [](PyGlobals &self) {
- return self.getTracebackLoc().locTracebacksEnabled();
- })
- .def("set_loc_tracebacks_enabled",
- [](PyGlobals &self, bool enabled) {
- self.getTracebackLoc().setLocTracebacksEnabled(enabled);
- })
- .def("loc_tracebacks_frame_limit",
- [](PyGlobals &self) {
- return self.getTracebackLoc().locTracebackFramesLimit();
- })
- .def("set_loc_tracebacks_frame_limit",
- [](PyGlobals &self, std::optional<int> n) {
- self.getTracebackLoc().setLocTracebackFramesLimit(
- n.value_or(PyGlobals::TracebackLoc::kMaxFrames));
- })
- .def("register_traceback_file_inclusion",
- [](PyGlobals &self, const std::string &filename) {
- self.getTracebackLoc().registerTracebackFileInclusion(filename);
- })
- .def("register_traceback_file_exclusion",
- [](PyGlobals &self, const std::string &filename) {
- self.getTracebackLoc().registerTracebackFileExclusion(filename);
- });
-
- // Aside from making the globals accessible to python, having python manage
- // it is necessary to make sure it is destroyed (and releases its python
- // resources) properly.
- m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
-
- // Registration decorators.
- m.def(
- "register_dialect",
- [](nb::type_object pyClass) {
- std::string dialectNamespace =
- nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
- PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
- return pyClass;
- },
- "dialect_class"_a,
- "Class decorator for registering a custom Dialect wrapper");
- m.def(
- "register_operation",
- [](const nb::type_object &dialectClass, bool replace) -> nb::object {
- return nb::cpp_function(
- [dialectClass,
- replace](nb::type_object opClass) -> nb::type_object {
- std::string operationName =
- nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
- PyGlobals::get().registerOperationImpl(operationName, opClass,
- replace);
- // Dict-stuff the new opClass by name onto the dialect class.
- nb::object opClassName = opClass.attr("__name__");
- dialectClass.attr(opClassName) = opClass;
- return opClass;
- });
- },
- // clang-format off
- nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
- "-> typing.Callable[[type[T]], type[T]]"),
- // clang-format on
- "dialect_class"_a, nb::kw_only(), "replace"_a = false,
- "Produce a class decorator for registering an Operation class as part of "
- "a dialect");
- m.def(
- MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
- return nb::cpp_function([mlirTypeID, replace](
- nb::callable typeCaster) -> nb::object {
- PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
- return typeCaster;
- });
- },
- // clang-format off
- nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
- "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
- // clang-format on
- "typeid"_a, nb::kw_only(), "replace"_a = false,
- "Register a type caster for casting MLIR types to custom user types.");
- m.def(
- MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
- return nb::cpp_function(
- [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
- PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
- replace);
- return valueCaster;
- });
- },
- // clang-format off
- nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
- "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
- // clang-format on
- "typeid"_a, nb::kw_only(), "replace"_a = false,
- "Register a value caster for casting MLIR values to custom user values.");
+ // disable leak warnings which tend to be false positives.
+ nb::set_leak_warnings(false);
+ m.doc() = "MLIR Python Native Extension";
+ populateRoot(m);
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
populateIRCore(irModule);
@@ -158,4 +52,18 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
+ nanobind::register_exception_translator(
+ [](const std::exception_ptr &p, void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nanobind::object obj =
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdf01fff28cf2..b4a256d847ad5 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,9 +8,9 @@
#include "Pass.h"
-#include "Globals.h"
-#include "IRModule.h"
#include "mlir-c/Pass.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
@@ -19,9 +19,11 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Owning Wrapper around a PassManager.
class PyPassManager {
@@ -53,23 +55,29 @@ class PyPassManager {
MlirPassManager passManager;
};
-} // namespace
+enum PyMlirPassDisplayMode : std::underlying_type_t<MlirPassDisplayMode> {
+ MLIR_PASS_DISPLAY_MODE_LIST = MLIR_PASS_DISPLAY_MODE_LIST,
+ MLIR_PASS_DISPLAY_MODE_PIPELINE = MLIR_PASS_DISPLAY_MODE_PIPELINE
+};
+
+struct PyMlirExternalPass : MlirExternalPass {};
/// Create the `mlir.passmanager` here.
-void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+void populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of enumerated types
//----------------------------------------------------------------------------
- nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode")
+ nb::enum_<PyMlirPassDisplayMode>(m, "PassDisplayMode")
.value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
.value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
//----------------------------------------------------------------------------
// Mapping of MlirExternalPass
//----------------------------------------------------------------------------
- nb::class_<MlirExternalPass>(m, "ExternalPass")
- .def("signal_pass_failure",
- [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+ nb::class_<PyMlirExternalPass>(m, "ExternalPass")
+ .def("signal_pass_failure", [](PyMlirExternalPass pass) {
+ mlirExternalPassSignalFailure(pass);
+ });
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
@@ -148,11 +156,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"Enable pass timing.")
.def(
"enable_statistics",
- [](PyPassManager &passManager, MlirPassDisplayMode displayMode) {
- mlirPassManagerEnableStatistics(passManager.get(), displayMode);
+ [](PyPassManager &passManager, PyMlirPassDisplayMode displayMode) {
+ mlirPassManagerEnableStatistics(
+ passManager.get(),
+ static_cast<MlirPassDisplayMode>(displayMode));
},
- "displayMode"_a =
- MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE,
+ "displayMode"_a = MLIR_PASS_DISPLAY_MODE_PIPELINE,
"Enable pass statistics.")
.def_static(
"parse",
@@ -211,7 +220,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
};
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::handle(static_cast<PyObject *>(userData))(op, pass);
+ nb::handle(static_cast<PyObject *>(userData))(
+ op, PyMlirExternalPass{pass.ptr});
};
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
@@ -267,3 +277,6 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..1a311666ebecd 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -9,12 +9,13 @@
#ifndef MLIR_BINDINGS_PYTHON_PASS_H
#define MLIR_BINDINGS_PYTHON_PASS_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populatePassManagerSubmodule(nanobind::module_ &m);
+}
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index dc6dc7f7c9b72..c282f4b6996e5 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,10 +8,10 @@
#include "Rewrite.h"
-#include "IRModule.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
@@ -22,9 +22,11 @@
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyPatternRewriter {
public:
@@ -60,6 +62,8 @@ class PyPatternRewriter {
PyMlirContextRef ctx;
};
+struct PyMlirPDLResultList : MlirPDLResultList {};
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -118,7 +122,7 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(PyPatternRewriter(rewriter), results,
+ f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
@@ -133,7 +137,7 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(PyPatternRewriter(rewriter), results,
+ f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
@@ -223,6 +227,25 @@ class PyRewritePatternSet {
MlirContext ctx;
};
+enum PyGreedyRewriteStrictness : std::underlying_type_t<
+ MlirGreedyRewriteStrictness> {
+ MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP,
+ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS =
+ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS,
+ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS =
+ MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS,
+};
+
+enum PyGreedySimplifyRegionLevel : std::underlying_type_t<
+ MlirGreedySimplifyRegionLevel> {
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED =
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED,
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL =
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL,
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE =
+ MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
+};
+
/// Owning Wrapper around a GreedyRewriteDriverConfig.
class PyGreedyRewriteDriverConfig {
public:
@@ -255,12 +278,14 @@ class PyGreedyRewriteDriverConfig {
mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
}
- void setStrictness(MlirGreedyRewriteStrictness strictness) {
- mlirGreedyRewriteDriverConfigSetStrictness(config, strictness);
+ void setStrictness(PyGreedyRewriteStrictness strictness) {
+ mlirGreedyRewriteDriverConfigSetStrictness(
+ config, static_cast<MlirGreedyRewriteStrictness>(strictness));
}
- void setRegionSimplificationLevel(MlirGreedySimplifyRegionLevel level) {
- mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(config, level);
+ void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
+ mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+ config, static_cast<MlirGreedySimplifyRegionLevel>(level));
}
void enableConstantCSE(bool enable) {
@@ -283,12 +308,14 @@ class PyGreedyRewriteDriverConfig {
return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
}
- MlirGreedyRewriteStrictness getStrictness() {
- return mlirGreedyRewriteDriverConfigGetStrictness(config);
+ PyGreedyRewriteStrictness getStrictness() {
+ return static_cast<PyGreedyRewriteStrictness>(
+ mlirGreedyRewriteDriverConfigGetStrictness(config));
}
- MlirGreedySimplifyRegionLevel getRegionSimplificationLevel() {
- return mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config);
+ PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
+ return static_cast<PyGreedySimplifyRegionLevel>(
+ mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
}
bool isConstantCSEEnabled() {
@@ -299,22 +326,19 @@ class PyGreedyRewriteDriverConfig {
MlirGreedyRewriteDriverConfig config;
};
-} // namespace
-
/// Create the `mlir.rewrite` here.
-void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+void populateRewriteSubmodule(nb::module_ &m) {
// Enum definitions
- nb::enum_<MlirGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
+ nb::enum_<PyGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
.value("ANY_OP", MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP)
.value("EXISTING_AND_NEW_OPS",
MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS)
.value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
- nb::enum_<MlirGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
+ nb::enum_<PyGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
.value("DISABLED", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED)
.value("NORMAL", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL)
.value("AGGRESSIVE", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE);
-
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
@@ -403,10 +427,10 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
- nb::class_<MlirPDLResultList>(m, "PDLResultList")
+ nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
.def(
"append",
- [](MlirPDLResultList results, const PyValue &value) {
+ [](PyMlirPDLResultList results, const PyValue &value) {
mlirPDLResultListPushBackValue(results, value);
},
// clang-format off
@@ -415,7 +439,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyOperation &op) {
+ [](PyMlirPDLResultList results, const PyOperation &op) {
mlirPDLResultListPushBackOperation(results, op);
},
// clang-format off
@@ -424,7 +448,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyType &type) {
+ [](PyMlirPDLResultList results, const PyType &type) {
mlirPDLResultListPushBackType(results, type);
},
// clang-format off
@@ -433,7 +457,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyAttribute &attr) {
+ [](PyMlirPDLResultList results, const PyAttribute &attr) {
mlirPDLResultListPushBackAttribute(results, attr);
},
// clang-format off
@@ -443,9 +467,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
- [](PyPDLPatternModule &self, MlirModule module) {
- new (&self)
- PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
+ [](PyPDLPatternModule &self, PyModule &module) {
+ new (&self) PyPDLPatternModule(
+ mlirPDLPatternModuleFromModule(module.get()));
},
// clang-format off
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
@@ -533,22 +557,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
- .def(
- "apply_patterns_and_fold_greedily",
- [](PyModule &module, MlirFrozenRewritePatternSet set) {
- auto status = mlirApplyPatternsAndFoldGreedily(
- module.get(), set, mlirGreedyRewriteDriverConfigCreate());
- if (mlirLogicalResultIsFailure(status))
- throw std::runtime_error(
- "pattern application failed to converge");
- },
- "module"_a, "set"_a,
- // clang-format off
- nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
- "Applys the given patterns to the given module greedily while "
- "folding "
- "results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
@@ -565,21 +573,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
// clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.")
- .def(
- "apply_patterns_and_fold_greedily",
- [](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
- auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
- op.getOperation(), set, mlirGreedyRewriteDriverConfigCreate());
- if (mlirLogicalResultIsFailure(status))
- throw std::runtime_error(
- "pattern application failed to converge");
- },
- "op"_a, "set"_a,
- // clang-format off
- nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
- "Applys the given patterns to the given op greedily while folding "
- "results.")
.def(
"walk_and_apply_patterns",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
@@ -592,3 +585,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"Applies the given patterns to the given op by a fast walk-based "
"driver.");
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index ae89e2b9589f1..d287f19187708 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,13 +9,13 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateRewriteSubmodule(nanobind::module_ &m);
-
+}
} // namespace python
} // namespace mlir
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 1e9f1e11d4d06..4a9fb127ee08c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,6 +3,8 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
+set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
+set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
################################################################################
# Structural groupings.
@@ -524,7 +526,6 @@ declare_mlir_dialect_python_bindings(
# dependencies.
################################################################################
-set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
declare_mlir_python_extension(MLIRPythonExtension.Core
MODULE_NAME _mlir
ADD_TO_PARENT MLIRPythonSources.Core
@@ -533,18 +534,13 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
MainModule.cpp
IRAffine.cpp
IRAttributes.cpp
- IRCore.cpp
IRInterfaces.cpp
- IRModule.cpp
IRTypes.cpp
Pass.cpp
Rewrite.cpp
# Headers must be included explicitly so they are installed.
- Globals.h
- IRModule.h
Pass.h
- NanobindUtils.h
Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
@@ -752,8 +748,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectSMT.cpp
- # Headers must be included explicitly so they are installed.
- NanobindUtils.h
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
@@ -790,7 +784,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Nanobind
MODULE_NAME _mlirDialectsAMDGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
- PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectAMDGPU.cpp
PRIVATE_LINK_LIBS
@@ -847,6 +840,16 @@ if(MLIR_INCLUDE_TESTS)
)
endif()
+declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
+ _PRIVATE_SUPPORT_LIB
+ MODULE_NAME MLIRPythonSupport
+ ADD_TO_PARENT MLIRPythonSources.Core
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ SOURCES
+ IRCore.cpp
+ Globals.cpp
+)
+
################################################################################
# Common CAPI dependency DSO.
# All python extensions must link through one DSO which exports the CAPI, and
@@ -860,7 +863,6 @@ endif()
# once ready.
################################################################################
-set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
add_mlir_python_common_capi_library(MLIRPythonCAPI
INSTALL_COMPONENT MLIRPythonModules
INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index b60347ba687d0..46f170579a977 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -37,6 +37,8 @@
# CHECK: %[[V0:.*]] = standalone.foo %[[C2]] : i32
# CHECK: }
+# CHECK: !standalone.custom<"foo">
+
# CHECK: Testing mlir package
# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 7bba20931e675..9c0966d2d8798 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -586,9 +586,18 @@ def testCustomAttribute():
try:
TestAttr(42)
except TypeError as e:
- assert "Expected an MLIR object (got 42)" in str(e)
- except ValueError as e:
- assert "Cannot cast attribute to TestAttr (from 42)" in str(e)
+ assert (
+ "__init__(): incompatible function arguments. The following argument types are supported"
+ in str(e)
+ )
+ assert (
+ "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None"
+ in str(e)
+ )
+ assert (
+ "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int"
+ in str(e)
+ )
else:
raise
@@ -613,12 +622,6 @@ def testCustomType():
b = TestType(a)
# Instance custom types should have typeids
assert isinstance(b.typeid, TypeID)
- # Subclasses of ir.Type should not have a static_typeid
- # CHECK: 'TestType' object has no attribute 'static_typeid'
- try:
- b.static_typeid
- except AttributeError as e:
- print(e)
i8 = IntegerType.get_signless(8)
try:
@@ -633,9 +636,18 @@ def testCustomType():
try:
TestType(42)
except TypeError as e:
- assert "Expected an MLIR object (got 42)" in str(e)
- except ValueError as e:
- assert "Cannot cast type to TestType (from 42)" in str(e)
+ assert (
+ "__init__(): incompatible function arguments. The following argument types are supported"
+ in str(e)
+ )
+ assert (
+ "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None"
+ in str(e)
+ )
+ assert (
+ "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int"
+ in str(e)
+ )
else:
raise
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a497fcccf13d7..43573cbc305fa 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -14,6 +14,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
@@ -26,6 +27,49 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
+struct PyTestType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirPythonTestTestTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
+ return PyTestType(context->getRef(),
+ mlirPythonTestTestTypeGet(context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+class PyTestAttr
+ : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
+ PyTestAttr> {
+public:
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsAPythonTestTestAttribute;
+ static constexpr const char *pyClassName = "TestAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirPythonTestTestAttributeGetTypeID;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
+ return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
+ context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
@@ -65,30 +109,8 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
// clang-format on
- mlir_attribute_subclass(m, "TestAttr",
- mlirAttributeIsAPythonTestTestAttribute,
- mlirPythonTestTestAttributeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPythonTestTestAttributeGet(ctx));
- },
- // clang-format off
- nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("context").none() = nb::none());
-
- mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
- mlirPythonTestTestTypeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPythonTestTestTypeGet(ctx));
- },
- // clang-format off
- nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("context").none() = nb::none());
+ PyTestAttr::bind(m);
+ PyTestType::bind(m);
auto typeCls =
mlir_type_subclass(m, "TestIntegerRankedTensorType",
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 41223b72a7d10..2d98c9ce376b1 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1172,11 +1172,11 @@ PYBIND11_FEATURES = [
filegroup(
name = "MLIRBindingsPythonSourceFiles",
srcs = [
+ "lib/Bindings/Python/Globals.cpp",
"lib/Bindings/Python/IRAffine.cpp",
"lib/Bindings/Python/IRAttributes.cpp",
"lib/Bindings/Python/IRCore.cpp",
"lib/Bindings/Python/IRInterfaces.cpp",
- "lib/Bindings/Python/IRModule.cpp",
"lib/Bindings/Python/IRTypes.cpp",
"lib/Bindings/Python/Pass.cpp",
"lib/Bindings/Python/Rewrite.cpp",
More information about the Mlir-commits
mailing list