[llvm] [mlir] [mlir][Python] port dialect extensions to use core PyConcreteType, PyConcreteAttribute (PR #174156)
Maksim Levental via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 1 11:13:45 PST 2026
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/174156
depends on https://github.com/llvm/llvm-project/pull/174118
>From 9ba8371d08800ee328bc58afee34e4e54048260f Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 30 Dec 2025 22:09:33 -0800
Subject: [PATCH 01/38] [mlir][python] fix flatnamespace
---
mlir/cmake/modules/AddMLIRPython.cmake | 6 +++
mlir/examples/standalone/CMakeLists.txt | 4 +-
mlir/examples/standalone/test/lit.cfg.py | 12 ++---
.../standalone/test/python/smoketest.py | 53 ++++++++++++++++---
mlir/test/Examples/standalone/test.toy | 2 +
mlir/test/Examples/standalone/test.wheel.toy | 14 ++++-
6 files changed, 74 insertions(+), 17 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index ca90151e76268..8c301faf0941a 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -766,6 +766,12 @@ function(add_mlir_python_extension libname extname)
FREE_THREADED
${ARG_SOURCES}
)
+ if(APPLE)
+ # In llvm/cmake/modules/HandleLLVMOptions.cmake:268 we set -Wl,-flat_namespace which breaks
+ # the default name spacing on MacOS and causes "cross-wired" symbol resolution when multiple
+ # bindings packages are loaded.
+ target_link_options(${libname} PRIVATE "LINKER:-twolevel_namespace")
+ endif()
if (NOT MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES
AND (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL))
diff --git a/mlir/examples/standalone/CMakeLists.txt b/mlir/examples/standalone/CMakeLists.txt
index c6c49fde12d2e..955c9ec7a7b4c 100644
--- a/mlir/examples/standalone/CMakeLists.txt
+++ b/mlir/examples/standalone/CMakeLists.txt
@@ -71,7 +71,9 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
endif()
add_subdirectory(python)
endif()
-add_subdirectory(test)
+if(MLIR_INCLUDE_TESTS)
+ add_subdirectory(test)
+endif()
add_subdirectory(standalone-opt)
if(NOT WIN32)
add_subdirectory(standalone-plugin)
diff --git a/mlir/examples/standalone/test/lit.cfg.py b/mlir/examples/standalone/test/lit.cfg.py
index e27dddd7fb0b9..89cdd6889a1f2 100644
--- a/mlir/examples/standalone/test/lit.cfg.py
+++ b/mlir/examples/standalone/test/lit.cfg.py
@@ -61,10 +61,8 @@
llvm_config.add_tool_substitutions(tools, tool_dirs)
-llvm_config.with_environment(
- "PYTHONPATH",
- [
- os.path.join(config.mlir_obj_dir, "python_packages", "standalone"),
- ],
- append_path=True,
-)
+python_path = [os.path.join(config.mlir_obj_dir, "python_packages", "standalone")]
+if "PYTHONPATH" in os.environ:
+ python_path += [os.environ["PYTHONPATH"]]
+
+llvm_config.with_environment("PYTHONPATH", python_path, append_path=True)
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index f8819841fac45..addd767f53592 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -1,16 +1,55 @@
-# RUN: %python %s nanobind | FileCheck %s
+# RUN: %python %s 2>&1 | FileCheck %s
+import sys
-from mlir_standalone.ir import *
+# CHECK: Testing mlir_standalone package
+print("Testing mlir_standalone package", file=sys.stderr)
+
+import mlir_standalone.ir
from mlir_standalone.dialects import standalone_nanobind as standalone_d
-with Context():
+with mlir_standalone.ir.Context():
standalone_d.register_dialects()
- module = Module.parse(
+ standalone_module = mlir_standalone.ir.Module.parse(
"""
%0 = arith.constant 2 : i32
%1 = standalone.foo %0 : i32
"""
)
- # CHECK: %[[C:.*]] = arith.constant 2 : i32
- # CHECK: standalone.foo %[[C]] : i32
- print(str(module))
+ # CHECK: %[[C2:.*]] = arith.constant 2 : i32
+ # CHECK: standalone.foo %[[C2]] : i32
+ print(str(standalone_module), file=sys.stderr)
+
+
+# CHECK: Testing mlir package
+print("Testing mlir package", file=sys.stderr)
+
+import mlir.ir
+from mlir.dialects import (
+ amdgpu,
+ gpu,
+ irdl,
+ llvm,
+ nvgpu,
+ pdl,
+ quant,
+ smt,
+ sparse_tensor,
+ transform,
+ # Note: uncommenting linalg below will cause
+ # LLVM ERROR: Attempting to attach an interface to an unregistered operation builtin.unrealized_conversion_cast.
+ # unless you have built both mlir and mlir_standalone with
+ # -DCMAKE_C_VISIBILITY_PRESET=hidden -DCMAKE_CXX_VISIBILITY_PRESET=hidden -DCMAKE_VISIBILITY_INLINES_HIDDEN=ON
+ # which is highly recommended.
+ # linalg,
+)
+
+# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
+
+with mlir.ir.Context():
+ mlir_module = mlir.ir.Module.parse(
+ """
+ %0 = arith.constant 3 : i32
+ """
+ )
+ # CHECK: %[[C3:.*]] = arith.constant 3 : i32
+ print(str(mlir_module), file=sys.stderr)
diff --git a/mlir/test/Examples/standalone/test.toy b/mlir/test/Examples/standalone/test.toy
index a88c115ebf197..dc3c17f3da3d9 100644
--- a/mlir/test/Examples/standalone/test.toy
+++ b/mlir/test/Examples/standalone/test.toy
@@ -4,8 +4,10 @@
# RUN: -DLLVM_ENABLE_LIBCXX=%enable_libcxx -DMLIR_DIR=%mlir_cmake_dir \
# RUN: -DLLVM_USE_LINKER=%llvm_use_linker \
# RUN: -DMLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone \
+# RUN: -DMLIR_INCLUDE_TESTS=ON \
# RUN: -DPython3_EXECUTABLE=%python \
# RUN: -DPython_EXECUTABLE=%python
+# RUN: export PYTHONPATH="%mlir_obj_root/python_packages/mlir_core"
# RUN: "%cmake_exe" --build . --target check-standalone | tee %t
# RUN: FileCheck --input-file=%t %s
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index c8d188a3cacd0..e9232e3f16098 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -14,21 +14,31 @@
# RUN: export CMAKE_GENERATOR=%cmake_generator
# RUN: export LLVM_USE_LINKER=%llvm_use_linker
# RUN: export MLIR_DIR="%mlir_cmake_dir"
+# RUN: export MLIR_INCLUDE_TESTS=ON
# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v | tee %t
# RUN: rm -rf "%mlir_obj_root/standalone-python-bindings-install"
# RUN: %python -m pip install standalone_python_bindings -f "%mlir_obj_root/wheelhouse" --target "%mlir_obj_root/standalone-python-bindings-install" -v | tee -a %t
-# RUN: export PYTHONPATH="%mlir_obj_root/standalone-python-bindings-install"
-# RUN: %python "%mlir_src_root/examples/standalone/test/python/smoketest.py" nanobind | tee -a %t
+# RUN: export PYTHONPATH="%mlir_obj_root/standalone-python-bindings-install:%mlir_obj_root/python_packages/mlir_core"
+# RUN: %python "%mlir_src_root/examples/standalone/test/python/smoketest.py" 2>&1 | tee -a %t
# RUN: FileCheck --input-file=%t %s
# CHECK: Successfully built standalone-python-bindings
+# CHECK: Testing mlir_standalone package
+
# CHECK: module {
# CHECK: %[[C2:.*]] = arith.constant 2 : i32
# CHECK: %[[V0:.*]] = standalone.foo %[[C2]] : i32
# CHECK: }
+# CHECK: Testing mlir package
+
+# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
+
+# CHECK: module {
+# CHECK: %[[C3:.*]] = arith.constant 3 : i32
+# CHECK: }
>From f7fcd17b870226126253ffe0d79b34920dffb1bb Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 00:28:32 -0800
Subject: [PATCH 02/38] add doc
---
mlir/docs/Bindings/Python.md | 12 ++++++++++++
.../standalone/test/python/smoketest.py | 17 ++++-------------
mlir/test/Examples/standalone/test.wheel.toy | 4 ----
3 files changed, 16 insertions(+), 17 deletions(-)
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 877ae5170d68c..4f4f531f7723c 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -25,6 +25,18 @@
multiple Python implementations, setting this explicitly to the preferred
`python3` executable is strongly recommended.
+* **`CMAKE_C_VISIBILITY_PRESET`**: `STRING`
+* **`CMAKE_CXX_VISIBILITY_PRESET`**: `STRING`
+* **`CMAKE_VISIBILITY_INLINES_HIDDEN`**: `BOOL`
+
+ It is **highly** recommended these are set to `hidden`, `hidden`, and `ON` (respectively) if the final built package
+ is intended to be used in a context/use-case where multiple bindings packages will be used simultaneously
+ (i.e., multiple bindings packages loaded in the same Python interpreter session). Failing to do so can lead
+ to incorrect/ambiguous symbol resolution; the symptom of this is an `LLVM ERROR` like:
+ ```
+ LLVM ERROR: ... unregistered/uninitialized dialect/type/pass ...`
+ ```
+
### Recommended development practices
It is recommended to use a Python virtual environment. Many ways exist for this,
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index addd767f53592..319251b063773 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -4,12 +4,12 @@
# CHECK: Testing mlir_standalone package
print("Testing mlir_standalone package", file=sys.stderr)
-import mlir_standalone.ir
+from mlir_standalone.ir import *
from mlir_standalone.dialects import standalone_nanobind as standalone_d
-with mlir_standalone.ir.Context():
+with Context():
standalone_d.register_dialects()
- standalone_module = mlir_standalone.ir.Module.parse(
+ module = Module.parse(
"""
%0 = arith.constant 2 : i32
%1 = standalone.foo %0 : i32
@@ -17,7 +17,7 @@
)
# CHECK: %[[C2:.*]] = arith.constant 2 : i32
# CHECK: standalone.foo %[[C2]] : i32
- print(str(standalone_module), file=sys.stderr)
+ print(str(module), file=sys.stderr)
# CHECK: Testing mlir package
@@ -44,12 +44,3 @@
)
# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
-
-with mlir.ir.Context():
- mlir_module = mlir.ir.Module.parse(
- """
- %0 = arith.constant 3 : i32
- """
- )
- # CHECK: %[[C3:.*]] = arith.constant 3 : i32
- print(str(mlir_module), file=sys.stderr)
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index e9232e3f16098..ebf26b26b4cba 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -38,7 +38,3 @@
# CHECK: Testing mlir package
# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
-
-# CHECK: module {
-# CHECK: %[[C3:.*]] = arith.constant 3 : i32
-# CHECK: }
>From 190e8847cc91b2cc751a8903f2204cb98375e06c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 02:26:04 -0800
Subject: [PATCH 03/38] fixup
---
mlir/cmake/modules/AddMLIR.cmake | 9 +++++
mlir/examples/standalone/CMakeLists.txt | 3 ++
mlir/examples/standalone/pyproject.toml | 7 +++-
.../standalone/test/python/smoketest.py | 35 +++++++------------
mlir/test/Examples/standalone/test.toy | 1 +
mlir/test/Examples/standalone/test.wheel.toy | 6 ++++
6 files changed, 38 insertions(+), 23 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 6589458ab7894..92d558d86c754 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -445,6 +445,15 @@ function(add_mlir_library name)
MLIR_AGGREGATE_DEP_LIBS_IMPORTED "${CURRENT_LINK_LIBRARIES}"
)
+ # On MacOS, all template instantiations become weak symbols - this causes incorrect symbol
+ # resolution in cases where multiple aggregates are loaded in the same process (such as when multiple Python
+ # bindings packages are loaded, each with their own C API aggregate).
+ if(APPLE AND TARGET "obj.${name}" AND (NOT BUILD_SHARED_LIBS))
+ set_target_properties("obj.${name}" PROPERTIES
+ C_VISIBILITY_PRESET hidden
+ CXX_VISIBILITY_PRESET hidden
+ VISIBILITY_INLINES_HIDDEN YES)
+ endif()
# In order for out-of-tree projects to build aggregates of this library,
# we need to install the OBJECT library.
if(TARGET "obj.${name}" AND MLIR_INSTALL_AGGREGATE_OBJECTS AND NOT ARG_DISABLE_INSTALL)
diff --git a/mlir/examples/standalone/CMakeLists.txt b/mlir/examples/standalone/CMakeLists.txt
index 955c9ec7a7b4c..17d712f6a1064 100644
--- a/mlir/examples/standalone/CMakeLists.txt
+++ b/mlir/examples/standalone/CMakeLists.txt
@@ -66,6 +66,9 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
if(NOT MLIR_PYTHON_PACKAGE_PREFIX)
set(MLIR_PYTHON_PACKAGE_PREFIX "mlir_standalone" CACHE STRING "" FORCE)
endif()
+ if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir_standalone" CACHE STRING "" FORCE)
+ endif()
if(NOT MLIR_BINDINGS_PYTHON_INSTALL_PREFIX)
set(MLIR_BINDINGS_PYTHON_INSTALL_PREFIX "python_packages/standalone/${MLIR_PYTHON_PACKAGE_PREFIX}" CACHE STRING "" FORCE)
endif()
diff --git a/mlir/examples/standalone/pyproject.toml b/mlir/examples/standalone/pyproject.toml
index c4194153743ef..8fc8d7d8266c3 100644
--- a/mlir/examples/standalone/pyproject.toml
+++ b/mlir/examples/standalone/pyproject.toml
@@ -6,7 +6,7 @@
[project]
name = "standalone-python-bindings"
dynamic = ["version"]
-requires-python = ">=3.8,<=3.14"
+requires-python = ">=3.8"
dependencies = [
"numpy>=1.19.5, <=2.1.2",
"PyYAML>=5.4.0, <=6.0.1",
@@ -56,9 +56,14 @@ MLIR_DIR = { env = "MLIR_DIR", default = "" }
# Non-optional
CMAKE_BUILD_TYPE = { env = "CMAKE_BUILD_TYPE", default = "Release" }
MLIR_ENABLE_BINDINGS_PYTHON = "ON"
+
# Effectively non-optional (any downstream project should specify this).
+MLIR_BINDINGS_PYTHON_NB_DOMAIN = "mlir_standalone"
MLIR_PYTHON_PACKAGE_PREFIX = "mlir_standalone"
+
# This specifies the directory in the install directory (i.e., /tmp/pip-wheel/platlib) where _mlir_libs, dialects, etc.
# are installed. Thus, this will be the package location (and the name of the package) that pip assumes is
# the root package.
MLIR_BINDINGS_PYTHON_INSTALL_PREFIX = "mlir_standalone"
+# Optional
+MLIR_INCLUDE_TESTS = { env = "MLIR_INCLUDE_TESTS", default = "ON" }
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 319251b063773..3ec0e7a33f233 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -4,12 +4,12 @@
# CHECK: Testing mlir_standalone package
print("Testing mlir_standalone package", file=sys.stderr)
-from mlir_standalone.ir import *
+import mlir_standalone.ir
from mlir_standalone.dialects import standalone_nanobind as standalone_d
-with Context():
+with mlir_standalone.ir.Context():
standalone_d.register_dialects()
- module = Module.parse(
+ standalone_module = mlir_standalone.ir.Module.parse(
"""
%0 = arith.constant 2 : i32
%1 = standalone.foo %0 : i32
@@ -17,30 +17,21 @@
)
# CHECK: %[[C2:.*]] = arith.constant 2 : i32
# CHECK: standalone.foo %[[C2]] : i32
- print(str(module), file=sys.stderr)
+ print(str(standalone_module), file=sys.stderr)
# CHECK: Testing mlir package
print("Testing mlir package", file=sys.stderr)
import mlir.ir
-from mlir.dialects import (
- amdgpu,
- gpu,
- irdl,
- llvm,
- nvgpu,
- pdl,
- quant,
- smt,
- sparse_tensor,
- transform,
- # Note: uncommenting linalg below will cause
- # LLVM ERROR: Attempting to attach an interface to an unregistered operation builtin.unrealized_conversion_cast.
- # unless you have built both mlir and mlir_standalone with
- # -DCMAKE_C_VISIBILITY_PRESET=hidden -DCMAKE_CXX_VISIBILITY_PRESET=hidden -DCMAKE_VISIBILITY_INLINES_HIDDEN=ON
- # which is highly recommended.
- # linalg,
-)
# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
+
+with mlir.ir.Context():
+ mlir_module = mlir.ir.Module.parse(
+ """
+ %0 = arith.constant 3 : i32
+ """
+ )
+ # CHECK: %[[C3:.*]] = arith.constant 3 : i32
+ print(str(mlir_module), file=sys.stderr)
diff --git a/mlir/test/Examples/standalone/test.toy b/mlir/test/Examples/standalone/test.toy
index dc3c17f3da3d9..8836adde72ed8 100644
--- a/mlir/test/Examples/standalone/test.toy
+++ b/mlir/test/Examples/standalone/test.toy
@@ -4,6 +4,7 @@
# RUN: -DLLVM_ENABLE_LIBCXX=%enable_libcxx -DMLIR_DIR=%mlir_cmake_dir \
# RUN: -DLLVM_USE_LINKER=%llvm_use_linker \
# RUN: -DMLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone \
+# RUN: -DMLIR_BINDINGS_PYTHON_NB_DOMAIN=mlir_standalone \
# RUN: -DMLIR_INCLUDE_TESTS=ON \
# RUN: -DPython3_EXECUTABLE=%python \
# RUN: -DPython_EXECUTABLE=%python
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index ebf26b26b4cba..68ad479de8395 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -15,6 +15,8 @@
# RUN: export LLVM_USE_LINKER=%llvm_use_linker
# RUN: export MLIR_DIR="%mlir_cmake_dir"
# RUN: export MLIR_INCLUDE_TESTS=ON
+# RUN: export MLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone
+# RUN: export MLIR_BINDINGS_PYTHON_NB_DOMAIN=mlir_standalone
# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v | tee %t
@@ -38,3 +40,7 @@
# CHECK: Testing mlir package
# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
+
+# CHECK: module {
+# CHECK: %[[C3:.*]] = arith.constant 3 : i32
+# CHECK: }
>From 10edd430f57e73620a7c9a13df44cef6e11877d5 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 12:01:16 -0800
Subject: [PATCH 04/38] dont create module
---
mlir/examples/standalone/test/python/smoketest.py | 11 +----------
mlir/test/Examples/standalone/test.wheel.toy | 3 ---
2 files changed, 1 insertion(+), 13 deletions(-)
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 3ec0e7a33f233..09040eb2f45dc 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -23,15 +23,6 @@
# CHECK: Testing mlir package
print("Testing mlir package", file=sys.stderr)
-import mlir.ir
+from mlir.ir import *
# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
-
-with mlir.ir.Context():
- mlir_module = mlir.ir.Module.parse(
- """
- %0 = arith.constant 3 : i32
- """
- )
- # CHECK: %[[C3:.*]] = arith.constant 3 : i32
- print(str(mlir_module), file=sys.stderr)
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index 68ad479de8395..b60347ba687d0 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -41,6 +41,3 @@
# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
-# CHECK: module {
-# CHECK: %[[C3:.*]] = arith.constant 3 : i32
-# CHECK: }
>From ed388a1f329804e0612e3b77cb8565926b501153 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 23:57:13 -0800
Subject: [PATCH 05/38] [mlir][Python] create MLIRPythonSupport
---
mlir/python/CMakeLists.txt | 65 ++++++++++++++++++++++++++++++--------
1 file changed, 52 insertions(+), 13 deletions(-)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 1e9f1e11d4d06..e9b1aff0455e6 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,6 +3,8 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
+set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
+set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
################################################################################
# Structural groupings.
@@ -524,27 +526,17 @@ declare_mlir_dialect_python_bindings(
# dependencies.
################################################################################
-set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
declare_mlir_python_extension(MLIRPythonExtension.Core
MODULE_NAME _mlir
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
MainModule.cpp
- IRAffine.cpp
- IRAttributes.cpp
- IRCore.cpp
- IRInterfaces.cpp
- IRModule.cpp
- IRTypes.cpp
Pass.cpp
Rewrite.cpp
# Headers must be included explicitly so they are installed.
- Globals.h
- IRModule.h
Pass.h
- NanobindUtils.h
Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
@@ -752,8 +744,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectSMT.cpp
- # Headers must be included explicitly so they are installed.
- NanobindUtils.h
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
@@ -860,7 +850,6 @@ endif()
# once ready.
################################################################################
-set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
add_mlir_python_common_capi_library(MLIRPythonCAPI
INSTALL_COMPONENT MLIRPythonModules
INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
@@ -997,3 +986,53 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
endif()
endif()
+
+get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso PROPERTY LINK_LIBRARIES)
+list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
+add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
+add_mlir_library(MLIRPythonSupport
+ ${PYTHON_SOURCE_DIR}/Globals.cpp
+ ${PYTHON_SOURCE_DIR}/IRAffine.cpp
+ ${PYTHON_SOURCE_DIR}/IRAttributes.cpp
+ ${PYTHON_SOURCE_DIR}/IRCore.cpp
+ ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
+ ${PYTHON_SOURCE_DIR}/IRTypes.cpp
+ EXCLUDE_FROM_LIBMLIR
+ SHARED
+ LINK_COMPONENTS
+ Support
+ LINK_LIBS
+ ${NB_LIBRARY_TARGET_NAME}
+ MLIRCAPIIR
+)
+target_link_libraries(MLIRPythonSupport PUBLIC ${NB_LIBRARY_TARGET_NAME})
+nanobind_link_options(MLIRPythonSupport)
+set_target_properties(MLIRPythonSupport PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ # Needed for windows (and doesn't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+)
+set(eh_rtti_enable)
+if(MSVC)
+ set(eh_rtti_enable /EHsc /GR)
+elseif(LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
+ set(eh_rtti_enable -frtti -fexceptions)
+endif()
+target_compile_options(MLIRPythonSupport PRIVATE ${eh_rtti_enable})
+if(APPLE)
+ # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
+ # doesn't declare this API as undefined in its linker flags. So we need to declare it as such
+ # for downstream users that do not do something like `-undefined dynamic_lookup`.
+ # Same for the rest.
+ target_link_options(MLIRPythonSupport PUBLIC
+ "LINKER:-U,_PyClassMethod_New"
+ "LINKER:-U,_PyCode_Addr2Location"
+ "LINKER:-U,_PyFrame_GetLasti"
+ )
+endif()
+target_link_libraries(
+ MLIRPythonModules.extension._mlir.dso
+ PUBLIC MLIRPythonSupport)
+
>From a5e1569486bfb1dfa673b6ae2e4793fdff390061 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 11 Dec 2025 13:19:23 -0800
Subject: [PATCH 06/38] kind of working
---
mlir/cmake/modules/AddMLIRPython.cmake | 1 +
.../mlir}/Bindings/Python/Globals.h | 48 +-
.../mlir/Bindings/Python/IRCore.h} | 1025 ++++-
.../mlir}/Bindings/Python/NanobindUtils.h | 0
mlir/lib/Bindings/Python/DialectSMT.cpp | 2 +-
.../Python/{IRModule.cpp => Globals.cpp} | 14 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 10 +-
mlir/lib/Bindings/Python/IRAttributes.cpp | 21 +-
mlir/lib/Bindings/Python/IRCore.cpp | 3317 +----------------
mlir/lib/Bindings/Python/IRInterfaces.cpp | 2 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 22 +-
mlir/lib/Bindings/Python/MainModule.cpp | 2277 ++++++++++-
mlir/lib/Bindings/Python/Pass.cpp | 4 +-
mlir/lib/Bindings/Python/Pass.h | 2 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
mlir/lib/Bindings/Python/Rewrite.h | 2 +-
mlir/python/CMakeLists.txt | 19 +-
17 files changed, 3428 insertions(+), 3340 deletions(-)
rename mlir/{lib => include/mlir}/Bindings/Python/Globals.h (82%)
rename mlir/{lib/Bindings/Python/IRModule.h => include/mlir/Bindings/Python/IRCore.h} (57%)
rename mlir/{lib => include/mlir}/Bindings/Python/NanobindUtils.h (100%)
rename mlir/lib/Bindings/Python/{IRModule.cpp => Globals.cpp} (97%)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 8c301faf0941a..111ef45609160 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -764,6 +764,7 @@ function(add_mlir_python_extension libname extname)
nanobind_add_module(${libname}
NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
FREE_THREADED
+ NB_SHARED
${ARG_SOURCES}
)
if(APPLE)
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
similarity index 82%
rename from mlir/lib/Bindings/Python/Globals.h
rename to mlir/include/mlir/Bindings/Python/Globals.h
index 1e81f53e465ac..fea7a201453ce 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -15,10 +15,12 @@
#include <unordered_set>
#include <vector>
-#include "NanobindUtils.h"
+#include "mlir-c/Debug.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir/CAPI/Support.h"
+
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -200,6 +202,50 @@ class PyGlobals {
TypeIDAllocator typeIDAllocator;
};
+/// Wrapper for the global LLVM debugging flag.
+struct PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
+
+ static bool get(const nanobind::object &) {
+ nanobind::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ // Debug flags.
+ nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ nanobind::arg("types"),
+ "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ nanobind::arg("types"),
+ "Sets multiple specific debug types to be produced by LLVM.");
+ }
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+
+
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/include/mlir/Bindings/Python/IRCore.h
similarity index 57%
rename from mlir/lib/Bindings/Python/IRModule.h
rename to mlir/include/mlir/Bindings/Python/IRCore.h
index e706be3b4d32a..488196ea42e44 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1,4 +1,4 @@
-//===- IRModules.h - IR Submodules of pybind module -----------------------===//
+//===- IRCore.h - IR helpers of python bindings ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,8 +7,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
-#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
-#define MLIR_BINDINGS_PYTHON_IRMODULES_H
+#ifndef MLIR_BINDINGS_PYTHON_IRCORE_H
+#define MLIR_BINDINGS_PYTHON_IRCORE_H
#include <optional>
#include <sstream>
@@ -20,12 +20,14 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ThreadPool.h"
@@ -1323,12 +1325,1017 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
-void populateIRAffine(nanobind::module_ &m);
-void populateIRAttributes(nanobind::module_ &m);
-void populateIRCore(nanobind::module_ &m);
-void populateIRInterfaces(nanobind::module_ &m);
-void populateIRTypes(nanobind::module_ &m);
+//------------------------------------------------------------------------------
+// Utilities.
+//------------------------------------------------------------------------------
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nanobind::object classmethod(Func f, Args... args) {
+ nanobind::object cf = nanobind::cpp_function(f, args...);
+ return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
+static nanobind::object
+createCustomDialectWrapper(const std::string &dialectNamespace,
+ nanobind::object dialectDescriptor) {
+ auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+ if (!dialectClass) {
+ // Use the base class.
+ return nanobind::cast(PyDialect(std::move(dialectDescriptor)));
+ }
+
+ // Create the custom implementation.
+ return (*dialectClass)(std::move(dialectDescriptor));
+}
+
+static MlirStringRef toMlirStringRef(const std::string &s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+static MlirStringRef toMlirStringRef(std::string_view s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+static MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
+ return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
+}
+
+/// Create a block, using the current location context if no locations are
+/// specified.
+static MlirBlock
+createBlock(const nanobind::sequence &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ SmallVector<MlirType> argTypes;
+ argTypes.reserve(nanobind::len(pyArgTypes));
+ for (const auto &pyType : pyArgTypes)
+ argTypes.push_back(nanobind::cast<PyType &>(pyType));
+
+ SmallVector<MlirLocation> argLocs;
+ if (pyArgLocs) {
+ argLocs.reserve(nanobind::len(*pyArgLocs));
+ for (const auto &pyLoc : *pyArgLocs)
+ argLocs.push_back(nanobind::cast<PyLocation &>(pyLoc));
+ } else if (!argTypes.empty()) {
+ argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
+ }
+
+ if (argTypes.size() != argLocs.size())
+ throw nanobind::value_error(("Expected " + Twine(argTypes.size()) +
+ " locations, got: " + Twine(argLocs.size()))
+ .str()
+ .c_str());
+ return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
+}
+
+struct PyAttrBuilderMap {
+ static bool dunderContains(const std::string &attributeKind) {
+ return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+ }
+ static nanobind::callable
+ dunderGetItemNamed(const std::string &attributeKind) {
+ auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+ if (!builder)
+ throw nanobind::key_error(attributeKind.c_str());
+ return *builder;
+ }
+ static void dunderSetItemNamed(const std::string &attributeKind,
+ nanobind::callable func, bool replace) {
+ PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
+ replace);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyAttrBuilderMap>(m, "AttrBuilder")
+ .def_static("contains", &PyAttrBuilderMap::dunderContains,
+ nanobind::arg("attribute_kind"),
+ "Checks whether an attribute builder is registered for the "
+ "given attribute kind.")
+ .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
+ nanobind::arg("attribute_kind"),
+ "Gets the registered attribute builder for the given "
+ "attribute kind.")
+ .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
+ nanobind::arg("attribute_kind"),
+ nanobind::arg("attr_builder"),
+ nanobind::arg("replace") = false,
+ "Register an attribute builder for building MLIR "
+ "attributes from Python values.");
+ }
+};
+
+//------------------------------------------------------------------------------
+// PyBlock
+//------------------------------------------------------------------------------
+
+inline nanobind::object PyBlock::getCapsule() {
+ return nanobind::steal<nanobind::object>(mlirPythonBlockToCapsule(get()));
+}
+
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+class PyRegionIterator {
+public:
+ PyRegionIterator(PyOperationRef operation, int nextIndex)
+ : operation(std::move(operation)), nextIndex(nextIndex) {}
+
+ PyRegionIterator &dunderIter() { return *this; }
+
+ PyRegion dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw nanobind::stop_iteration();
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyRegionIterator>(m, "RegionIterator")
+ .def("__iter__", &PyRegionIterator::dunderIter,
+ "Returns an iterator over the regions in the operation.")
+ .def("__next__", &PyRegionIterator::dunderNext,
+ "Returns the next region in the iteration.");
+ }
+
+private:
+ PyOperationRef operation;
+ intptr_t nextIndex = 0;
+};
+
+/// Regions of an op are fixed length and indexed numerically so are represented
+/// with a sequence-like container.
+class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
+public:
+ static constexpr const char *pyClassName = "RegionSequence";
+
+ PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ PyRegionIterator dunderIter() {
+ operation->checkValid();
+ return PyRegionIterator(operation, startIndex);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__iter__", &PyRegionList::dunderIter,
+ "Returns an iterator over the regions in the sequence.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyRegionList, PyRegion>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+ }
+
+ PyRegion getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
+ }
+
+ PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyRegionList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+class PyBlockIterator {
+public:
+ PyBlockIterator(PyOperationRef operation, MlirBlock next)
+ : operation(std::move(operation)), next(next) {}
+
+ PyBlockIterator &dunderIter() { return *this; }
+
+ PyBlock dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyBlockIterator>(m, "BlockIterator")
+ .def("__iter__", &PyBlockIterator::dunderIter,
+ "Returns an iterator over the blocks in the operation's region.")
+ .def("__next__", &PyBlockIterator::dunderNext,
+ "Returns the next block in the iteration.");
+ }
+
+private:
+ PyOperationRef operation;
+ MlirBlock next;
+};
+
+/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
+/// we present them as a more full-featured list-like container but optimize
+/// it for forward iteration. Blocks are always owned by a region.
+class PyBlockList {
+public:
+ PyBlockList(PyOperationRef operation, MlirRegion region)
+ : operation(std::move(operation)), region(region) {}
+
+ PyBlockIterator dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+ }
+
+ intptr_t dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
+ }
+ return count;
+ }
+
+ PyBlock dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds block");
+ }
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
+ }
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds block");
+ }
+
+ PyBlock appendBlock(const nanobind::args &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ operation->checkValid();
+ MlirBlock block =
+ createBlock(nanobind::cast<nanobind::sequence>(pyArgTypes), pyArgLocs);
+ mlirRegionAppendOwnedBlock(region, block);
+ return PyBlock(operation, block);
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyBlockList>(m, "BlockList")
+ .def("__getitem__", &PyBlockList::dunderGetItem,
+ "Returns the block at the specified index.")
+ .def("__iter__", &PyBlockList::dunderIter,
+ "Returns an iterator over blocks in the operation's region.")
+ .def("__len__", &PyBlockList::dunderLen,
+ "Returns the number of blocks in the operation's region.")
+ .def("append", &PyBlockList::appendBlock,
+ R"(
+ Appends a new block, with argument types as positional args.
+
+ Returns:
+ The created block.
+ )",
+ nanobind::arg("args"), nanobind::kw_only(),
+ nanobind::arg("arg_locs") = std::nullopt);
+ }
+
+private:
+ PyOperationRef operation;
+ MlirRegion region;
+};
+
+class PyOperationIterator {
+public:
+ PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
+ : parentOperation(std::move(parentOperation)), next(next) {}
+
+ PyOperationIterator &dunderIter() { return *this; }
+
+ nanobind::typed<nanobind::object, PyOpView> dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation->createOpView();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOperationIterator>(m, "OperationIterator")
+ .def("__iter__", &PyOperationIterator::dunderIter,
+ "Returns an iterator over the operations in an operation's block.")
+ .def("__next__", &PyOperationIterator::dunderNext,
+ "Returns the next operation in the iteration.");
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirOperation next;
+};
+
+/// Operations are exposed by the C-API as a forward-only linked list. In
+/// Python, we present them as a more full-featured list-like container but
+/// optimize it for forward iteration. Iterable operations are always owned
+/// by a block.
+class PyOperationList {
+public:
+ PyOperationList(PyOperationRef parentOperation, MlirBlock block)
+ : parentOperation(std::move(parentOperation)), block(block) {}
+
+ PyOperationIterator dunderIter() {
+ parentOperation->checkValid();
+ return PyOperationIterator(parentOperation,
+ mlirBlockGetFirstOperation(block));
+ }
+
+ intptr_t dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
+ }
+ return count;
+ }
+
+ nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds operation");
+ }
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ ->createOpView();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds operation");
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOperationList>(m, "OperationList")
+ .def("__getitem__", &PyOperationList::dunderGetItem,
+ "Returns the operation at the specified index.")
+ .def("__iter__", &PyOperationList::dunderIter,
+ "Returns an iterator over operations in the list.")
+ .def("__len__", &PyOperationList::dunderLen,
+ "Returns the number of operations in the list.");
+ }
+
+private:
+ PyOperationRef parentOperation;
+ MlirBlock block;
+};
+
+class PyOpOperand {
+public:
+ PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+ nanobind::typed<nanobind::object, PyOpView> getOwner() {
+ MlirOperation owner = mlirOpOperandGetOwner(opOperand);
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(owner));
+ return PyOperation::forOperation(context, owner)->createOpView();
+ }
+
+ size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpOperand>(m, "OpOperand")
+ .def_prop_ro("owner", &PyOpOperand::getOwner,
+ "Returns the operation that owns this operand.")
+ .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
+ "Returns the operand number in the owning operation.");
+ }
+
+private:
+ MlirOpOperand opOperand;
+};
+
+class PyOpOperandIterator {
+public:
+ PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
+
+ PyOpOperandIterator &dunderIter() { return *this; }
+
+ PyOpOperand dunderNext() {
+ if (mlirOpOperandIsNull(opOperand))
+ throw nanobind::stop_iteration();
+
+ PyOpOperand returnOpOperand(opOperand);
+ opOperand = mlirOpOperandGetNextUse(opOperand);
+ return returnOpOperand;
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpOperandIterator>(m, "OpOperandIterator")
+ .def("__iter__", &PyOpOperandIterator::dunderIter,
+ "Returns an iterator over operands.")
+ .def("__next__", &PyOpOperandIterator::dunderNext,
+ "Returns the next operand in the iteration.");
+ }
+
+private:
+ MlirOpOperand opOperand;
+};
+
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = nanobind::class_<DerivedTy, PyValue>;
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ PyConcreteValue() = default;
+ PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+ : PyValue(operationRef, value) {}
+ PyConcreteValue(PyValue &orig)
+ : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+ /// Attempts to cast the original value to the derived type and throws on
+ /// type mismatches.
+ static MlirValue castFrom(PyValue &orig) {
+ if (!DerivedTy::isaFunction(orig.get())) {
+ auto origRepr =
+ nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
+ throw nanobind::value_error((Twine("Cannot cast value to ") +
+ DerivedTy::pyClassName + " (from " +
+ origRepr + ")")
+ .str()
+ .c_str());
+ }
+ return orig.get();
+ }
+
+ /// Binds the Python module objects to functions of this class.
+ static void bind(nanobind::module_ &m) {
+ auto cls = ClassTy(
+ m, DerivedTy::pyClassName, nanobind::is_generic(),
+ nanobind::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
+ .str()
+ .c_str()));
+ cls.def(nanobind::init<PyValue &>(), nanobind::keep_alive<0, 1>(),
+ nanobind::arg("value"));
+ cls.def_static(
+ "isinstance",
+ [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ },
+ nanobind::arg("other_value"));
+ cls.def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
+ return self.maybeDownCast();
+ });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+ static constexpr const char *pyClassName = "OpResult";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOperation> {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in "
+ "the IR");
+ return self.getParentOperation().getObject();
+ },
+ "Returns the operation that produces this result.");
+ c.def_prop_ro(
+ "result_number",
+ [](PyOpResult &self) {
+ return mlirOpResultGetResultNumber(self.get());
+ },
+ "Returns the position of this result in the operation's result list.");
+ }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<nanobind::typed<nanobind::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+ std::vector<nanobind::typed<nanobind::object, PyType>> result;
+ result.reserve(container.size());
+ for (int i = 0, e = container.size(); i < e; ++i) {
+ result.push_back(PyType(context->getRef(),
+ mlirValueGetType(container.getElement(i).get()))
+ .maybeDownCast());
+ }
+ return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+ static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+ PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all results in this result list.");
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResultList &self)
+ -> nanobind::typed<nanobind::object, PyOpView> {
+ return self.operation->createOpView();
+ },
+ "Returns the operation that owns this result list.");
+ }
+
+ PyOperationRef &getOperation() { return operation; }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpResultList, PyOpResult>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+ }
+
+ PyOpResult getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+ }
+
+ PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpResultList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// Python wrapper for MlirBlockArgument.
+class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
+ static constexpr const char *pyClassName = "BlockArgument";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyBlockArgument &self) {
+ return PyBlock(self.getParentOperation(),
+ mlirBlockArgumentGetOwner(self.get()));
+ },
+ "Returns the block that owns this argument.");
+ c.def_prop_ro(
+ "arg_number",
+ [](PyBlockArgument &self) {
+ return mlirBlockArgumentGetArgNumber(self.get());
+ },
+ "Returns the position of this argument in the block's argument list.");
+ c.def(
+ "set_type",
+ [](PyBlockArgument &self, PyType type) {
+ return mlirBlockArgumentSetType(self.get(), type);
+ },
+ nanobind::arg("type"), "Sets the type of this block argument.");
+ c.def(
+ "set_location",
+ [](PyBlockArgument &self, PyLocation loc) {
+ return mlirBlockArgumentSetLocation(self.get(), loc);
+ },
+ nanobind::arg("loc"), "Sets the location of this block argument.");
+ }
+};
+
+/// A list of block arguments. Internally, these are stored as consecutive
+/// elements, random access is cheap. The argument list is associated with the
+/// operation that contains the block (detached blocks are not allowed in
+/// Python bindings) and extends its lifetime.
+class PyBlockArgumentList
+ : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
+public:
+ static constexpr const char *pyClassName = "BlockArgumentList";
+ using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
+ PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumArguments(block) : length,
+ step),
+ operation(std::move(operation)), block(block) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all arguments in this argument list.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
+ /// Returns the number of arguments in the list.
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirBlockGetNumArguments(block);
+ }
+
+ /// Returns `pos`-the element in the list.
+ PyBlockArgument getRawElement(intptr_t pos) {
+ MlirValue argument = mlirBlockGetArgument(block, pos);
+ return PyBlockArgument(operation, argument);
+ }
+
+ /// Returns a sublist of this list.
+ PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyBlockArgumentList(operation, block, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ MlirBlock block;
+};
+
+/// A list of operation operands. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) operand list is associated
+/// with the operation whose operands these are, and thus extends the lifetime
+/// of this operation.
+class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
+public:
+ static constexpr const char *pyClassName = "OpOperandList";
+ using SliceableT = Sliceable<PyOpOperandList, PyValue>;
+
+ PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumOperands(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+ void dunderSetItem(intptr_t index, PyValue value) {
+ index = wrapIndex(index);
+ mlirOperationSetOperand(operation->get(), index, value.get());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem,
+ nanobind::arg("index"), nanobind::arg("value"),
+ "Sets the operand at the specified index to a new value.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpOperandList, PyValue>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+ }
+
+ PyValue getRawElement(intptr_t pos) {
+ MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(operand))
+ owner = mlirOpResultGetOwner(operand);
+ else if (mlirValueIsABlockArgument(operand))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ PyOperationRef pyOwner =
+ PyOperation::forOperation(operation->getContext(), owner);
+ return PyValue(pyOwner, operand);
+ }
+
+ PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpOperandList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// A list of operation successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation whose successors these are, and thus extends
+/// the lifetime of this operation.
+class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "OpSuccessors";
+
+ PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumSuccessors(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
+ void dunderSetItem(intptr_t index, PyBlock block) {
+ index = wrapIndex(index);
+ mlirOperationSetSuccessor(operation->get(), index, block.get());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nanobind::arg("index"),
+ nanobind::arg("block"),
+ "Sets the successor block at the specified index.");
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumSuccessors(operation->get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpSuccessors(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockSuccessors";
+
+ PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumSuccessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumSuccessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyBlockSuccessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of block predecessors. The (returned) predecessor list is
+/// associated with the operation and block whose predecessors these are, and
+/// thus extends the lifetime of this operation and block.
+///
+/// WARNING: This Sliceable is more expensive than the others here because
+/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
+/// operands) anew for each indexed access.
+class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockPredecessors";
+
+ PyBlockPredecessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumPredecessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockPredecessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumPredecessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyBlockPredecessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
+/// A list of operation attributes. Can be indexed by name, producing
+/// attributes, or by index, producing named attributes.
+class PyOpAttributeMap {
+public:
+ PyOpAttributeMap(PyOperationRef operation)
+ : operation(std::move(operation)) {}
+
+ nanobind::typed<nanobind::object, PyAttribute>
+ dunderGetItemNamed(const std::string &name) {
+ MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr)) {
+ throw nanobind::key_error("attempt to access a non-existent attribute");
+ }
+ return PyAttribute(operation->getContext(), attr).maybeDownCast();
+ }
+
+ PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0 || index >= dunderLen()) {
+ throw nanobind::index_error("attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr =
+ mlirOperationGetAttribute(operation->get(), index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data,
+ mlirIdentifierStr(namedAttr.name).length));
+ }
+
+ void dunderSetItem(const std::string &name, const PyAttribute &attr) {
+ mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
+ attr);
+ }
+
+ void dunderDelItem(const std::string &name) {
+ int removed = mlirOperationRemoveAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (!removed)
+ throw nanobind::key_error("attempt to delete a non-existent attribute");
+ }
+
+ intptr_t dunderLen() {
+ return mlirOperationGetNumAttributes(operation->get());
+ }
+
+ bool dunderContains(const std::string &name) {
+ return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
+ operation->get(), toMlirStringRef(name)));
+ }
+
+ static void
+ forEachAttr(MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
+ }
+ }
+
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+ .def("__contains__", &PyOpAttributeMap::dunderContains,
+ nanobind::arg("name"),
+ "Checks if an attribute with the given name exists in the map.")
+ .def("__len__", &PyOpAttributeMap::dunderLen,
+ "Returns the number of attributes in the map.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
+ nanobind::arg("name"), "Gets an attribute by name.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
+ nanobind::arg("index"), "Gets a named attribute by index.")
+ .def("__setitem__", &PyOpAttributeMap::dunderSetItem,
+ nanobind::arg("name"), nanobind::arg("attr"),
+ "Sets an attribute with the given name.")
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem,
+ nanobind::arg("name"), "Deletes an attribute with the given name.")
+ .def(
+ "__iter__",
+ [](PyOpAttributeMap &self) {
+ nanobind::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nanobind::str(name.data, name.length));
+ });
+ return nanobind::iter(keys);
+ },
+ "Iterates over attribute names.")
+ .def(
+ "keys",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ out.append(nanobind::str(name.data, name.length));
+ });
+ return out;
+ },
+ "Returns a list of attribute names.")
+ .def(
+ "values",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ },
+ "Returns a list of attribute values.")
+ .def(
+ "items",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nanobind::make_tuple(
+ nanobind::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ },
+ "Returns a list of `(name, attribute)` tuples.");
+ }
+
+private:
+ PyOperationRef operation;
+};
+MlirValue getUniqueResult(MlirOperation operation);
} // namespace python
} // namespace mlir
@@ -1345,4 +2352,4 @@ struct type_caster<mlir::python::DefaultingPyLocation>
} // namespace detail
} // namespace nanobind
-#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
+#endif // MLIR_BINDINGS_PYTHON_IRCORE_H
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
similarity index 100%
rename from mlir/lib/Bindings/Python/NanobindUtils.h
rename to mlir/include/mlir/Bindings/Python/NanobindUtils.h
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 0d1d9e89f92f6..a87918a05b126 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Dialect/SMT.h"
#include "mlir-c/IR.h"
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/Globals.cpp
similarity index 97%
rename from mlir/lib/Bindings/Python/IRModule.cpp
rename to mlir/lib/Bindings/Python/Globals.cpp
index 0de2f1711829b..bc6b210426221 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -6,25 +6,27 @@
//
//===----------------------------------------------------------------------===//
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include <optional>
#include <vector>
-#include "Globals.h"
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/Globals.h"
+// clang-format off
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
+// clang-format on
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
+namespace mlir::python {
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
@@ -265,3 +267,7 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename(
}
return isUserTracebackFilenameCache[file];
}
+
+nanobind::ft_mutex PyGlobalDebugFlag::mutex;
+
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 7147f2cbad149..624d8f0fa57ce 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -13,11 +13,13 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
+#include "mlir/Bindings/Python/IRCore.h"
+// clang-format off
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir-c/IntegerSet.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Support/LLVM.h"
@@ -509,7 +511,8 @@ PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) {
rawIntegerSet);
}
-void mlir::python::populateIRAffine(nb::module_ &m) {
+namespace mlir::python {
+void populateIRAffine(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineExpr and derived classes.
//----------------------------------------------------------------------------
@@ -995,3 +998,4 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
PyIntegerSetConstraint::bind(m);
PyIntegerSetConstraintList::bind(m);
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c0a945e3f4f3b..36367e658697c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -12,12 +12,12 @@
#include <string_view>
#include <utility>
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
@@ -1799,7 +1799,8 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
-void mlir::python::populateIRAttributes(nb::module_ &m) {
+namespace mlir::python {
+void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
@@ -1851,4 +1852,18 @@ void mlir::python::populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 168c57955af07..88cffb64906d7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
+// clang-format off
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
@@ -22,6 +24,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
+#include <iostream>
#include <optional>
namespace nb = nanobind;
@@ -33,504 +36,7 @@ using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
-static const char kModuleParseDocstring[] =
- R"(Parses a module's assembly format from a string.
-
-Returns a new MlirModule or raises an MLIRError if the parsing fails.
-
-See also: https://mlir.llvm.org/docs/LangRef/
-)";
-
-static const char kDumpDocstring[] =
- "Dumps a debug representation of the object to stderr.";
-
-static const char kValueReplaceAllUsesExceptDocstring[] =
- R"(Replace all uses of this value with the `with` value, except for those
-in `exceptions`. `exceptions` can be either a single operation or a list of
-operations.
-)";
-
-//------------------------------------------------------------------------------
-// Utilities.
-//------------------------------------------------------------------------------
-
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-static nb::object classmethod(Func f, Args... args) {
- nb::object cf = nb::cpp_function(f, args...);
- return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
-}
-
-static nb::object
-createCustomDialectWrapper(const std::string &dialectNamespace,
- nb::object dialectDescriptor) {
- auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
- if (!dialectClass) {
- // Use the base class.
- return nb::cast(PyDialect(std::move(dialectDescriptor)));
- }
-
- // Create the custom implementation.
- return (*dialectClass)(std::move(dialectDescriptor));
-}
-
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(std::string_view s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(const nb::bytes &s) {
- return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
-}
-
-/// Create a block, using the current location context if no locations are
-/// specified.
-static MlirBlock createBlock(const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- SmallVector<MlirType> argTypes;
- argTypes.reserve(nb::len(pyArgTypes));
- for (const auto &pyType : pyArgTypes)
- argTypes.push_back(nb::cast<PyType &>(pyType));
-
- SmallVector<MlirLocation> argLocs;
- if (pyArgLocs) {
- argLocs.reserve(nb::len(*pyArgLocs));
- for (const auto &pyLoc : *pyArgLocs)
- argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
- } else if (!argTypes.empty()) {
- argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
- }
-
- if (argTypes.size() != argLocs.size())
- throw nb::value_error(("Expected " + Twine(argTypes.size()) +
- " locations, got: " + Twine(argLocs.size()))
- .str()
- .c_str());
- return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
-}
-
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nb::object &o, bool enable) {
- nb::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nb::object &) {
- nb::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
-
- static void bind(nb::module_ &m) {
- // Debug flags.
- nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- "types"_a, "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- "types"_a,
- "Sets multiple specific debug types to be produced by LLVM.");
- }
-
-private:
- static nb::ft_mutex mutex;
-};
-
-nb::ft_mutex PyGlobalDebugFlag::mutex;
-
-struct PyAttrBuilderMap {
- static bool dunderContains(const std::string &attributeKind) {
- return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
- }
- static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
- auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
- if (!builder)
- throw nb::key_error(attributeKind.c_str());
- return *builder;
- }
- static void dunderSetItemNamed(const std::string &attributeKind,
- nb::callable func, bool replace) {
- PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
- replace);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
- .def_static("contains", &PyAttrBuilderMap::dunderContains,
- "attribute_kind"_a,
- "Checks whether an attribute builder is registered for the "
- "given attribute kind.")
- .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
- "attribute_kind"_a,
- "Gets the registered attribute builder for the given "
- "attribute kind.")
- .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
- "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
- "Register an attribute builder for building MLIR "
- "attributes from Python values.");
- }
-};
-
-//------------------------------------------------------------------------------
-// PyBlock
-//------------------------------------------------------------------------------
-
-nb::object PyBlock::getCapsule() {
- return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
-}
-
-//------------------------------------------------------------------------------
-// Collections.
-//------------------------------------------------------------------------------
-
-namespace {
-
-class PyRegionIterator {
-public:
- PyRegionIterator(PyOperationRef operation, int nextIndex)
- : operation(std::move(operation)), nextIndex(nextIndex) {}
-
- PyRegionIterator &dunderIter() { return *this; }
-
- PyRegion dunderNext() {
- operation->checkValid();
- if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
- throw nb::stop_iteration();
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
- return PyRegion(operation, region);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyRegionIterator>(m, "RegionIterator")
- .def("__iter__", &PyRegionIterator::dunderIter,
- "Returns an iterator over the regions in the operation.")
- .def("__next__", &PyRegionIterator::dunderNext,
- "Returns the next region in the iteration.");
- }
-
-private:
- PyOperationRef operation;
- intptr_t nextIndex = 0;
-};
-
-/// Regions of an op are fixed length and indexed numerically so are represented
-/// with a sequence-like container.
-class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
-public:
- static constexpr const char *pyClassName = "RegionSequence";
-
- PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumRegions(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- PyRegionIterator dunderIter() {
- operation->checkValid();
- return PyRegionIterator(operation, startIndex);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__iter__", &PyRegionList::dunderIter,
- "Returns an iterator over the regions in the sequence.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyRegionList, PyRegion>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumRegions(operation->get());
- }
-
- PyRegion getRawElement(intptr_t pos) {
- operation->checkValid();
- return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
- }
-
- PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyRegionList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-class PyBlockIterator {
-public:
- PyBlockIterator(PyOperationRef operation, MlirBlock next)
- : operation(std::move(operation)), next(next) {}
-
- PyBlockIterator &dunderIter() { return *this; }
-
- PyBlock dunderNext() {
- operation->checkValid();
- if (mlirBlockIsNull(next)) {
- throw nb::stop_iteration();
- }
-
- PyBlock returnBlock(operation, next);
- next = mlirBlockGetNextInRegion(next);
- return returnBlock;
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyBlockIterator>(m, "BlockIterator")
- .def("__iter__", &PyBlockIterator::dunderIter,
- "Returns an iterator over the blocks in the operation's region.")
- .def("__next__", &PyBlockIterator::dunderNext,
- "Returns the next block in the iteration.");
- }
-
-private:
- PyOperationRef operation;
- MlirBlock next;
-};
-
-/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
-/// we present them as a more full-featured list-like container but optimize
-/// it for forward iteration. Blocks are always owned by a region.
-class PyBlockList {
-public:
- PyBlockList(PyOperationRef operation, MlirRegion region)
- : operation(std::move(operation)), region(region) {}
-
- PyBlockIterator dunderIter() {
- operation->checkValid();
- return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
- }
-
- intptr_t dunderLen() {
- operation->checkValid();
- intptr_t count = 0;
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- count += 1;
- block = mlirBlockGetNextInRegion(block);
- }
- return count;
- }
-
- PyBlock dunderGetItem(intptr_t index) {
- operation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nb::index_error("attempt to access out of bounds block");
- }
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- if (index == 0) {
- return PyBlock(operation, block);
- }
- block = mlirBlockGetNextInRegion(block);
- index -= 1;
- }
- throw nb::index_error("attempt to access out of bounds block");
- }
-
- PyBlock appendBlock(const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- operation->checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- mlirRegionAppendOwnedBlock(region, block);
- return PyBlock(operation, block);
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyBlockList>(m, "BlockList")
- .def("__getitem__", &PyBlockList::dunderGetItem,
- "Returns the block at the specified index.")
- .def("__iter__", &PyBlockList::dunderIter,
- "Returns an iterator over blocks in the operation's region.")
- .def("__len__", &PyBlockList::dunderLen,
- "Returns the number of blocks in the operation's region.")
- .def("append", &PyBlockList::appendBlock,
- R"(
- Appends a new block, with argument types as positional args.
-
- Returns:
- The created block.
- )",
- nb::arg("args"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt);
- }
-
-private:
- PyOperationRef operation;
- MlirRegion region;
-};
-
-class PyOperationIterator {
-public:
- PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
- : parentOperation(std::move(parentOperation)), next(next) {}
-
- PyOperationIterator &dunderIter() { return *this; }
-
- nb::typed<nb::object, PyOpView> dunderNext() {
- parentOperation->checkValid();
- if (mlirOperationIsNull(next)) {
- throw nb::stop_iteration();
- }
-
- PyOperationRef returnOperation =
- PyOperation::forOperation(parentOperation->getContext(), next);
- next = mlirOperationGetNextInBlock(next);
- return returnOperation->createOpView();
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOperationIterator>(m, "OperationIterator")
- .def("__iter__", &PyOperationIterator::dunderIter,
- "Returns an iterator over the operations in an operation's block.")
- .def("__next__", &PyOperationIterator::dunderNext,
- "Returns the next operation in the iteration.");
- }
-
-private:
- PyOperationRef parentOperation;
- MlirOperation next;
-};
-
-/// Operations are exposed by the C-API as a forward-only linked list. In
-/// Python, we present them as a more full-featured list-like container but
-/// optimize it for forward iteration. Iterable operations are always owned
-/// by a block.
-class PyOperationList {
-public:
- PyOperationList(PyOperationRef parentOperation, MlirBlock block)
- : parentOperation(std::move(parentOperation)), block(block) {}
-
- PyOperationIterator dunderIter() {
- parentOperation->checkValid();
- return PyOperationIterator(parentOperation,
- mlirBlockGetFirstOperation(block));
- }
-
- intptr_t dunderLen() {
- parentOperation->checkValid();
- intptr_t count = 0;
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- count += 1;
- childOp = mlirOperationGetNextInBlock(childOp);
- }
- return count;
- }
-
- nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
- parentOperation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nb::index_error("attempt to access out of bounds operation");
- }
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- if (index == 0) {
- return PyOperation::forOperation(parentOperation->getContext(), childOp)
- ->createOpView();
- }
- childOp = mlirOperationGetNextInBlock(childOp);
- index -= 1;
- }
- throw nb::index_error("attempt to access out of bounds operation");
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOperationList>(m, "OperationList")
- .def("__getitem__", &PyOperationList::dunderGetItem,
- "Returns the operation at the specified index.")
- .def("__iter__", &PyOperationList::dunderIter,
- "Returns an iterator over operations in the list.")
- .def("__len__", &PyOperationList::dunderLen,
- "Returns the number of operations in the list.");
- }
-
-private:
- PyOperationRef parentOperation;
- MlirBlock block;
-};
-
-class PyOpOperand {
-public:
- PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
-
- nb::typed<nb::object, PyOpView> getOwner() {
- MlirOperation owner = mlirOpOperandGetOwner(opOperand);
- PyMlirContextRef context =
- PyMlirContext::forContext(mlirOperationGetContext(owner));
- return PyOperation::forOperation(context, owner)->createOpView();
- }
-
- size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpOperand>(m, "OpOperand")
- .def_prop_ro("owner", &PyOpOperand::getOwner,
- "Returns the operation that owns this operand.")
- .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
- "Returns the operand number in the owning operation.");
- }
-
-private:
- MlirOpOperand opOperand;
-};
-
-class PyOpOperandIterator {
-public:
- PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
-
- PyOpOperandIterator &dunderIter() { return *this; }
-
- PyOpOperand dunderNext() {
- if (mlirOpOperandIsNull(opOperand))
- throw nb::stop_iteration();
-
- PyOpOperand returnOpOperand(opOperand);
- opOperand = mlirOpOperandGetNextUse(opOperand);
- return returnOpOperand;
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
- .def("__iter__", &PyOpOperandIterator::dunderIter,
- "Returns an iterator over operands.")
- .def("__next__", &PyOpOperandIterator::dunderNext,
- "Returns the next operand in the iteration.");
- }
-
-private:
- MlirOpOperand opOperand;
-};
-
-} // namespace
-
+namespace mlir::python {
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -1413,8 +919,12 @@ nb::object PyOperation::create(std::string_view name,
// Construct the operation.
PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
- if (!operation.ptr)
+ if (!operation.ptr) {
+ for (auto take : errors.take()) {
+ std::cout << take.message << "\n";
+ }
throw MLIRError("Operation creation failed", errors.take());
+ }
PyOperationRef created =
PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
@@ -1448,163 +958,6 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
-namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- // and redefine bindDerived.
- using ClassTy = nb::class_<DerivedTy, PyValue>;
- using IsAFunctionTy = bool (*)(MlirValue);
-
- PyConcreteValue() = default;
- PyConcreteValue(PyOperationRef operationRef, MlirValue value)
- : PyValue(operationRef, value) {}
- PyConcreteValue(PyValue &orig)
- : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
- /// Attempts to cast the original value to the derived type and throws on
- /// type mismatches.
- static MlirValue castFrom(PyValue &orig) {
- if (!DerivedTy::isaFunction(orig.get())) {
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
- throw nb::value_error((Twine("Cannot cast value to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str()
- .c_str());
- }
- return orig.get();
- }
-
- /// Binds the Python module objects to functions of this class.
- static void bind(nb::module_ &m) {
- auto cls = ClassTy(
- m, DerivedTy::pyClassName, nb::is_generic(),
- nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
- .str()
- .c_str()));
- cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
- cls.def_static(
- "isinstance",
- [](PyValue &otherValue) -> bool {
- return DerivedTy::isaFunction(otherValue);
- },
- nb::arg("other_value"));
- cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
- return self.maybeDownCast();
- });
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-} // namespace
-
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
- static constexpr const char *pyClassName = "OpResult";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in "
- "the IR");
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation that produces this result.");
- c.def_prop_ro(
- "result_number",
- [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- },
- "Returns the position of this result in the operation's result list.");
- }
-};
-
-/// Returns the list of types of the values held by container.
-template <typename Container>
-static std::vector<nb::typed<nb::object, PyType>>
-getValueTypes(Container &container, PyMlirContextRef &context) {
- std::vector<nb::typed<nb::object, PyType>> result;
- result.reserve(container.size());
- for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(PyType(context->getRef(),
- mlirValueGetType(container.getElement(i).get()))
- .maybeDownCast());
- }
- return result;
-}
-
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
- static constexpr const char *pyClassName = "OpResultList";
- using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
- PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumResults(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all results in this result list.");
- c.def_prop_ro(
- "owner",
- [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
- return self.operation->createOpView();
- },
- "Returns the operation that owns this result list.");
- }
-
- PyOperationRef &getOperation() { return operation; }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpResultList, PyOpResult>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
- }
-
- PyOpResult getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
- }
-
- PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpResultList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
//------------------------------------------------------------------------------
// PyOpView
//------------------------------------------------------------------------------
@@ -1706,7 +1059,7 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
}
}
-static MlirValue getUniqueResult(MlirOperation operation) {
+MlirValue getUniqueResult(MlirOperation operation) {
auto numResults = mlirOperationGetNumResults(operation);
if (numResults != 1) {
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
@@ -2319,2648 +1672,11 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-namespace {
-
-/// Python wrapper for MlirBlockArgument.
-class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
- static constexpr const char *pyClassName = "BlockArgument";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyBlockArgument &self) {
- return PyBlock(self.getParentOperation(),
- mlirBlockArgumentGetOwner(self.get()));
- },
- "Returns the block that owns this argument.");
- c.def_prop_ro(
- "arg_number",
- [](PyBlockArgument &self) {
- return mlirBlockArgumentGetArgNumber(self.get());
- },
- "Returns the position of this argument in the block's argument list.");
- c.def(
- "set_type",
- [](PyBlockArgument &self, PyType type) {
- return mlirBlockArgumentSetType(self.get(), type);
- },
- nb::arg("type"), "Sets the type of this block argument.");
- c.def(
- "set_location",
- [](PyBlockArgument &self, PyLocation loc) {
- return mlirBlockArgumentSetLocation(self.get(), loc);
- },
- nb::arg("loc"), "Sets the location of this block argument.");
- }
-};
-
-/// A list of block arguments. Internally, these are stored as consecutive
-/// elements, random access is cheap. The argument list is associated with the
-/// operation that contains the block (detached blocks are not allowed in
-/// Python bindings) and extends its lifetime.
-class PyBlockArgumentList
- : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
-public:
- static constexpr const char *pyClassName = "BlockArgumentList";
- using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
-
- PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumArguments(block) : length,
- step),
- operation(std::move(operation)), block(block) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyBlockArgumentList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all arguments in this argument list.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
-
- /// Returns the number of arguments in the list.
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirBlockGetNumArguments(block);
- }
-
- /// Returns `pos`-the element in the list.
- PyBlockArgument getRawElement(intptr_t pos) {
- MlirValue argument = mlirBlockGetArgument(block, pos);
- return PyBlockArgument(operation, argument);
- }
-
- /// Returns a sublist of this list.
- PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockArgumentList(operation, block, startIndex, length, step);
- }
-
- PyOperationRef operation;
- MlirBlock block;
-};
-
-/// A list of operation operands. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) operand list is associated
-/// with the operation whose operands these are, and thus extends the lifetime
-/// of this operation.
-class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
-public:
- static constexpr const char *pyClassName = "OpOperandList";
- using SliceableT = Sliceable<PyOpOperandList, PyValue>;
-
- PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumOperands(operation->get())
- : length,
- step),
- operation(operation) {}
-
- void dunderSetItem(intptr_t index, PyValue value) {
- index = wrapIndex(index);
- mlirOperationSetOperand(operation->get(), index, value.get());
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"),
- nb::arg("value"),
- "Sets the operand at the specified index to a new value.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpOperandList, PyValue>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumOperands(operation->get());
- }
-
- PyValue getRawElement(intptr_t pos) {
- MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
- return PyValue(pyOwner, operand);
- }
-
- PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpOperandList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-/// A list of operation successors. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) successor list is
-/// associated with the operation whose successors these are, and thus extends
-/// the lifetime of this operation.
-class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "OpSuccessors";
-
- PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumSuccessors(operation->get())
- : length,
- step),
- operation(operation) {}
-
- void dunderSetItem(intptr_t index, PyBlock block) {
- index = wrapIndex(index);
- mlirOperationSetSuccessor(operation->get(), index, block.get());
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"),
- nb::arg("block"), "Sets the successor block at the specified index.");
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpSuccessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumSuccessors(operation->get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
- return PyBlock(operation, block);
- }
-
- PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpSuccessors(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
-/// A list of block successors. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) successor list is
-/// associated with the operation and block whose successors these are, and thus
-/// extends the lifetime of this operation and block.
-class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "BlockSuccessors";
-
- PyBlockSuccessors(PyBlock block, PyOperationRef operation,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumSuccessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockSuccessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumSuccessors(block.get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
-
- PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyBlockSuccessors(block, operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
- PyBlock block;
-};
-
-/// A list of block predecessors. The (returned) predecessor list is
-/// associated with the operation and block whose predecessors these are, and
-/// thus extends the lifetime of this operation and block.
-///
-/// WARNING: This Sliceable is more expensive than the others here because
-/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
-/// operands) anew for each indexed access.
-class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
-public:
- static constexpr const char *pyClassName = "BlockPredecessors";
-
- PyBlockPredecessors(PyBlock block, PyOperationRef operation,
- intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumPredecessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyBlockPredecessors, PyBlock>;
-
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumPredecessors(block.get());
- }
-
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
-
- PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockPredecessors(block, operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
- PyBlock block;
-};
-
-/// A list of operation attributes. Can be indexed by name, producing
-/// attributes, or by index, producing named attributes.
-class PyOpAttributeMap {
-public:
- PyOpAttributeMap(PyOperationRef operation)
- : operation(std::move(operation)) {}
-
- nb::typed<nb::object, PyAttribute>
- dunderGetItemNamed(const std::string &name) {
- MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
- throw nb::key_error("attempt to access a non-existent attribute");
- }
- return PyAttribute(operation->getContext(), attr).maybeDownCast();
- }
-
- PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0 || index >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr =
- mlirOperationGetAttribute(operation->get(), index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data,
- mlirIdentifierStr(namedAttr.name).length));
- }
-
- void dunderSetItem(const std::string &name, const PyAttribute &attr) {
- mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
- attr);
- }
-
- void dunderDelItem(const std::string &name) {
- int removed = mlirOperationRemoveAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (!removed)
- throw nb::key_error("attempt to delete a non-existent attribute");
- }
-
- intptr_t dunderLen() {
- return mlirOperationGetNumAttributes(operation->get());
- }
-
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
- operation->get(), toMlirStringRef(name)));
- }
-
- static void
- forEachAttr(MlirOperation op,
- llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
- intptr_t n = mlirOperationGetNumAttributes(op);
- for (intptr_t i = 0; i < n; ++i) {
- MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
- MlirStringRef name = mlirIdentifierStr(na.name);
- fn(name, na.attribute);
- }
- }
-
- static void bind(nb::module_ &m) {
- nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
- .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"),
- "Checks if an attribute with the given name exists in the map.")
- .def("__len__", &PyOpAttributeMap::dunderLen,
- "Returns the number of attributes in the map.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
- nb::arg("name"), "Gets an attribute by name.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
- nb::arg("index"), "Gets a named attribute by index.")
- .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"),
- nb::arg("attr"), "Sets an attribute with the given name.")
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"),
- "Deletes an attribute with the given name.")
- .def(
- "__iter__",
- [](PyOpAttributeMap &self) {
- nb::list keys;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- keys.append(nb::str(name.data, name.length));
- });
- return nb::iter(keys);
- },
- "Iterates over attribute names.")
- .def(
- "keys",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- out.append(nb::str(name.data, name.length));
- });
- return out;
- },
- "Returns a list of attribute names.")
- .def(
- "values",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef, MlirAttribute attr) {
- out.append(PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast());
- });
- return out;
- },
- "Returns a list of attribute values.")
- .def(
- "items",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute attr) {
- out.append(nb::make_tuple(
- nb::str(name.data, name.length),
- PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast()));
- });
- return out;
- },
- "Returns a list of `(name, attribute)` tuples.");
- }
-
-private:
- PyOperationRef operation;
-};
-
-// see
-// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
-
-#ifndef _Py_CAST
-#define _Py_CAST(type, expr) ((type)(expr))
-#endif
-
-// Static inline functions should use _Py_NULL rather than using directly NULL
-// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
-// _Py_NULL is defined as nullptr.
-#ifndef _Py_NULL
-#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
- (defined(__cplusplus) && __cplusplus >= 201103)
-#define _Py_NULL nullptr
-#else
-#define _Py_NULL NULL
-#endif
-#endif
-
-// Python 3.10.0a3
-#if PY_VERSION_HEX < 0x030A00A3
-
-// bpo-42262 added Py_XNewRef()
-#if !defined(Py_XNewRef)
-[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
- Py_XINCREF(obj);
- return obj;
-}
-#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
-#endif
-
-// bpo-42262 added Py_NewRef()
-#if !defined(Py_NewRef)
-[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
- Py_INCREF(obj);
- return obj;
-}
-#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
-#endif
-
-#endif // Python 3.10.0a3
-
-// Python 3.9.0b1
-#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
-
-// bpo-40429 added PyThreadState_GetFrame()
-PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
- assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
- return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
-}
-
-// bpo-40421 added PyFrame_GetBack()
-PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
- assert(frame != _Py_NULL && "expected frame != _Py_NULL");
- return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
-}
-
-// bpo-40421 added PyFrame_GetCode()
-PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
- assert(frame != _Py_NULL && "expected frame != _Py_NULL");
- assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
- return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
-}
-
-#endif // Python 3.9.0b1
-
-MlirLocation tracebackToLocation(MlirContext ctx) {
- size_t framesLimit =
- PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
- // Use a thread_local here to avoid requiring a large amount of space.
- thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
- frames;
- size_t count = 0;
-
- nb::gil_scoped_acquire acquire;
- PyThreadState *tstate = PyThreadState_GET();
- PyFrameObject *next;
- PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
- // In the increment expression:
- // 1. get the next prev frame;
- // 2. decrement the ref count on the current frame (in order that it can get
- // gc'd, along with any objects in its closure and etc);
- // 3. set current = next.
- for (; pyFrame != nullptr && count < framesLimit;
- next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
- PyCodeObject *code = PyFrame_GetCode(pyFrame);
- auto fileNameStr =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
- llvm::StringRef fileName(fileNameStr);
- if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
- continue;
-
- // co_qualname and PyCode_Addr2Location added in py3.11
-#if PY_VERSION_HEX < 0x030B00F0
- std::string name =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
- llvm::StringRef funcName(name);
- int startLine = PyFrame_GetLineNumber(pyFrame);
- MlirLocation loc =
- mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
-#else
- std::string name =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
- llvm::StringRef funcName(name);
- int startLine, startCol, endLine, endCol;
- int lasti = PyFrame_GetLasti(pyFrame);
- if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
- &endCol)) {
- throw nb::python_error();
- }
- MlirLocation loc = mlirLocationFileLineColRangeGet(
- ctx, wrap(fileName), startLine, startCol, endLine, endCol);
-#endif
-
- frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
- ++count;
- }
- // When the loop breaks (after the last iter), current frame (if non-null)
- // is leaked without this.
- Py_XDECREF(pyFrame);
-
- if (count == 0)
- return mlirLocationUnknownGet(ctx);
-
- MlirLocation callee = frames[0];
- assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
- if (count == 1)
- return callee;
-
- MlirLocation caller = frames[count - 1];
- assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
- for (int i = count - 2; i >= 1; i--)
- caller = mlirLocationCallSiteGet(frames[i], caller);
-
- return mlirLocationCallSiteGet(callee, caller);
-}
-
-PyLocation
-maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
- if (location.has_value())
- return location.value();
- if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
- return DefaultingPyLocation::resolve();
-
- PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
- MlirLocation mlirLoc = tracebackToLocation(ctx.get());
- PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
- return {ref, mlirLoc};
-}
-
-} // namespace
-
-//------------------------------------------------------------------------------
-// Populates the core exports of the 'ir' submodule.
-//------------------------------------------------------------------------------
-
-void mlir::python::populateIRCore(nb::module_ &m) {
- // disable leak warnings which tend to be false positives.
- nb::set_leak_warnings(false);
- //----------------------------------------------------------------------------
- // Enums.
- //----------------------------------------------------------------------------
- nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
- .value("ERROR", MlirDiagnosticError)
- .value("WARNING", MlirDiagnosticWarning)
- .value("NOTE", MlirDiagnosticNote)
- .value("REMARK", MlirDiagnosticRemark);
-
- nb::enum_<MlirWalkOrder>(m, "WalkOrder")
- .value("PRE_ORDER", MlirWalkPreOrder)
- .value("POST_ORDER", MlirWalkPostOrder);
-
- nb::enum_<MlirWalkResult>(m, "WalkResult")
- .value("ADVANCE", MlirWalkResultAdvance)
- .value("INTERRUPT", MlirWalkResultInterrupt)
- .value("SKIP", MlirWalkResultSkip);
-
- //----------------------------------------------------------------------------
- // Mapping of Diagnostics.
- //----------------------------------------------------------------------------
- nb::class_<PyDiagnostic>(m, "Diagnostic")
- .def_prop_ro("severity", &PyDiagnostic::getSeverity,
- "Returns the severity of the diagnostic.")
- .def_prop_ro("location", &PyDiagnostic::getLocation,
- "Returns the location associated with the diagnostic.")
- .def_prop_ro("message", &PyDiagnostic::getMessage,
- "Returns the message text of the diagnostic.")
- .def_prop_ro("notes", &PyDiagnostic::getNotes,
- "Returns a tuple of attached note diagnostics.")
- .def(
- "__str__",
- [](PyDiagnostic &self) -> nb::str {
- if (!self.isValid())
- return nb::str("<Invalid Diagnostic>");
- return self.getMessage();
- },
- "Returns the diagnostic message as a string.");
-
- nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
- .def(
- "__init__",
- [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
- new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
- },
- "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
- .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
- "The severity level of the diagnostic.")
- .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
- "The location associated with the diagnostic.")
- .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
- "The message text of the diagnostic.")
- .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
- "List of attached note diagnostics.")
- .def(
- "__str__",
- [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
- "Returns the diagnostic message as a string.");
-
- nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
- .def("detach", &PyDiagnosticHandler::detach,
- "Detaches the diagnostic handler from the context.")
- .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
- "Returns True if the handler is attached to a context.")
- .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
- "Returns True if an error was encountered during diagnostic "
- "handling.")
- .def("__enter__", &PyDiagnosticHandler::contextEnter,
- "Enters the diagnostic handler as a context manager.")
- .def("__exit__", &PyDiagnosticHandler::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
- "Exits the diagnostic handler context manager.");
-
- // Expose DefaultThreadPool to python
- nb::class_<PyThreadPool>(m, "ThreadPool")
- .def(
- "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
- "Creates a new thread pool with default concurrency.")
- .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
- "Returns the maximum number of threads in the pool.")
- .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
- "Returns the raw pointer to the LLVM thread pool as a string.");
-
- nb::class_<PyMlirContext>(m, "Context")
- .def(
- "__init__",
- [](PyMlirContext &self) {
- MlirContext context = mlirContextCreateWithThreading(false);
- new (&self) PyMlirContext(context);
- },
- R"(
- Creates a new MLIR context.
-
- The context is the top-level container for all MLIR objects. It owns the storage
- for types, attributes, locations, and other core IR objects. A context can be
- configured to allow or disallow unregistered dialects and can have dialects
- loaded on-demand.)")
- .def_static("_get_live_count", &PyMlirContext::getLiveCount,
- "Gets the number of live Context objects.")
- .def(
- "_get_context_again",
- [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
- PyMlirContextRef ref = PyMlirContext::forContext(self.get());
- return ref.releaseObject();
- },
- "Gets another reference to the same context.")
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
- "Gets the number of live modules owned by this context.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
- "Gets a capsule wrapping the MlirContext.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyMlirContext::createFromCapsule,
- "Creates a Context from a capsule wrapping MlirContext.")
- .def("__enter__", &PyMlirContext::contextEnter,
- "Enters the context as a context manager.")
- .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
- "Exits the context manager.")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/)
- -> std::optional<nb::typed<nb::object, PyMlirContext>> {
- auto *context = PyThreadContextEntry::getDefaultContext();
- if (!context)
- return {};
- return nb::cast(context);
- },
- nb::sig("def current(/) -> Context | None"),
- "Gets the Context bound to the current thread or returns None if no "
- "context is set.")
- .def_prop_ro(
- "dialects",
- [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Gets a container for accessing dialects by name.")
- .def_prop_ro(
- "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Alias for `dialects`.")
- .def(
- "get_dialect_descriptor",
- [=](PyMlirContext &self, std::string &name) {
- MlirDialect dialect = mlirContextGetOrLoadDialect(
- self.get(), {name.data(), name.size()});
- if (mlirDialectIsNull(dialect)) {
- throw nb::value_error(
- (Twine("Dialect '") + name + "' not found").str().c_str());
- }
- return PyDialectDescriptor(self.getRef(), dialect);
- },
- nb::arg("dialect_name"),
- "Gets or loads a dialect by name, returning its descriptor object.")
- .def_prop_rw(
- "allow_unregistered_dialects",
- [](PyMlirContext &self) -> bool {
- return mlirContextGetAllowUnregisteredDialects(self.get());
- },
- [](PyMlirContext &self, bool value) {
- mlirContextSetAllowUnregisteredDialects(self.get(), value);
- },
- "Controls whether unregistered dialects are allowed in this context.")
- .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
- nb::arg("callback"),
- "Attaches a diagnostic handler that will receive callbacks.")
- .def(
- "enable_multithreading",
- [](PyMlirContext &self, bool enable) {
- mlirContextEnableMultithreading(self.get(), enable);
- },
- nb::arg("enable"),
- R"(
- Enables or disables multi-threading support in the context.
-
- Args:
- enable: Whether to enable (True) or disable (False) multi-threading.
- )")
- .def(
- "set_thread_pool",
- [](PyMlirContext &self, PyThreadPool &pool) {
- // we should disable multi-threading first before setting
- // new thread pool otherwise the assert in
- // MLIRContext::setThreadPool will be raised.
- mlirContextEnableMultithreading(self.get(), false);
- mlirContextSetThreadPool(self.get(), pool.get());
- },
- R"(
- Sets a custom thread pool for the context to use.
-
- Args:
- pool: A ThreadPool object to use for parallel operations.
-
- Note:
- Multi-threading is automatically disabled before setting the thread pool.)")
- .def(
- "get_num_threads",
- [](PyMlirContext &self) {
- return mlirContextGetNumThreads(self.get());
- },
- "Gets the number of threads in the context's thread pool.")
- .def(
- "_mlir_thread_pool_ptr",
- [](PyMlirContext &self) {
- MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
- std::stringstream ss;
- ss << pool.ptr;
- return ss.str();
- },
- "Gets the raw pointer to the LLVM thread pool as a string.")
- .def(
- "is_registered_operation",
- [](PyMlirContext &self, std::string &name) {
- return mlirContextIsRegisteredOperation(
- self.get(), MlirStringRef{name.data(), name.size()});
- },
- nb::arg("operation_name"),
- R"(
- Checks whether an operation with the given name is registered.
-
- Args:
- operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
-
- Returns:
- True if the operation is registered, False otherwise.)")
- .def(
- "append_dialect_registry",
- [](PyMlirContext &self, PyDialectRegistry ®istry) {
- mlirContextAppendDialectRegistry(self.get(), registry);
- },
- nb::arg("registry"),
- R"(
- Appends the contents of a dialect registry to the context.
-
- Args:
- registry: A DialectRegistry containing dialects to append.)")
- .def_prop_rw("emit_error_diagnostics",
- &PyMlirContext::getEmitErrorDiagnostics,
- &PyMlirContext::setEmitErrorDiagnostics,
- R"(
- Controls whether error diagnostics are emitted to diagnostic handlers.
-
- By default, error diagnostics are captured and reported through MLIRError exceptions.)")
- .def(
- "load_all_available_dialects",
- [](PyMlirContext &self) {
- mlirContextLoadAllAvailableDialects(self.get());
- },
- R"(
- Loads all dialects available in the registry into the context.
-
- This eagerly loads all dialects that have been registered, making them
- immediately available for use.)");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectDescriptor
- //----------------------------------------------------------------------------
- nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
- .def_prop_ro(
- "namespace",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- return nb::str(ns.data, ns.length);
- },
- "Returns the namespace of the dialect.")
- .def(
- "__repr__",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- std::string repr("<DialectDescriptor ");
- repr.append(ns.data, ns.length);
- repr.append(">");
- return repr;
- },
- nb::sig("def __repr__(self) -> str"),
- "Returns a string representation of the dialect descriptor.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialects
- //----------------------------------------------------------------------------
- nb::class_<PyDialects>(m, "Dialects")
- .def(
- "__getitem__",
- [=](PyDialects &self, std::string keyName) {
- MlirDialect dialect =
- self.getDialectForKey(keyName, /*attrError=*/false);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(keyName, std::move(descriptor));
- },
- "Gets a dialect by name using subscript notation.")
- .def(
- "__getattr__",
- [=](PyDialects &self, std::string attrName) {
- MlirDialect dialect =
- self.getDialectForKey(attrName, /*attrError=*/true);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(attrName, std::move(descriptor));
- },
- "Gets a dialect by name using attribute notation.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialect
- //----------------------------------------------------------------------------
- nb::class_<PyDialect>(m, "Dialect")
- .def(nb::init<nb::object>(), nb::arg("descriptor"),
- "Creates a Dialect from a DialectDescriptor.")
- .def_prop_ro(
- "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
- "Returns the DialectDescriptor for this dialect.")
- .def(
- "__repr__",
- [](const nb::object &self) {
- auto clazz = self.attr("__class__");
- return nb::str("<Dialect ") +
- self.attr("descriptor").attr("namespace") +
- nb::str(" (class ") + clazz.attr("__module__") +
- nb::str(".") + clazz.attr("__name__") + nb::str(")>");
- },
- nb::sig("def __repr__(self) -> str"),
- "Returns a string representation of the dialect.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectRegistry
- //----------------------------------------------------------------------------
- nb::class_<PyDialectRegistry>(m, "DialectRegistry")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
- "Gets a capsule wrapping the MlirDialectRegistry.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyDialectRegistry::createFromCapsule,
- "Creates a DialectRegistry from a capsule wrapping "
- "`MlirDialectRegistry`.")
- .def(nb::init<>(), "Creates a new empty dialect registry.");
-
- //----------------------------------------------------------------------------
- // Mapping of Location
- //----------------------------------------------------------------------------
- nb::class_<PyLocation>(m, "Location")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
- "Gets a capsule wrapping the MlirLocation.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
- "Creates a Location from a capsule wrapping MlirLocation.")
- .def("__enter__", &PyLocation::contextEnter,
- "Enters the location as a context manager.")
- .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
- "Exits the location context manager.")
- .def(
- "__eq__",
- [](PyLocation &self, PyLocation &other) -> bool {
- return mlirLocationEqual(self, other);
- },
- "Compares two locations for equality.")
- .def(
- "__eq__", [](PyLocation &self, nb::object other) { return false; },
- "Compares location with non-location object (always returns False).")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/) -> std::optional<PyLocation *> {
- auto *loc = PyThreadContextEntry::getDefaultLocation();
- if (!loc)
- return std::nullopt;
- return loc;
- },
- // clang-format off
- nb::sig("def current(/) -> Location | None"),
- // clang-format on
- "Gets the Location bound to the current thread or raises ValueError.")
- .def_static(
- "unknown",
- [](DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationUnknownGet(context->get()));
- },
- nb::arg("context") = nb::none(),
- "Gets a Location representing an unknown location.")
- .def_static(
- "callsite",
- [](PyLocation callee, const std::vector<PyLocation> &frames,
- DefaultingPyMlirContext context) {
- if (frames.empty())
- throw nb::value_error("No caller frames provided.");
- MlirLocation caller = frames.back().get();
- for (const PyLocation &frame :
- llvm::reverse(llvm::ArrayRef(frames).drop_back()))
- caller = mlirLocationCallSiteGet(frame.get(), caller);
- return PyLocation(context->getRef(),
- mlirLocationCallSiteGet(callee.get(), caller));
- },
- nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
- "Gets a Location representing a caller and callsite.")
- .def("is_a_callsite", mlirLocationIsACallSite,
- "Returns True if this location is a CallSiteLoc.")
- .def_prop_ro(
- "callee",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCallee(self));
- },
- "Gets the callee location from a CallSiteLoc.")
- .def_prop_ro(
- "caller",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCaller(self));
- },
- "Gets the caller location from a CallSiteLoc.")
- .def_static(
- "file",
- [](std::string filename, int line, int col,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationFileLineColGet(
- context->get(), toMlirStringRef(filename), line, col));
- },
- nb::arg("filename"), nb::arg("line"), nb::arg("col"),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a file, line and column.")
- .def_static(
- "file",
- [](std::string filename, int startLine, int startCol, int endLine,
- int endCol, DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFileLineColRangeGet(
- context->get(), toMlirStringRef(filename),
- startLine, startCol, endLine, endCol));
- },
- nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
- nb::arg("end_line"), nb::arg("end_col"),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a file, line and column range.")
- .def("is_a_file", mlirLocationIsAFileLineColRange,
- "Returns True if this location is a FileLineColLoc.")
- .def_prop_ro(
- "filename",
- [](MlirLocation loc) {
- return mlirIdentifierStr(
- mlirLocationFileLineColRangeGetFilename(loc));
- },
- "Gets the filename from a FileLineColLoc.")
- .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
- "Gets the start line number from a `FileLineColLoc`.")
- .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
- "Gets the start column number from a `FileLineColLoc`.")
- .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
- "Gets the end line number from a `FileLineColLoc`.")
- .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
- "Gets the end column number from a `FileLineColLoc`.")
- .def_static(
- "fused",
- [](const std::vector<PyLocation> &pyLocations,
- std::optional<PyAttribute> metadata,
- DefaultingPyMlirContext context) {
- llvm::SmallVector<MlirLocation, 4> locations;
- locations.reserve(pyLocations.size());
- for (auto &pyLocation : pyLocations)
- locations.push_back(pyLocation.get());
- MlirLocation location = mlirLocationFusedGet(
- context->get(), locations.size(), locations.data(),
- metadata ? metadata->get() : MlirAttribute{0});
- return PyLocation(context->getRef(), location);
- },
- nb::arg("locations"), nb::arg("metadata") = nb::none(),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a fused location with optional "
- "metadata.")
- .def("is_a_fused", mlirLocationIsAFused,
- "Returns True if this location is a `FusedLoc`.")
- .def_prop_ro(
- "locations",
- [](PyLocation &self) {
- unsigned numLocations = mlirLocationFusedGetNumLocations(self);
- std::vector<MlirLocation> locations(numLocations);
- if (numLocations)
- mlirLocationFusedGetLocations(self, locations.data());
- std::vector<PyLocation> pyLocations{};
- pyLocations.reserve(numLocations);
- for (unsigned i = 0; i < numLocations; ++i)
- pyLocations.emplace_back(self.getContext(), locations[i]);
- return pyLocations;
- },
- "Gets the list of locations from a `FusedLoc`.")
- .def_static(
- "name",
- [](std::string name, std::optional<PyLocation> childLoc,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationNameGet(
- context->get(), toMlirStringRef(name),
- childLoc ? childLoc->get()
- : mlirLocationUnknownGet(context->get())));
- },
- nb::arg("name"), nb::arg("childLoc") = nb::none(),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a named location with optional child "
- "location.")
- .def("is_a_name", mlirLocationIsAName,
- "Returns True if this location is a `NameLoc`.")
- .def_prop_ro(
- "name_str",
- [](MlirLocation loc) {
- return mlirIdentifierStr(mlirLocationNameGetName(loc));
- },
- "Gets the name string from a `NameLoc`.")
- .def_prop_ro(
- "child_loc",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationNameGetChildLoc(self));
- },
- "Gets the child location from a `NameLoc`.")
- .def_static(
- "from_attr",
- [](PyAttribute &attribute, DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFromAttribute(attribute));
- },
- nb::arg("attribute"), nb::arg("context") = nb::none(),
- "Gets a Location from a `LocationAttr`.")
- .def_prop_ro(
- "context",
- [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Location`.")
- .def_prop_ro(
- "attr",
- [](PyLocation &self) {
- return PyAttribute(self.getContext(),
- mlirLocationGetAttribute(self));
- },
- "Get the underlying `LocationAttr`.")
- .def(
- "emit_error",
- [](PyLocation &self, std::string message) {
- mlirEmitError(self, message.c_str());
- },
- nb::arg("message"),
- R"(
- Emits an error diagnostic at this location.
-
- Args:
- message: The error message to emit.)")
- .def(
- "__repr__",
- [](PyLocation &self) {
- PyPrintAccumulator printAccum;
- mlirLocationPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly representation of the location.");
-
- //----------------------------------------------------------------------------
- // Mapping of Module
- //----------------------------------------------------------------------------
- nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
- "Gets a capsule wrapping the MlirModule.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
- R"(
- Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
-
- This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
- prevent double-frees (of the underlying `mlir::Module`).)")
- .def("_clear_mlir_module", &PyModule::clearMlirModule,
- R"(
- Clears the internal MLIR module reference.
-
- This is used internally to prevent double-free when ownership is transferred
- via the C API capsule mechanism. Not intended for normal use.)")
- .def_static(
- "parse",
- [](const std::string &moduleAsm, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParse(
- context->get(), toMlirStringRef(moduleAsm));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "parse",
- [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParse(
- context->get(), toMlirStringRef(moduleAsm));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "parseFile",
- [](const std::string &path, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParseFromFile(
- context->get(), toMlirStringRef(path));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("path"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "create",
- [](const std::optional<PyLocation> &loc)
- -> nb::typed<nb::object, PyModule> {
- PyLocation pyLoc = maybeGetTracebackLocation(loc);
- MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("loc") = nb::none(), "Creates an empty module.")
- .def_prop_ro(
- "context",
- [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that created the `Module`.")
- .def_prop_ro(
- "operation",
- [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
- return PyOperation::forOperation(self.getContext(),
- mlirModuleGetOperation(self.get()),
- self.getRef().releaseObject())
- .releaseObject();
- },
- "Accesses the module as an operation.")
- .def_prop_ro(
- "body",
- [](PyModule &self) {
- PyOperationRef moduleOp = PyOperation::forOperation(
- self.getContext(), mlirModuleGetOperation(self.get()),
- self.getRef().releaseObject());
- PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
- return returnBlock;
- },
- "Return the block for this module.")
- .def(
- "dump",
- [](PyModule &self) {
- mlirOperationDump(mlirModuleGetOperation(self.get()));
- },
- kDumpDocstring)
- .def(
- "__str__",
- [](const nb::object &self) {
- // Defer to the operation's __str__.
- return self.attr("operation").attr("__str__")();
- },
- nb::sig("def __str__(self) -> str"),
- R"(
- Gets the assembly form of the operation with default options.
-
- If more advanced control over the assembly formatting or I/O options is needed,
- use the dedicated print or get_asm method, which supports keyword arguments to
- customize behavior.
- )")
- .def(
- "__eq__",
- [](PyModule &self, PyModule &other) {
- return mlirModuleEqual(self.get(), other.get());
- },
- "other"_a, "Compares two modules for equality.")
- .def(
- "__hash__",
- [](PyModule &self) { return mlirModuleHashValue(self.get()); },
- "Returns the hash value of the module.");
-
- //----------------------------------------------------------------------------
- // Mapping of Operation.
- //----------------------------------------------------------------------------
- nb::class_<PyOperationBase>(m, "_OperationBase")
- .def_prop_ro(
- MLIR_PYTHON_CAPI_PTR_ATTR,
- [](PyOperationBase &self) {
- return self.getOperation().getCapsule();
- },
- "Gets a capsule wrapping the `MlirOperation`.")
- .def(
- "__eq__",
- [](PyOperationBase &self, PyOperationBase &other) {
- return mlirOperationEqual(self.getOperation().get(),
- other.getOperation().get());
- },
- "Compares two operations for equality.")
- .def(
- "__eq__",
- [](PyOperationBase &self, nb::object other) { return false; },
- "Compares operation with non-operation object (always returns "
- "False).")
- .def(
- "__hash__",
- [](PyOperationBase &self) {
- return mlirOperationHashValue(self.getOperation().get());
- },
- "Returns the hash value of the operation.")
- .def_prop_ro(
- "attributes",
- [](PyOperationBase &self) {
- return PyOpAttributeMap(self.getOperation().getRef());
- },
- "Returns a dictionary-like map of operation attributes.")
- .def_prop_ro(
- "context",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
- PyOperation &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- return concreteOperation.getContext().getObject();
- },
- "Context that owns the operation.")
- .def_prop_ro(
- "name",
- [](PyOperationBase &self) {
- auto &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- MlirOperation operation = concreteOperation.get();
- return mlirIdentifierStr(mlirOperationGetName(operation));
- },
- "Returns the fully qualified name of the operation.")
- .def_prop_ro(
- "operands",
- [](PyOperationBase &self) {
- return PyOpOperandList(self.getOperation().getRef());
- },
- "Returns the list of operation operands.")
- .def_prop_ro(
- "regions",
- [](PyOperationBase &self) {
- return PyRegionList(self.getOperation().getRef());
- },
- "Returns the list of operation regions.")
- .def_prop_ro(
- "results",
- [](PyOperationBase &self) {
- return PyOpResultList(self.getOperation().getRef());
- },
- "Returns the list of Operation results.")
- .def_prop_ro(
- "result",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
- auto &operation = self.getOperation();
- return PyOpResult(operation.getRef(), getUniqueResult(operation))
- .maybeDownCast();
- },
- "Shortcut to get an op result if it has only one (throws an error "
- "otherwise).")
- .def_prop_rw(
- "location",
- [](PyOperationBase &self) {
- PyOperation &operation = self.getOperation();
- return PyLocation(operation.getContext(),
- mlirOperationGetLocation(operation.get()));
- },
- [](PyOperationBase &self, const PyLocation &location) {
- PyOperation &operation = self.getOperation();
- mlirOperationSetLocation(operation.get(), location.get());
- },
- nb::for_getter("Returns the source location the operation was "
- "defined or derived from."),
- nb::for_setter("Sets the source location the operation was defined "
- "or derived from."))
- .def_prop_ro(
- "parent",
- [](PyOperationBase &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto parent = self.getOperation().getParentOperation();
- if (parent)
- return parent->getObject();
- return {};
- },
- "Returns the parent operation, or `None` if at top level.")
- .def(
- "__str__",
- [](PyOperationBase &self) {
- return self.getAsm(/*binary=*/false,
- /*largeElementsLimit=*/std::nullopt,
- /*largeResourceLimit=*/std::nullopt,
- /*enableDebugInfo=*/false,
- /*prettyDebugInfo=*/false,
- /*printGenericOpForm=*/false,
- /*useLocalScope=*/false,
- /*useNameLocAsPrefix=*/false,
- /*assumeVerified=*/false,
- /*skipRegions=*/false);
- },
- nb::sig("def __str__(self) -> str"),
- "Returns the assembly form of the operation.")
- .def("print",
- nb::overload_cast<PyAsmState &, nb::object, bool>(
- &PyOperationBase::print),
- nb::arg("state"), nb::arg("file") = nb::none(),
- nb::arg("binary") = false,
- R"(
- Prints the assembly form of the operation to a file like object.
-
- Args:
- state: `AsmState` capturing the operation numbering and flags.
- file: Optional file like object to write to. Defaults to sys.stdout.
- binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
- .def("print",
- nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
- bool, bool, bool, bool, bool, bool, nb::object,
- bool, bool>(&PyOperationBase::print),
- // Careful: Lots of arguments must match up with print method.
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
- nb::arg("binary") = false, nb::arg("skip_regions") = false,
- R"(
- Prints the assembly form of the operation to a file like object.
-
- Args:
- large_elements_limit: Whether to elide elements attributes above this
- number of elements. Defaults to None (no limit).
- large_resource_limit: Whether to elide resource attributes above this
- number of characters. Defaults to None (no limit). If large_elements_limit
- is set and this is None, the behavior will be to use large_elements_limit
- as large_resource_limit.
- enable_debug_info: Whether to print debug/location information. Defaults
- to False.
- pretty_debug_info: Whether to format debug information for easier reading
- by a human (warning: the result is unparseable). Defaults to False.
- print_generic_op_form: Whether to print the generic assembly forms of all
- ops. Defaults to False.
- use_local_scope: Whether to print in a way that is more optimized for
- multi-threaded access but may not be consistent with how the overall
- module prints.
- use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
- prefixes for the SSA identifiers. Defaults to False.
- assume_verified: By default, if not printing generic form, the verifier
- will be run and if it fails, generic form will be printed with a comment
- about failed verification. While a reasonable default for interactive use,
- for systematic use, it is often better for the caller to verify explicitly
- and report failures in a more robust fashion. Set this to True if doing this
- in order to avoid running a redundant verification. If the IR is actually
- invalid, behavior is undefined.
- file: The file like object to write to. Defaults to sys.stdout.
- binary: Whether to write bytes (True) or str (False). Defaults to False.
- skip_regions: Whether to skip printing regions. Defaults to False.)")
- .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
- nb::arg("desired_version") = nb::none(),
- R"(
- Write the bytecode form of the operation to a file like object.
-
- Args:
- file: The file like object to write to.
- desired_version: Optional version of bytecode to emit.
- Returns:
- The bytecode writer status.)")
- .def("get_asm", &PyOperationBase::getAsm,
- // Careful: Lots of arguments must match up with get_asm method.
- nb::arg("binary") = false,
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
- R"(
- Gets the assembly form of the operation with all options available.
-
- Args:
- binary: Whether to return a bytes (True) or str (False) object. Defaults to
- False.
- ... others ...: See the print() method for common keyword arguments for
- configuring the printout.
- Returns:
- Either a bytes or str object, depending on the setting of the `binary`
- argument.)")
- .def("verify", &PyOperationBase::verify,
- "Verify the operation. Raises MLIRError if verification fails, and "
- "returns true otherwise.")
- .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
- "Puts self immediately after the other operation in its parent "
- "block.")
- .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
- "Puts self immediately before the other operation in its parent "
- "block.")
- .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
- nb::arg("other"),
- R"(
- Checks if this operation is before another in the same block.
-
- Args:
- other: Another operation in the same parent block.
-
- Returns:
- True if this operation is before `other` in the operation list of the parent block.)")
- .def(
- "clone",
- [](PyOperationBase &self,
- const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
- return self.getOperation().clone(ip);
- },
- nb::arg("ip") = nb::none(),
- R"(
- Creates a deep copy of the operation.
-
- Args:
- ip: Optional insertion point where the cloned operation should be inserted.
- If None, the current insertion point is used. If False, the operation
- remains detached.
-
- Returns:
- A new Operation that is a clone of this operation.)")
- .def(
- "detach_from_parent",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
- PyOperation &operation = self.getOperation();
- operation.checkValid();
- if (!operation.isAttached())
- throw nb::value_error("Detached operation has no parent.");
-
- operation.detachFromParent();
- return operation.createOpView();
- },
- "Detaches the operation from its parent block.")
- .def_prop_ro(
- "attached",
- [](PyOperationBase &self) {
- PyOperation &operation = self.getOperation();
- operation.checkValid();
- return operation.isAttached();
- },
- "Reports if the operation is attached to its parent block.")
- .def(
- "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
- R"(
- Erases the operation and frees its memory.
-
- Note:
- After erasing, any Python references to the operation become invalid.)")
- .def("walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = MlirWalkPostOrder,
- // clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
- // clang-format on
- R"(
- Walks the operation tree with a callback function.
-
- Args:
- callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
-
- nb::class_<PyOperation, PyOperationBase>(m, "Operation")
- .def_static(
- "create",
- [](std::string_view name,
- std::optional<std::vector<PyType *>> results,
- std::optional<std::vector<PyValue *>> operands,
- std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors, int regions,
- const std::optional<PyLocation> &location,
- const nb::object &maybeIp,
- bool inferType) -> nb::typed<nb::object, PyOperation> {
- // Unpack/validate operands.
- llvm::SmallVector<MlirValue, 4> mlirOperands;
- if (operands) {
- mlirOperands.reserve(operands->size());
- for (PyValue *operand : *operands) {
- if (!operand)
- throw nb::value_error("operand value cannot be None");
- mlirOperands.push_back(operand->get());
- }
- }
-
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- return PyOperation::create(name, results, mlirOperands, attributes,
- successors, regions, pyLoc, maybeIp,
- inferType);
- },
- nb::arg("name"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- nb::arg("infer_type") = false,
- R"(
- Creates a new operation.
-
- Args:
- name: Operation name (e.g. `dialect.operation`).
- results: Optional sequence of Type representing op result types.
- operands: Optional operands of the operation.
- attributes: Optional Dict of {str: Attribute}.
- successors: Optional List of Block for the operation's successors.
- regions: Number of regions to create (default = 0).
- location: Optional Location object (defaults to resolve from context manager).
- ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
- infer_type: Whether to infer result types (default = False).
- Returns:
- A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
- .def_static(
- "parse",
- [](const std::string &sourceStr, const std::string &sourceName,
- DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyOpView> {
- return PyOperation::parse(context->getRef(), sourceStr, sourceName)
- ->createOpView();
- },
- nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
- nb::arg("context") = nb::none(),
- "Parses an operation. Supports both text assembly format and binary "
- "bytecode format.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
- "Gets a capsule wrapping the MlirOperation.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyOperation::createFromCapsule,
- "Creates an Operation from a capsule wrapping MlirOperation.")
- .def_prop_ro(
- "operation",
- [](nb::object self) -> nb::typed<nb::object, PyOperation> {
- return self;
- },
- "Returns self (the operation).")
- .def_prop_ro(
- "opview",
- [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
- return self.createOpView();
- },
- R"(
- Returns an OpView of this operation.
-
- Note:
- If the operation has a registered and loaded dialect then this OpView will
- be concrete wrapper class.)")
- .def_prop_ro("block", &PyOperation::getBlock,
- "Returns the block containing this operation.")
- .def_prop_ro(
- "successors",
- [](PyOperationBase &self) {
- return PyOpSuccessors(self.getOperation().getRef());
- },
- "Returns the list of Operation successors.")
- .def(
- "replace_uses_of_with",
- [](PyOperation &self, PyValue &of, PyValue &with) {
- mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
- },
- "of"_a, "with_"_a,
- "Replaces uses of the 'of' value with the 'with' value inside the "
- "operation.")
- .def("_set_invalid", &PyOperation::setInvalid,
- "Invalidate the operation.");
-
- auto opViewClass =
- nb::class_<PyOpView, PyOperationBase>(m, "OpView")
- .def(nb::init<nb::typed<nb::object, PyOperation>>(),
- nb::arg("operation"))
- .def(
- "__init__",
- [](PyOpView *self, std::string_view name,
- std::tuple<int, bool> opRegionSpec,
- nb::object operandSegmentSpecObj,
- nb::object resultSegmentSpecObj,
- std::optional<nb::list> resultTypeList, nb::list operandList,
- std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions,
- const std::optional<PyLocation> &location,
- const nb::object &maybeIp) {
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- new (self) PyOpView(PyOpView::buildGeneric(
- name, opRegionSpec, operandSegmentSpecObj,
- resultSegmentSpecObj, resultTypeList, operandList,
- attributes, successors, regions, pyLoc, maybeIp));
- },
- nb::arg("name"), nb::arg("opRegionSpec"),
- nb::arg("operandSegmentSpecObj") = nb::none(),
- nb::arg("resultSegmentSpecObj") = nb::none(),
- nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
- nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(),
- nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
- nb::arg("ip") = nb::none())
- .def_prop_ro(
- "operation",
- [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
- return self.getOperationObject();
- })
- .def_prop_ro("opview",
- [](nb::object self) -> nb::typed<nb::object, PyOpView> {
- return self;
- })
- .def(
- "__str__",
- [](PyOpView &self) { return nb::str(self.getOperationObject()); })
- .def_prop_ro(
- "successors",
- [](PyOperationBase &self) {
- return PyOpSuccessors(self.getOperation().getRef());
- },
- "Returns the list of Operation successors.")
- .def(
- "_set_invalid",
- [](PyOpView &self) { self.getOperation().setInvalid(); },
- "Invalidate the operation.");
- opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
- opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
- opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
- // It is faster to pass the operation_name, ods_regions, and
- // ods_operand_segments/ods_result_segments as arguments to the constructor,
- // rather than to access them as attributes.
- opViewClass.attr("build_generic") = classmethod(
- [](nb::handle cls, std::optional<nb::list> resultTypeList,
- nb::list operandList, std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, std::optional<PyLocation> location,
- const nb::object &maybeIp) {
- std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
- std::tuple<int, bool> opRegionSpec =
- nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
- nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
- nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
- resultSegmentSpec, resultTypeList,
- operandList, attributes, successors,
- regions, pyLoc, maybeIp);
- },
- nb::arg("cls"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- "Builds a specific, generated OpView based on class level attributes.");
- opViewClass.attr("parse") = classmethod(
- [](const nb::object &cls, const std::string &sourceStr,
- const std::string &sourceName,
- DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
- PyOperationRef parsed =
- PyOperation::parse(context->getRef(), sourceStr, sourceName);
-
- // Check if the expected operation was parsed, and cast to to the
- // appropriate `OpView` subclass if successful.
- // NOTE: This accesses attributes that have been automatically added to
- // `OpView` subclasses, and is not intended to be used on `OpView`
- // directly.
- std::string clsOpName =
- nb::cast<std::string>(cls.attr("OPERATION_NAME"));
- MlirStringRef identifier =
- mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
- std::string_view parsedOpName(identifier.data, identifier.length);
- if (clsOpName != parsedOpName)
- throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
- parsedOpName + "'");
- return PyOpView::constructDerived(cls, parsed.getObject());
- },
- nb::arg("cls"), nb::arg("source"), nb::kw_only(),
- nb::arg("source_name") = "", nb::arg("context") = nb::none(),
- "Parses a specific, generated OpView based on class level attributes.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyRegion.
- //----------------------------------------------------------------------------
- nb::class_<PyRegion>(m, "Region")
- .def_prop_ro(
- "blocks",
- [](PyRegion &self) {
- return PyBlockList(self.getParentOperation(), self.get());
- },
- "Returns a forward-optimized sequence of blocks.")
- .def_prop_ro(
- "owner",
- [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation owning this region.")
- .def(
- "__iter__",
- [](PyRegion &self) {
- self.checkValid();
- MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
- return PyBlockIterator(self.getParentOperation(), firstBlock);
- },
- "Iterates over blocks in the region.")
- .def(
- "__eq__",
- [](PyRegion &self, PyRegion &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two regions for pointer equality.")
- .def(
- "__eq__", [](PyRegion &self, nb::object &other) { return false; },
- "Compares region with non-region object (always returns False).");
-
- //----------------------------------------------------------------------------
- // Mapping of PyBlock.
- //----------------------------------------------------------------------------
- nb::class_<PyBlock>(m, "Block")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
- "Gets a capsule wrapping the MlirBlock.")
- .def_prop_ro(
- "owner",
- [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
- return self.getParentOperation()->createOpView();
- },
- "Returns the owning operation of this block.")
- .def_prop_ro(
- "region",
- [](PyBlock &self) {
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- return PyRegion(self.getParentOperation(), region);
- },
- "Returns the owning region of this block.")
- .def_prop_ro(
- "arguments",
- [](PyBlock &self) {
- return PyBlockArgumentList(self.getParentOperation(), self.get());
- },
- "Returns a list of block arguments.")
- .def(
- "add_argument",
- [](PyBlock &self, const PyType &type, const PyLocation &loc) {
- return PyBlockArgument(self.getParentOperation(),
- mlirBlockAddArgument(self.get(), type, loc));
- },
- "type"_a, "loc"_a,
- R"(
- Appends an argument of the specified type to the block.
-
- Args:
- type: The type of the argument to add.
- loc: The source location for the argument.
-
- Returns:
- The newly added block argument.)")
- .def(
- "erase_argument",
- [](PyBlock &self, unsigned index) {
- return mlirBlockEraseArgument(self.get(), index);
- },
- nb::arg("index"),
- R"(
- Erases the argument at the specified index.
-
- Args:
- index: The index of the argument to erase.)")
- .def_prop_ro(
- "operations",
- [](PyBlock &self) {
- return PyOperationList(self.getParentOperation(), self.get());
- },
- "Returns a forward-optimized sequence of operations.")
- .def_static(
- "create_at_start",
- [](PyRegion &parent, const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- parent.checkValid();
- MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
- mlirRegionInsertOwnedBlock(parent, 0, block);
- return PyBlock(parent.getParentOperation(), block);
- },
- nb::arg("parent"), nb::arg("arg_types") = nb::list(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block at the beginning of the given "
- "region (with given argument types and locations).")
- .def(
- "append_to",
- [](PyBlock &self, PyRegion ®ion) {
- MlirBlock b = self.get();
- if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
- mlirBlockDetach(b);
- mlirRegionAppendOwnedBlock(region.get(), b);
- },
- nb::arg("region"),
- R"(
- Appends this block to a region.
-
- Transfers ownership if the block is currently owned by another region.
-
- Args:
- region: The region to append the block to.)")
- .def(
- "create_before",
- [](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- self.checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
- return PyBlock(self.getParentOperation(), block);
- },
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block before this block "
- "(with given argument types and locations).")
- .def(
- "create_after",
- [](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- self.checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
- return PyBlock(self.getParentOperation(), block);
- },
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block after this block "
- "(with given argument types and locations).")
- .def(
- "__iter__",
- [](PyBlock &self) {
- self.checkValid();
- MlirOperation firstOperation =
- mlirBlockGetFirstOperation(self.get());
- return PyOperationIterator(self.getParentOperation(),
- firstOperation);
- },
- "Iterates over operations in the block.")
- .def(
- "__eq__",
- [](PyBlock &self, PyBlock &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two blocks for pointer equality.")
- .def(
- "__eq__", [](PyBlock &self, nb::object &other) { return false; },
- "Compares block with non-block object (always returns False).")
- .def(
- "__hash__",
- [](PyBlock &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the block.")
- .def(
- "__str__",
- [](PyBlock &self) {
- self.checkValid();
- PyPrintAccumulator printAccum;
- mlirBlockPrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the block.")
- .def(
- "append",
- [](PyBlock &self, PyOperationBase &operation) {
- if (operation.getOperation().isAttached())
- operation.getOperation().detachFromParent();
-
- MlirOperation mlirOperation = operation.getOperation().get();
- mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
- operation.getOperation().setAttached(
- self.getParentOperation().getObject());
- },
- nb::arg("operation"),
- R"(
- Appends an operation to this block.
-
- If the operation is currently in another block, it will be moved.
-
- Args:
- operation: The operation to append to the block.)")
- .def_prop_ro(
- "successors",
- [](PyBlock &self) {
- return PyBlockSuccessors(self, self.getParentOperation());
- },
- "Returns the list of Block successors.")
- .def_prop_ro(
- "predecessors",
- [](PyBlock &self) {
- return PyBlockPredecessors(self, self.getParentOperation());
- },
- "Returns the list of Block predecessors.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyInsertionPoint.
- //----------------------------------------------------------------------------
-
- nb::class_<PyInsertionPoint>(m, "InsertionPoint")
- .def(nb::init<PyBlock &>(), nb::arg("block"),
- "Inserts after the last operation but still inside the block.")
- .def("__enter__", &PyInsertionPoint::contextEnter,
- "Enters the insertion point as a context manager.")
- .def("__exit__", &PyInsertionPoint::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
- "Exits the insertion point context manager.")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/) {
- auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
- if (!ip)
- throw nb::value_error("No current InsertionPoint");
- return ip;
- },
- nb::sig("def current(/) -> InsertionPoint"),
- "Gets the InsertionPoint bound to the current thread or raises "
- "ValueError if none has been set.")
- .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
- "Inserts before a referenced operation.")
- .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
- nb::arg("block"),
- R"(
- Creates an insertion point at the beginning of a block.
-
- Args:
- block: The block at whose beginning operations should be inserted.
-
- Returns:
- An InsertionPoint at the block's beginning.)")
- .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
- nb::arg("block"),
- R"(
- Creates an insertion point before a block's terminator.
-
- Args:
- block: The block whose terminator to insert before.
-
- Returns:
- An InsertionPoint before the terminator.
-
- Raises:
- ValueError: If the block has no terminator.)")
- .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
- R"(
- Creates an insertion point immediately after an operation.
-
- Args:
- operation: The operation after which to insert.
-
- Returns:
- An InsertionPoint after the operation.)")
- .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
- R"(
- Inserts an operation at this insertion point.
-
- Args:
- operation: The operation to insert.)")
- .def_prop_ro(
- "block", [](PyInsertionPoint &self) { return self.getBlock(); },
- "Returns the block that this `InsertionPoint` points to.")
- .def_prop_ro(
- "ref_operation",
- [](PyInsertionPoint &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto refOperation = self.getRefOperation();
- if (refOperation)
- return refOperation->getObject();
- return {};
- },
- "The reference operation before which new operations are "
- "inserted, or None if the insertion point is at the end of "
- "the block.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyAttribute.
- //----------------------------------------------------------------------------
- nb::class_<PyAttribute>(m, "Attribute")
- // Delegate to the PyAttribute copy constructor, which will also lifetime
- // extend the backing context which owns the MlirAttribute.
- .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
- "Casts the passed attribute to the generic `Attribute`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
- "Gets a capsule wrapping the MlirAttribute.")
- .def_static(
- MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
- "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
- .def_static(
- "parse",
- [](const std::string &attrSpec, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyAttribute> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute attr = mlirAttributeParseGet(
- context->get(), toMlirStringRef(attrSpec));
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Unable to parse attribute", errors.take());
- return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- "Parses an attribute from an assembly form. Raises an `MLIRError` on "
- "failure.")
- .def_prop_ro(
- "context",
- [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Attribute`.")
- .def_prop_ro(
- "type",
- [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirAttributeGetType(self))
- .maybeDownCast();
- },
- "Returns the type of the `Attribute`.")
- .def(
- "get_named",
- [](PyAttribute &self, std::string name) {
- return PyNamedAttribute(self, std::move(name));
- },
- nb::keep_alive<0, 1>(),
- R"(
- Binds a name to the attribute, creating a `NamedAttribute`.
-
- Args:
- name: The name to bind to the `Attribute`.
-
- Returns:
- A `NamedAttribute` with the given name and this attribute.)")
- .def(
- "__eq__",
- [](PyAttribute &self, PyAttribute &other) { return self == other; },
- "Compares two attributes for equality.")
- .def(
- "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
- "Compares attribute with non-attribute object (always returns "
- "False).")
- .def(
- "__hash__",
- [](PyAttribute &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the attribute.")
- .def(
- "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
- kDumpDocstring)
- .def(
- "__str__",
- [](PyAttribute &self) {
- PyPrintAccumulator printAccum;
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the Attribute.")
- .def(
- "__repr__",
- [](PyAttribute &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, attribute values are generally considered useful and
- // are printed. This may need to be re-evaluated if debug dumps end
- // up being excessive.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Attribute(");
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the attribute.")
- .def_prop_ro(
- "typeid",
- [](PyAttribute &self) {
- MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
- assert(!mlirTypeIDIsNull(mlirTypeID) &&
- "mlirTypeID was expected to be non-null.");
- return PyTypeID(mlirTypeID);
- },
- "Returns the `TypeID` of the attribute.")
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
- return self.maybeDownCast();
- },
- "Downcasts the attribute to a more specific attribute if possible.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyNamedAttribute
- //----------------------------------------------------------------------------
- nb::class_<PyNamedAttribute>(m, "NamedAttribute")
- .def(
- "__repr__",
- [](PyNamedAttribute &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("NamedAttribute(");
- printAccum.parts.append(
- nb::str(mlirIdentifierStr(self.namedAttr.name).data,
- mlirIdentifierStr(self.namedAttr.name).length));
- printAccum.parts.append("=");
- mlirAttributePrint(self.namedAttr.attribute,
- printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the named attribute.")
- .def_prop_ro(
- "name",
- [](PyNamedAttribute &self) {
- return mlirIdentifierStr(self.namedAttr.name);
- },
- "The name of the `NamedAttribute` binding.")
- .def_prop_ro(
- "attr",
- [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
- nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
- "The underlying generic attribute of the `NamedAttribute` binding.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyType.
- //----------------------------------------------------------------------------
- nb::class_<PyType>(m, "Type")
- // Delegate to the PyType copy constructor, which will also lifetime
- // extend the backing context which owns the MlirType.
- .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
- "Casts the passed type to the generic `Type`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
- "Gets a capsule wrapping the `MlirType`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
- "Creates a Type from a capsule wrapping `MlirType`.")
- .def_static(
- "parse",
- [](std::string typeSpec,
- DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType type =
- mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
- if (mlirTypeIsNull(type))
- throw MLIRError("Unable to parse type", errors.take());
- return PyType(context.get()->getRef(), type).maybeDownCast();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- R"(
- Parses the assembly form of a type.
-
- Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
-
- See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
- .def_prop_ro(
- "context",
- [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Type`.")
- .def(
- "__eq__", [](PyType &self, PyType &other) { return self == other; },
- "Compares two types for equality.")
- .def(
- "__eq__", [](PyType &self, nb::object &other) { return false; },
- nb::arg("other").none(),
- "Compares type with non-type object (always returns False).")
- .def(
- "__hash__",
- [](PyType &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the `Type`.")
- .def(
- "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
- .def(
- "__str__",
- [](PyType &self) {
- PyPrintAccumulator printAccum;
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the `Type`.")
- .def(
- "__repr__",
- [](PyType &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, types are an exception as they typically have compact
- // assembly forms and printing them is useful.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Type(");
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the `Type`.")
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyType &self) -> nb::typed<nb::object, PyType> {
- return self.maybeDownCast();
- },
- "Downcasts the Type to a more specific `Type` if possible.")
- .def_prop_ro(
- "typeid",
- [](PyType &self) {
- MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
- if (!mlirTypeIDIsNull(mlirTypeID))
- return PyTypeID(mlirTypeID);
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
- throw nb::value_error(
- (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
- },
- "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
- "`Type` has no "
- "`TypeID`.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyTypeID.
- //----------------------------------------------------------------------------
- nb::class_<PyTypeID>(m, "TypeID")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
- "Gets a capsule wrapping the `MlirTypeID`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
- "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
- // Note, this tests whether the underlying TypeIDs are the same,
- // not whether the wrapper MlirTypeIDs are the same, nor whether
- // the Python objects are the same (i.e., PyTypeID is a value type).
- .def(
- "__eq__",
- [](PyTypeID &self, PyTypeID &other) { return self == other; },
- "Compares two `TypeID`s for equality.")
- .def(
- "__eq__",
- [](PyTypeID &self, const nb::object &other) { return false; },
- "Compares TypeID with non-TypeID object (always returns False).")
- // Note, this gives the hash value of the underlying TypeID, not the
- // hash value of the Python object, nor the hash value of the
- // MlirTypeID wrapper.
- .def(
- "__hash__",
- [](PyTypeID &self) {
- return static_cast<size_t>(mlirTypeIDHashValue(self));
- },
- "Returns the hash value of the `TypeID`.");
-
- //----------------------------------------------------------------------------
- // Mapping of Value.
- //----------------------------------------------------------------------------
- m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
-
- nb::class_<PyValue>(m, "Value", nb::is_generic(),
- nb::sig("class Value(Generic[_T])"))
- .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
- "Creates a Value reference from another `Value`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
- "Gets a capsule wrapping the `MlirValue`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
- "Creates a `Value` from a capsule wrapping `MlirValue`.")
- .def_prop_ro(
- "context",
- [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getParentOperation()->getContext().getObject();
- },
- "Context in which the value lives.")
- .def(
- "dump", [](PyValue &self) { mlirValueDump(self.get()); },
- kDumpDocstring)
- .def_prop_ro(
- "owner",
- [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
- MlirValue v = self.get();
- if (mlirValueIsAOpResult(v)) {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match "
- "that in "
- "the IR");
- return self.getParentOperation()->createOpView();
- }
-
- if (mlirValueIsABlockArgument(v)) {
- MlirBlock block = mlirBlockArgumentGetOwner(self.get());
- return nb::cast(PyBlock(self.getParentOperation(), block));
- }
-
- assert(false && "Value must be a block argument or an op result");
- return nb::none();
- },
- "Returns the owner of the value (`Operation` for results, `Block` "
- "for "
- "arguments).")
- .def_prop_ro(
- "uses",
- [](PyValue &self) {
- return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
- },
- "Returns an iterator over uses of this value.")
- .def(
- "__eq__",
- [](PyValue &self, PyValue &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two values for pointer equality.")
- .def(
- "__eq__", [](PyValue &self, nb::object other) { return false; },
- "Compares value with non-value object (always returns False).")
- .def(
- "__hash__",
- [](PyValue &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the value.")
- .def(
- "__str__",
- [](PyValue &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Value(");
- mlirValuePrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- R"(
- Returns the string form of the value.
-
- If the value is a block argument, this is the assembly form of its type and the
- position in the argument list. If the value is an operation result, this is
- equivalent to printing the operation that produced it.
- )")
- .def(
- "get_name",
- [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
- PyPrintAccumulator printAccum;
- MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- if (useNameLocAsPrefix)
- mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
- MlirAsmState valueState =
- mlirAsmStateCreateForValue(self.get(), flags);
- mlirValuePrintAsOperand(self.get(), valueState,
- printAccum.getCallback(),
- printAccum.getUserData());
- mlirOpPrintingFlagsDestroy(flags);
- mlirAsmStateDestroy(valueState);
- return printAccum.join();
- },
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- R"(
- Returns the string form of value as an operand.
-
- Args:
- use_local_scope: Whether to use local scope for naming.
- use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
-
- Returns:
- The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
- .def(
- "get_name",
- [](PyValue &self, PyAsmState &state) {
- PyPrintAccumulator printAccum;
- MlirAsmState valueState = state.get();
- mlirValuePrintAsOperand(self.get(), valueState,
- printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- nb::arg("state"),
- "Returns the string form of value as an operand (i.e., the ValueID).")
- .def_prop_ro(
- "type",
- [](PyValue &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()))
- .maybeDownCast();
- },
- "Returns the type of the value.")
- .def(
- "set_type",
- [](PyValue &self, const PyType &type) {
- mlirValueSetType(self.get(), type);
- },
- nb::arg("type"), "Sets the type of the value.",
- nb::sig("def set_type(self, type: _T)"))
- .def(
- "replace_all_uses_with",
- [](PyValue &self, PyValue &with) {
- mlirValueReplaceAllUsesOfWith(self.get(), with.get());
- },
- "Replace all uses of value with the new value, updating anything in "
- "the IR that uses `self` to use the other value instead.")
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, PyOperation &exception) {
- MlirOperation exceptedUser = exception.get();
- mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, const nb::list &exceptions) {
- // Convert Python list to a SmallVector of MlirOperations
- llvm::SmallVector<MlirOperation> exceptionOps;
- for (nb::handle exception : exceptions) {
- exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
- }
-
- mlirValueReplaceAllUsesExcept(
- self, with, static_cast<intptr_t>(exceptionOps.size()),
- exceptionOps.data());
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, PyOperation &exception) {
- MlirOperation exceptedUser = exception.get();
- mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with,
- std::vector<PyOperation> &exceptions) {
- // Convert Python list to a SmallVector of MlirOperations
- llvm::SmallVector<MlirOperation> exceptionOps;
- for (PyOperation &exception : exceptions)
- exceptionOps.push_back(exception);
- mlirValueReplaceAllUsesExcept(
- self, with, static_cast<intptr_t>(exceptionOps.size()),
- exceptionOps.data());
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) -> nb::typed<nb::object, PyValue> {
- return self.maybeDownCast();
- },
- "Downcasts the `Value` to a more specific kind if possible.")
- .def_prop_ro(
- "location",
- [](MlirValue self) {
- return PyLocation(
- PyMlirContext::forContext(mlirValueGetContext(self)),
- mlirValueGetLocation(self));
- },
- "Returns the source location of the value.");
-
- PyBlockArgument::bind(m);
- PyOpResult::bind(m);
- PyOpOperand::bind(m);
-
- nb::class_<PyAsmState>(m, "AsmState")
- .def(nb::init<PyValue &, bool>(), nb::arg("value"),
- nb::arg("use_local_scope") = false,
- R"(
- Creates an `AsmState` for consistent SSA value naming.
-
- Args:
- value: The value to create state for.
- use_local_scope: Whether to use local scope for naming.)")
- .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
- nb::arg("use_local_scope") = false,
- R"(
- Creates an AsmState for consistent SSA value naming.
-
- Args:
- op: The operation to create state for.
- use_local_scope: Whether to use local scope for naming.)");
-
- //----------------------------------------------------------------------------
- // Mapping of SymbolTable.
- //----------------------------------------------------------------------------
- nb::class_<PySymbolTable>(m, "SymbolTable")
- .def(nb::init<PyOperationBase &>(),
- R"(
- Creates a symbol table for an operation.
-
- Args:
- operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
-
- Raises:
- TypeError: If the operation is not a symbol table.)")
- .def(
- "__getitem__",
- [](PySymbolTable &self,
- const std::string &name) -> nb::typed<nb::object, PyOpView> {
- return self.dunderGetItem(name);
- },
- R"(
- Looks up a symbol by name in the symbol table.
-
- Args:
- name: The name of the symbol to look up.
-
- Returns:
- The operation defining the symbol.
-
- Raises:
- KeyError: If the symbol is not found.)")
- .def("insert", &PySymbolTable::insert, nb::arg("operation"),
- R"(
- Inserts a symbol operation into the symbol table.
-
- Args:
- operation: An operation with a symbol name to insert.
-
- Returns:
- The symbol name attribute of the inserted operation.
-
- Raises:
- ValueError: If the operation does not have a symbol name.)")
- .def("erase", &PySymbolTable::erase, nb::arg("operation"),
- R"(
- Erases a symbol operation from the symbol table.
-
- Args:
- operation: The symbol operation to erase.
-
- Note:
- The operation is also erased from the IR and invalidated.)")
- .def("__delitem__", &PySymbolTable::dunderDel,
- "Deletes a symbol by name from the symbol table.")
- .def(
- "__contains__",
- [](PySymbolTable &table, const std::string &name) {
- return !mlirOperationIsNull(mlirSymbolTableLookup(
- table, mlirStringRefCreate(name.data(), name.length())));
- },
- "Checks if a symbol with the given name exists in the table.")
- // Static helpers.
- .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
- nb::arg("symbol"), nb::arg("name"),
- "Sets the symbol name for a symbol operation.")
- .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
- nb::arg("symbol"),
- "Gets the symbol name from a symbol operation.")
- .def_static("get_visibility", &PySymbolTable::getVisibility,
- nb::arg("symbol"),
- "Gets the visibility attribute of a symbol operation.")
- .def_static("set_visibility", &PySymbolTable::setVisibility,
- nb::arg("symbol"), nb::arg("visibility"),
- "Sets the visibility attribute of a symbol operation.")
- .def_static("replace_all_symbol_uses",
- &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
- nb::arg("new_symbol"), nb::arg("from_op"),
- "Replaces all uses of a symbol with a new symbol name within "
- "the given operation.")
- .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
- nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
- nb::arg("callback"),
- "Walks symbol tables starting from an operation with a "
- "callback function.");
-
- // Container bindings.
- PyBlockArgumentList::bind(m);
- PyBlockIterator::bind(m);
- PyBlockList::bind(m);
- PyBlockSuccessors::bind(m);
- PyBlockPredecessors::bind(m);
- PyOperationIterator::bind(m);
- PyOperationList::bind(m);
- PyOpAttributeMap::bind(m);
- PyOpOperandIterator::bind(m);
- PyOpOperandList::bind(m);
- PyOpResultList::bind(m);
- PyOpSuccessors::bind(m);
- PyRegionIterator::bind(m);
- PyRegionList::bind(m);
-
- // Debug bindings.
- PyGlobalDebugFlag::bind(m);
-
- // Attribute builder getter.
- PyAttrBuilderMap::bind(m);
-
+void registerMLIRErrorInIRCore() {
nb::register_exception_translator([](const std::exception_ptr &p,
void *payload) {
- // We can't define exceptions with custom fields through pybind, so instead
- // the exception class is defined in python and imported here.
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
try {
if (p)
std::rethrow_exception(p);
@@ -4971,3 +1687,4 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
});
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 31d4798ffb906..f1e494c375523 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,7 +12,7 @@
#include <utility>
#include <vector>
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 34c5b8dd86a66..294ab91a059e2 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -7,14 +7,13 @@
//===----------------------------------------------------------------------===//
// clang-format off
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRTypes.h"
// clang-format on
#include <optional>
-#include "IRModule.h"
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
@@ -1144,7 +1143,8 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
} // namespace
-void mlir::python::populateIRTypes(nb::module_ &m) {
+namespace mlir::python {
+void populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
@@ -1175,4 +1175,18 @@ void mlir::python::populateIRTypes(nb::module_ &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+});
+}
}
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index ba767ad6692cf..686c55ee1e6a8 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,18 +6,2275 @@
//
//===----------------------------------------------------------------------===//
-#include "Globals.h"
-#include "IRModule.h"
-#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
using namespace mlir::python;
+static const char kModuleParseDocstring[] =
+ R"(Parses a module's assembly format from a string.
+
+Returns a new MlirModule or raises an MLIRError if the parsing fails.
+
+See also: https://mlir.llvm.org/docs/LangRef/
+)";
+
+static const char kDumpDocstring[] =
+ "Dumps a debug representation of the object to stderr.";
+
+static const char kValueReplaceAllUsesExceptDocstring[] =
+ R"(Replace all uses of this value with the `with` value, except for those
+in `exceptions`. `exceptions` can be either a single operation or a list of
+operations.
+)";
+
+namespace {
+// see
+// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
+
+#ifndef _Py_CAST
+#define _Py_CAST(type, expr) ((type)(expr))
+#endif
+
+// Static inline functions should use _Py_NULL rather than using directly NULL
+// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
+// _Py_NULL is defined as nullptr.
+#ifndef _Py_NULL
+#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
+ (defined(__cplusplus) && __cplusplus >= 201103)
+#define _Py_NULL nullptr
+#else
+#define _Py_NULL NULL
+#endif
+#endif
+
+// Python 3.10.0a3
+#if PY_VERSION_HEX < 0x030A00A3
+
+// bpo-42262 added Py_XNewRef()
+#if !defined(Py_XNewRef)
+[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
+ Py_XINCREF(obj);
+ return obj;
+}
+#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
+#endif
+
+// bpo-42262 added Py_NewRef()
+#if !defined(Py_NewRef)
+[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
+ Py_INCREF(obj);
+ return obj;
+}
+#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
+#endif
+
+#endif // Python 3.10.0a3
+
+// Python 3.9.0b1
+#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
+
+// bpo-40429 added PyThreadState_GetFrame()
+PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
+ assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
+}
+
+// bpo-40421 added PyFrame_GetBack()
+PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
+}
+
+// bpo-40421 added PyFrame_GetCode()
+PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
+ return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
+}
+
+#endif // Python 3.9.0b1
+
+MlirLocation tracebackToLocation(MlirContext ctx) {
+ size_t framesLimit =
+ PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
+ // Use a thread_local here to avoid requiring a large amount of space.
+ thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
+ frames;
+ size_t count = 0;
+
+ nb::gil_scoped_acquire acquire;
+ PyThreadState *tstate = PyThreadState_GET();
+ PyFrameObject *next;
+ PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
+ // In the increment expression:
+ // 1. get the next prev frame;
+ // 2. decrement the ref count on the current frame (in order that it can get
+ // gc'd, along with any objects in its closure and etc);
+ // 3. set current = next.
+ for (; pyFrame != nullptr && count < framesLimit;
+ next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
+ PyCodeObject *code = PyFrame_GetCode(pyFrame);
+ auto fileNameStr =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
+ llvm::StringRef fileName(fileNameStr);
+ if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
+ continue;
+
+ // co_qualname and PyCode_Addr2Location added in py3.11
+#if PY_VERSION_HEX < 0x030B00F0
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
+ llvm::StringRef funcName(name);
+ int startLine = PyFrame_GetLineNumber(pyFrame);
+ MlirLocation loc =
+ mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
+#else
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
+ llvm::StringRef funcName(name);
+ int startLine, startCol, endLine, endCol;
+ int lasti = PyFrame_GetLasti(pyFrame);
+ if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
+ &endCol)) {
+ throw nb::python_error();
+ }
+ MlirLocation loc = mlirLocationFileLineColRangeGet(
+ ctx, wrap(fileName), startLine, startCol, endLine, endCol);
+#endif
+
+ frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
+ ++count;
+ }
+ // When the loop breaks (after the last iter), current frame (if non-null)
+ // is leaked without this.
+ Py_XDECREF(pyFrame);
+
+ if (count == 0)
+ return mlirLocationUnknownGet(ctx);
+
+ MlirLocation callee = frames[0];
+ assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
+ if (count == 1)
+ return callee;
+
+ MlirLocation caller = frames[count - 1];
+ assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
+ for (int i = count - 2; i >= 1; i--)
+ caller = mlirLocationCallSiteGet(frames[i], caller);
+
+ return mlirLocationCallSiteGet(callee, caller);
+}
+
+PyLocation
+maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
+ if (location.has_value())
+ return location.value();
+ if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
+ return DefaultingPyLocation::resolve();
+
+ PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
+ MlirLocation mlirLoc = tracebackToLocation(ctx.get());
+ PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
+ return {ref, mlirLoc};
+}
+} // namespace
+
+//------------------------------------------------------------------------------
+// Populates the core exports of the 'ir' submodule.
+//------------------------------------------------------------------------------
+
+static void populateIRCore(nb::module_ &m) {
+ // disable leak warnings which tend to be false positives.
+ nb::set_leak_warnings(false);
+ //----------------------------------------------------------------------------
+ // Enums.
+ //----------------------------------------------------------------------------
+ nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
+ .value("ERROR", MlirDiagnosticError)
+ .value("WARNING", MlirDiagnosticWarning)
+ .value("NOTE", MlirDiagnosticNote)
+ .value("REMARK", MlirDiagnosticRemark);
+
+ nb::enum_<MlirWalkOrder>(m, "WalkOrder")
+ .value("PRE_ORDER", MlirWalkPreOrder)
+ .value("POST_ORDER", MlirWalkPostOrder);
+
+ nb::enum_<MlirWalkResult>(m, "WalkResult")
+ .value("ADVANCE", MlirWalkResultAdvance)
+ .value("INTERRUPT", MlirWalkResultInterrupt)
+ .value("SKIP", MlirWalkResultSkip);
+
+ //----------------------------------------------------------------------------
+ // Mapping of Diagnostics.
+ //----------------------------------------------------------------------------
+ nb::class_<PyDiagnostic>(m, "Diagnostic")
+ .def_prop_ro("severity", &PyDiagnostic::getSeverity,
+ "Returns the severity of the diagnostic.")
+ .def_prop_ro("location", &PyDiagnostic::getLocation,
+ "Returns the location associated with the diagnostic.")
+ .def_prop_ro("message", &PyDiagnostic::getMessage,
+ "Returns the message text of the diagnostic.")
+ .def_prop_ro("notes", &PyDiagnostic::getNotes,
+ "Returns a tuple of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic &self) -> nb::str {
+ if (!self.isValid())
+ return nb::str("<Invalid Diagnostic>");
+ return self.getMessage();
+ },
+ "Returns the diagnostic message as a string.");
+
+ nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
+ .def(
+ "__init__",
+ [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
+ new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
+ },
+ "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
+ .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
+ "The severity level of the diagnostic.")
+ .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
+ "The location associated with the diagnostic.")
+ .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
+ "The message text of the diagnostic.")
+ .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
+ "List of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
+ "Returns the diagnostic message as a string.");
+
+ nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
+ .def("detach", &PyDiagnosticHandler::detach,
+ "Detaches the diagnostic handler from the context.")
+ .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
+ "Returns True if the handler is attached to a context.")
+ .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
+ "Returns True if an error was encountered during diagnostic "
+ "handling.")
+ .def("__enter__", &PyDiagnosticHandler::contextEnter,
+ "Enters the diagnostic handler as a context manager.")
+ .def("__exit__", &PyDiagnosticHandler::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none(),
+ "Exits the diagnostic handler context manager.");
+
+ // Expose DefaultThreadPool to python
+ nb::class_<PyThreadPool>(m, "ThreadPool")
+ .def(
+ "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
+ "Creates a new thread pool with default concurrency.")
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
+ "Returns the maximum number of threads in the pool.")
+ .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
+ "Returns the raw pointer to the LLVM thread pool as a string.");
+
+ nb::class_<PyMlirContext>(m, "Context")
+ .def(
+ "__init__",
+ [](PyMlirContext &self) {
+ MlirContext context = mlirContextCreateWithThreading(false);
+ new (&self) PyMlirContext(context);
+ },
+ R"(
+ Creates a new MLIR context.
+
+ The context is the top-level container for all MLIR objects. It owns the storage
+ for types, attributes, locations, and other core IR objects. A context can be
+ configured to allow or disallow unregistered dialects and can have dialects
+ loaded on-demand.)")
+ .def_static("_get_live_count", &PyMlirContext::getLiveCount,
+ "Gets the number of live Context objects.")
+ .def(
+ "_get_context_again",
+ [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+ return ref.releaseObject();
+ },
+ "Gets another reference to the same context.")
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
+ "Gets the number of live modules owned by this context.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
+ "Gets a capsule wrapping the MlirContext.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyMlirContext::createFromCapsule,
+ "Creates a Context from a capsule wrapping MlirContext.")
+ .def("__enter__", &PyMlirContext::contextEnter,
+ "Enters the context as a context manager.")
+ .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
+ nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ "Exits the context manager.")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/)
+ -> std::optional<nb::typed<nb::object, PyMlirContext>> {
+ auto *context = PyThreadContextEntry::getDefaultContext();
+ if (!context)
+ return {};
+ return nb::cast(context);
+ },
+ nb::sig("def current(/) -> Context | None"),
+ "Gets the Context bound to the current thread or returns None if no "
+ "context is set.")
+ .def_prop_ro(
+ "dialects",
+ [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Gets a container for accessing dialects by name.")
+ .def_prop_ro(
+ "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Alias for `dialects`.")
+ .def(
+ "get_dialect_descriptor",
+ [=](PyMlirContext &self, std::string &name) {
+ MlirDialect dialect = mlirContextGetOrLoadDialect(
+ self.get(), {name.data(), name.size()});
+ if (mlirDialectIsNull(dialect)) {
+ throw nb::value_error(
+ (Twine("Dialect '") + name + "' not found").str().c_str());
+ }
+ return PyDialectDescriptor(self.getRef(), dialect);
+ },
+ nb::arg("dialect_name"),
+ "Gets or loads a dialect by name, returning its descriptor object.")
+ .def_prop_rw(
+ "allow_unregistered_dialects",
+ [](PyMlirContext &self) -> bool {
+ return mlirContextGetAllowUnregisteredDialects(self.get());
+ },
+ [](PyMlirContext &self, bool value) {
+ mlirContextSetAllowUnregisteredDialects(self.get(), value);
+ },
+ "Controls whether unregistered dialects are allowed in this context.")
+ .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
+ nb::arg("callback"),
+ "Attaches a diagnostic handler that will receive callbacks.")
+ .def(
+ "enable_multithreading",
+ [](PyMlirContext &self, bool enable) {
+ mlirContextEnableMultithreading(self.get(), enable);
+ },
+ nb::arg("enable"),
+ R"(
+ Enables or disables multi-threading support in the context.
+
+ Args:
+ enable: Whether to enable (True) or disable (False) multi-threading.
+ )")
+ .def(
+ "set_thread_pool",
+ [](PyMlirContext &self, PyThreadPool &pool) {
+ // we should disable multi-threading first before setting
+ // new thread pool otherwise the assert in
+ // MLIRContext::setThreadPool will be raised.
+ mlirContextEnableMultithreading(self.get(), false);
+ mlirContextSetThreadPool(self.get(), pool.get());
+ },
+ R"(
+ Sets a custom thread pool for the context to use.
+
+ Args:
+ pool: A ThreadPool object to use for parallel operations.
+
+ Note:
+ Multi-threading is automatically disabled before setting the thread pool.)")
+ .def(
+ "get_num_threads",
+ [](PyMlirContext &self) {
+ return mlirContextGetNumThreads(self.get());
+ },
+ "Gets the number of threads in the context's thread pool.")
+ .def(
+ "_mlir_thread_pool_ptr",
+ [](PyMlirContext &self) {
+ MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
+ std::stringstream ss;
+ ss << pool.ptr;
+ return ss.str();
+ },
+ "Gets the raw pointer to the LLVM thread pool as a string.")
+ .def(
+ "is_registered_operation",
+ [](PyMlirContext &self, std::string &name) {
+ return mlirContextIsRegisteredOperation(
+ self.get(), MlirStringRef{name.data(), name.size()});
+ },
+ nb::arg("operation_name"),
+ R"(
+ Checks whether an operation with the given name is registered.
+
+ Args:
+ operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
+
+ Returns:
+ True if the operation is registered, False otherwise.)")
+ .def(
+ "append_dialect_registry",
+ [](PyMlirContext &self, PyDialectRegistry ®istry) {
+ mlirContextAppendDialectRegistry(self.get(), registry);
+ },
+ nb::arg("registry"),
+ R"(
+ Appends the contents of a dialect registry to the context.
+
+ Args:
+ registry: A DialectRegistry containing dialects to append.)")
+ .def_prop_rw("emit_error_diagnostics",
+ &PyMlirContext::getEmitErrorDiagnostics,
+ &PyMlirContext::setEmitErrorDiagnostics,
+ R"(
+ Controls whether error diagnostics are emitted to diagnostic handlers.
+
+ By default, error diagnostics are captured and reported through MLIRError exceptions.)")
+ .def(
+ "load_all_available_dialects",
+ [](PyMlirContext &self) {
+ mlirContextLoadAllAvailableDialects(self.get());
+ },
+ R"(
+ Loads all dialects available in the registry into the context.
+
+ This eagerly loads all dialects that have been registered, making them
+ immediately available for use.)");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectDescriptor
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+ .def_prop_ro(
+ "namespace",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ return nb::str(ns.data, ns.length);
+ },
+ "Returns the namespace of the dialect.")
+ .def(
+ "__repr__",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ std::string repr("<DialectDescriptor ");
+ repr.append(ns.data, ns.length);
+ repr.append(">");
+ return repr;
+ },
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect descriptor.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialects
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialects>(m, "Dialects")
+ .def(
+ "__getitem__",
+ [=](PyDialects &self, std::string keyName) {
+ MlirDialect dialect =
+ self.getDialectForKey(keyName, /*attrError=*/false);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(keyName, std::move(descriptor));
+ },
+ "Gets a dialect by name using subscript notation.")
+ .def(
+ "__getattr__",
+ [=](PyDialects &self, std::string attrName) {
+ MlirDialect dialect =
+ self.getDialectForKey(attrName, /*attrError=*/true);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(attrName, std::move(descriptor));
+ },
+ "Gets a dialect by name using attribute notation.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialect
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialect>(m, "Dialect")
+ .def(nb::init<nb::object>(), nb::arg("descriptor"),
+ "Creates a Dialect from a DialectDescriptor.")
+ .def_prop_ro(
+ "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
+ "Returns the DialectDescriptor for this dialect.")
+ .def(
+ "__repr__",
+ [](const nb::object &self) {
+ auto clazz = self.attr("__class__");
+ return nb::str("<Dialect ") +
+ self.attr("descriptor").attr("namespace") +
+ nb::str(" (class ") + clazz.attr("__module__") +
+ nb::str(".") + clazz.attr("__name__") + nb::str(")>");
+ },
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectRegistry
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialectRegistry>(m, "DialectRegistry")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
+ "Gets a capsule wrapping the MlirDialectRegistry.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyDialectRegistry::createFromCapsule,
+ "Creates a DialectRegistry from a capsule wrapping "
+ "`MlirDialectRegistry`.")
+ .def(nb::init<>(), "Creates a new empty dialect registry.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Location
+ //----------------------------------------------------------------------------
+ nb::class_<PyLocation>(m, "Location")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
+ "Gets a capsule wrapping the MlirLocation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
+ "Creates a Location from a capsule wrapping MlirLocation.")
+ .def("__enter__", &PyLocation::contextEnter,
+ "Enters the location as a context manager.")
+ .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
+ nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ "Exits the location context manager.")
+ .def(
+ "__eq__",
+ [](PyLocation &self, PyLocation &other) -> bool {
+ return mlirLocationEqual(self, other);
+ },
+ "Compares two locations for equality.")
+ .def(
+ "__eq__", [](PyLocation &self, nb::object other) { return false; },
+ "Compares location with non-location object (always returns False).")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/) -> std::optional<PyLocation *> {
+ auto *loc = PyThreadContextEntry::getDefaultLocation();
+ if (!loc)
+ return std::nullopt;
+ return loc;
+ },
+ // clang-format off
+ nb::sig("def current(/) -> Location | None"),
+ // clang-format on
+ "Gets the Location bound to the current thread or raises ValueError.")
+ .def_static(
+ "unknown",
+ [](DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationUnknownGet(context->get()));
+ },
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing an unknown location.")
+ .def_static(
+ "callsite",
+ [](PyLocation callee, const std::vector<PyLocation> &frames,
+ DefaultingPyMlirContext context) {
+ if (frames.empty())
+ throw nb::value_error("No caller frames provided.");
+ MlirLocation caller = frames.back().get();
+ for (const PyLocation &frame :
+ llvm::reverse(llvm::ArrayRef(frames).drop_back()))
+ caller = mlirLocationCallSiteGet(frame.get(), caller);
+ return PyLocation(context->getRef(),
+ mlirLocationCallSiteGet(callee.get(), caller));
+ },
+ nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
+ "Gets a Location representing a caller and callsite.")
+ .def("is_a_callsite", mlirLocationIsACallSite,
+ "Returns True if this location is a CallSiteLoc.")
+ .def_prop_ro(
+ "callee",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCallee(self));
+ },
+ "Gets the callee location from a CallSiteLoc.")
+ .def_prop_ro(
+ "caller",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCaller(self));
+ },
+ "Gets the caller location from a CallSiteLoc.")
+ .def_static(
+ "file",
+ [](std::string filename, int line, int col,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationFileLineColGet(
+ context->get(), toMlirStringRef(filename), line, col));
+ },
+ nb::arg("filename"), nb::arg("line"), nb::arg("col"),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a file, line and column.")
+ .def_static(
+ "file",
+ [](std::string filename, int startLine, int startCol, int endLine,
+ int endCol, DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationFileLineColRangeGet(
+ context->get(), toMlirStringRef(filename),
+ startLine, startCol, endLine, endCol));
+ },
+ nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
+ nb::arg("end_line"), nb::arg("end_col"),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a file, line and column range.")
+ .def("is_a_file", mlirLocationIsAFileLineColRange,
+ "Returns True if this location is a FileLineColLoc.")
+ .def_prop_ro(
+ "filename",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(
+ mlirLocationFileLineColRangeGetFilename(loc));
+ },
+ "Gets the filename from a FileLineColLoc.")
+ .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
+ "Gets the start line number from a `FileLineColLoc`.")
+ .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
+ "Gets the start column number from a `FileLineColLoc`.")
+ .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
+ "Gets the end line number from a `FileLineColLoc`.")
+ .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
+ "Gets the end column number from a `FileLineColLoc`.")
+ .def_static(
+ "fused",
+ [](const std::vector<PyLocation> &pyLocations,
+ std::optional<PyAttribute> metadata,
+ DefaultingPyMlirContext context) {
+ llvm::SmallVector<MlirLocation, 4> locations;
+ locations.reserve(pyLocations.size());
+ for (auto &pyLocation : pyLocations)
+ locations.push_back(pyLocation.get());
+ MlirLocation location = mlirLocationFusedGet(
+ context->get(), locations.size(), locations.data(),
+ metadata ? metadata->get() : MlirAttribute{0});
+ return PyLocation(context->getRef(), location);
+ },
+ nb::arg("locations"), nb::arg("metadata") = nb::none(),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a fused location with optional "
+ "metadata.")
+ .def("is_a_fused", mlirLocationIsAFused,
+ "Returns True if this location is a `FusedLoc`.")
+ .def_prop_ro(
+ "locations",
+ [](PyLocation &self) {
+ unsigned numLocations = mlirLocationFusedGetNumLocations(self);
+ std::vector<MlirLocation> locations(numLocations);
+ if (numLocations)
+ mlirLocationFusedGetLocations(self, locations.data());
+ std::vector<PyLocation> pyLocations{};
+ pyLocations.reserve(numLocations);
+ for (unsigned i = 0; i < numLocations; ++i)
+ pyLocations.emplace_back(self.getContext(), locations[i]);
+ return pyLocations;
+ },
+ "Gets the list of locations from a `FusedLoc`.")
+ .def_static(
+ "name",
+ [](std::string name, std::optional<PyLocation> childLoc,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationNameGet(
+ context->get(), toMlirStringRef(name),
+ childLoc ? childLoc->get()
+ : mlirLocationUnknownGet(context->get())));
+ },
+ nb::arg("name"), nb::arg("childLoc") = nb::none(),
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a named location with optional child "
+ "location.")
+ .def("is_a_name", mlirLocationIsAName,
+ "Returns True if this location is a `NameLoc`.")
+ .def_prop_ro(
+ "name_str",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(mlirLocationNameGetName(loc));
+ },
+ "Gets the name string from a `NameLoc`.")
+ .def_prop_ro(
+ "child_loc",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationNameGetChildLoc(self));
+ },
+ "Gets the child location from a `NameLoc`.")
+ .def_static(
+ "from_attr",
+ [](PyAttribute &attribute, DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationFromAttribute(attribute));
+ },
+ nb::arg("attribute"), nb::arg("context") = nb::none(),
+ "Gets a Location from a `LocationAttr`.")
+ .def_prop_ro(
+ "context",
+ [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Location`.")
+ .def_prop_ro(
+ "attr",
+ [](PyLocation &self) {
+ return PyAttribute(self.getContext(),
+ mlirLocationGetAttribute(self));
+ },
+ "Get the underlying `LocationAttr`.")
+ .def(
+ "emit_error",
+ [](PyLocation &self, std::string message) {
+ mlirEmitError(self, message.c_str());
+ },
+ nb::arg("message"),
+ R"(
+ Emits an error diagnostic at this location.
+
+ Args:
+ message: The error message to emit.)")
+ .def(
+ "__repr__",
+ [](PyLocation &self) {
+ PyPrintAccumulator printAccum;
+ mlirLocationPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly representation of the location.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Module
+ //----------------------------------------------------------------------------
+ nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
+ "Gets a capsule wrapping the MlirModule.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+ R"(
+ Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
+
+ This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
+ prevent double-frees (of the underlying `mlir::Module`).)")
+ .def("_clear_mlir_module", &PyModule::clearMlirModule,
+ R"(
+ Clears the internal MLIR module reference.
+
+ This is used internally to prevent double-free when ownership is transferred
+ via the C API capsule mechanism. Not intended for normal use.)")
+ .def_static(
+ "parse",
+ [](const std::string &moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "parse",
+ [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "parseFile",
+ [](const std::string &path, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParseFromFile(
+ context->get(), toMlirStringRef(path));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("path"), nb::arg("context") = nb::none(),
+ kModuleParseDocstring)
+ .def_static(
+ "create",
+ [](const std::optional<PyLocation> &loc)
+ -> nb::typed<nb::object, PyModule> {
+ PyLocation pyLoc = maybeGetTracebackLocation(loc);
+ MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("loc") = nb::none(), "Creates an empty module.")
+ .def_prop_ro(
+ "context",
+ [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that created the `Module`.")
+ .def_prop_ro(
+ "operation",
+ [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
+ return PyOperation::forOperation(self.getContext(),
+ mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject())
+ .releaseObject();
+ },
+ "Accesses the module as an operation.")
+ .def_prop_ro(
+ "body",
+ [](PyModule &self) {
+ PyOperationRef moduleOp = PyOperation::forOperation(
+ self.getContext(), mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject());
+ PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
+ return returnBlock;
+ },
+ "Return the block for this module.")
+ .def(
+ "dump",
+ [](PyModule &self) {
+ mlirOperationDump(mlirModuleGetOperation(self.get()));
+ },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](const nb::object &self) {
+ // Defer to the operation's __str__.
+ return self.attr("operation").attr("__str__")();
+ },
+ nb::sig("def __str__(self) -> str"),
+ R"(
+ Gets the assembly form of the operation with default options.
+
+ If more advanced control over the assembly formatting or I/O options is needed,
+ use the dedicated print or get_asm method, which supports keyword arguments to
+ customize behavior.
+ )")
+ .def(
+ "__eq__",
+ [](PyModule &self, PyModule &other) {
+ return mlirModuleEqual(self.get(), other.get());
+ },
+ "other"_a, "Compares two modules for equality.")
+ .def(
+ "__hash__",
+ [](PyModule &self) { return mlirModuleHashValue(self.get()); },
+ "Returns the hash value of the module.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Operation.
+ //----------------------------------------------------------------------------
+ nb::class_<PyOperationBase>(m, "_OperationBase")
+ .def_prop_ro(
+ MLIR_PYTHON_CAPI_PTR_ATTR,
+ [](PyOperationBase &self) {
+ return self.getOperation().getCapsule();
+ },
+ "Gets a capsule wrapping the `MlirOperation`.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, PyOperationBase &other) {
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
+ },
+ "Compares two operations for equality.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, nb::object other) { return false; },
+ "Compares operation with non-operation object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyOperationBase &self) {
+ return mlirOperationHashValue(self.getOperation().get());
+ },
+ "Returns the hash value of the operation.")
+ .def_prop_ro(
+ "attributes",
+ [](PyOperationBase &self) {
+ return PyOpAttributeMap(self.getOperation().getRef());
+ },
+ "Returns a dictionary-like map of operation attributes.")
+ .def_prop_ro(
+ "context",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyOperation &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ return concreteOperation.getContext().getObject();
+ },
+ "Context that owns the operation.")
+ .def_prop_ro(
+ "name",
+ [](PyOperationBase &self) {
+ auto &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ MlirOperation operation = concreteOperation.get();
+ return mlirIdentifierStr(mlirOperationGetName(operation));
+ },
+ "Returns the fully qualified name of the operation.")
+ .def_prop_ro(
+ "operands",
+ [](PyOperationBase &self) {
+ return PyOpOperandList(self.getOperation().getRef());
+ },
+ "Returns the list of operation operands.")
+ .def_prop_ro(
+ "regions",
+ [](PyOperationBase &self) {
+ return PyRegionList(self.getOperation().getRef());
+ },
+ "Returns the list of operation regions.")
+ .def_prop_ro(
+ "results",
+ [](PyOperationBase &self) {
+ return PyOpResultList(self.getOperation().getRef());
+ },
+ "Returns the list of Operation results.")
+ .def_prop_ro(
+ "result",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
+ auto &operation = self.getOperation();
+ return PyOpResult(operation.getRef(), getUniqueResult(operation))
+ .maybeDownCast();
+ },
+ "Shortcut to get an op result if it has only one (throws an error "
+ "otherwise).")
+ .def_prop_rw(
+ "location",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ return PyLocation(operation.getContext(),
+ mlirOperationGetLocation(operation.get()));
+ },
+ [](PyOperationBase &self, const PyLocation &location) {
+ PyOperation &operation = self.getOperation();
+ mlirOperationSetLocation(operation.get(), location.get());
+ },
+ nb::for_getter("Returns the source location the operation was "
+ "defined or derived from."),
+ nb::for_setter("Sets the source location the operation was defined "
+ "or derived from."))
+ .def_prop_ro(
+ "parent",
+ [](PyOperationBase &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto parent = self.getOperation().getParentOperation();
+ if (parent)
+ return parent->getObject();
+ return {};
+ },
+ "Returns the parent operation, or `None` if at top level.")
+ .def(
+ "__str__",
+ [](PyOperationBase &self) {
+ return self.getAsm(/*binary=*/false,
+ /*largeElementsLimit=*/std::nullopt,
+ /*largeResourceLimit=*/std::nullopt,
+ /*enableDebugInfo=*/false,
+ /*prettyDebugInfo=*/false,
+ /*printGenericOpForm=*/false,
+ /*useLocalScope=*/false,
+ /*useNameLocAsPrefix=*/false,
+ /*assumeVerified=*/false,
+ /*skipRegions=*/false);
+ },
+ nb::sig("def __str__(self) -> str"),
+ "Returns the assembly form of the operation.")
+ .def("print",
+ nb::overload_cast<PyAsmState &, nb::object, bool>(
+ &PyOperationBase::print),
+ nb::arg("state"), nb::arg("file") = nb::none(),
+ nb::arg("binary") = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ state: `AsmState` capturing the operation numbering and flags.
+ file: Optional file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
+ .def("print",
+ nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
+ bool, bool, bool, bool, bool, bool, nb::object,
+ bool, bool>(&PyOperationBase::print),
+ // Careful: Lots of arguments must match up with print method.
+ nb::arg("large_elements_limit") = nb::none(),
+ nb::arg("large_resource_limit") = nb::none(),
+ nb::arg("enable_debug_info") = false,
+ nb::arg("pretty_debug_info") = false,
+ nb::arg("print_generic_op_form") = false,
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
+ nb::arg("binary") = false, nb::arg("skip_regions") = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ large_resource_limit: Whether to elide resource attributes above this
+ number of characters. Defaults to None (no limit). If large_elements_limit
+ is set and this is None, the behavior will be to use large_elements_limit
+ as large_resource_limit.
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable). Defaults to False.
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+ use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
+ prefixes for the SSA identifiers. Defaults to False.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ skip_regions: Whether to skip printing regions. Defaults to False.)")
+ .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
+ nb::arg("desired_version") = nb::none(),
+ R"(
+ Write the bytecode form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to.
+ desired_version: Optional version of bytecode to emit.
+ Returns:
+ The bytecode writer status.)")
+ .def("get_asm", &PyOperationBase::getAsm,
+ // Careful: Lots of arguments must match up with get_asm method.
+ nb::arg("binary") = false,
+ nb::arg("large_elements_limit") = nb::none(),
+ nb::arg("large_resource_limit") = nb::none(),
+ nb::arg("enable_debug_info") = false,
+ nb::arg("pretty_debug_info") = false,
+ nb::arg("print_generic_op_form") = false,
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
+ R"(
+ Gets the assembly form of the operation with all options available.
+
+ Args:
+ binary: Whether to return a bytes (True) or str (False) object. Defaults to
+ False.
+ ... others ...: See the print() method for common keyword arguments for
+ configuring the printout.
+ Returns:
+ Either a bytes or str object, depending on the setting of the `binary`
+ argument.)")
+ .def("verify", &PyOperationBase::verify,
+ "Verify the operation. Raises MLIRError if verification fails, and "
+ "returns true otherwise.")
+ .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
+ "Puts self immediately after the other operation in its parent "
+ "block.")
+ .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
+ "Puts self immediately before the other operation in its parent "
+ "block.")
+ .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
+ nb::arg("other"),
+ R"(
+ Checks if this operation is before another in the same block.
+
+ Args:
+ other: Another operation in the same parent block.
+
+ Returns:
+ True if this operation is before `other` in the operation list of the parent block.)")
+ .def(
+ "clone",
+ [](PyOperationBase &self,
+ const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperation().clone(ip);
+ },
+ nb::arg("ip") = nb::none(),
+ R"(
+ Creates a deep copy of the operation.
+
+ Args:
+ ip: Optional insertion point where the cloned operation should be inserted.
+ If None, the current insertion point is used. If False, the operation
+ remains detached.
+
+ Returns:
+ A new Operation that is a clone of this operation.)")
+ .def(
+ "detach_from_parent",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ if (!operation.isAttached())
+ throw nb::value_error("Detached operation has no parent.");
+
+ operation.detachFromParent();
+ return operation.createOpView();
+ },
+ "Detaches the operation from its parent block.")
+ .def_prop_ro(
+ "attached",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ return operation.isAttached();
+ },
+ "Reports if the operation is attached to its parent block.")
+ .def(
+ "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
+ R"(
+ Erases the operation and frees its memory.
+
+ Note:
+ After erasing, any Python references to the operation become invalid.)")
+ .def("walk", &PyOperationBase::walk, nb::arg("callback"),
+ nb::arg("walk_order") = MlirWalkPostOrder,
+ // clang-format off
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+ // clang-format on
+ R"(
+ Walks the operation tree with a callback function.
+
+ Args:
+ callback: A callable that takes an Operation and returns a WalkResult.
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+
+ nb::class_<PyOperation, PyOperationBase>(m, "Operation")
+ .def_static(
+ "create",
+ [](std::string_view name,
+ std::optional<std::vector<PyType *>> results,
+ std::optional<std::vector<PyValue *>> operands,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors, int regions,
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp,
+ bool inferType) -> nb::typed<nb::object, PyOperation> {
+ // Unpack/validate operands.
+ llvm::SmallVector<MlirValue, 4> mlirOperands;
+ if (operands) {
+ mlirOperands.reserve(operands->size());
+ for (PyValue *operand : *operands) {
+ if (!operand)
+ throw nb::value_error("operand value cannot be None");
+ mlirOperands.push_back(operand->get());
+ }
+ }
+
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ return PyOperation::create(name, results, mlirOperands, attributes,
+ successors, regions, pyLoc, maybeIp,
+ inferType);
+ },
+ nb::arg("name"), nb::arg("results") = nb::none(),
+ nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
+ nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
+ nb::arg("infer_type") = false,
+ R"(
+ Creates a new operation.
+
+ Args:
+ name: Operation name (e.g. `dialect.operation`).
+ results: Optional sequence of Type representing op result types.
+ operands: Optional operands of the operation.
+ attributes: Optional Dict of {str: Attribute}.
+ successors: Optional List of Block for the operation's successors.
+ regions: Number of regions to create (default = 0).
+ location: Optional Location object (defaults to resolve from context manager).
+ ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
+ infer_type: Whether to infer result types (default = False).
+ Returns:
+ A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
+ .def_static(
+ "parse",
+ [](const std::string &sourceStr, const std::string &sourceName,
+ DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyOpView> {
+ return PyOperation::parse(context->getRef(), sourceStr, sourceName)
+ ->createOpView();
+ },
+ nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
+ nb::arg("context") = nb::none(),
+ "Parses an operation. Supports both text assembly format and binary "
+ "bytecode format.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
+ "Gets a capsule wrapping the MlirOperation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyOperation::createFromCapsule,
+ "Creates an Operation from a capsule wrapping MlirOperation.")
+ .def_prop_ro(
+ "operation",
+ [](nb::object self) -> nb::typed<nb::object, PyOperation> {
+ return self;
+ },
+ "Returns self (the operation).")
+ .def_prop_ro(
+ "opview",
+ [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
+ return self.createOpView();
+ },
+ R"(
+ Returns an OpView of this operation.
+
+ Note:
+ If the operation has a registered and loaded dialect then this OpView will
+ be concrete wrapper class.)")
+ .def_prop_ro("block", &PyOperation::getBlock,
+ "Returns the block containing this operation.")
+ .def_prop_ro(
+ "successors",
+ [](PyOperationBase &self) {
+ return PyOpSuccessors(self.getOperation().getRef());
+ },
+ "Returns the list of Operation successors.")
+ .def("_set_invalid", &PyOperation::setInvalid,
+ "Invalidate the operation.");
+
+ auto opViewClass =
+ nb::class_<PyOpView, PyOperationBase>(m, "OpView")
+ .def(nb::init<nb::typed<nb::object, PyOperation>>(),
+ nb::arg("operation"))
+ .def(
+ "__init__",
+ [](PyOpView *self, std::string_view name,
+ std::tuple<int, bool> opRegionSpec,
+ nb::object operandSegmentSpecObj,
+ nb::object resultSegmentSpecObj,
+ std::optional<nb::list> resultTypeList, nb::list operandList,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions,
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp) {
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ new (self) PyOpView(PyOpView::buildGeneric(
+ name, opRegionSpec, operandSegmentSpecObj,
+ resultSegmentSpecObj, resultTypeList, operandList,
+ attributes, successors, regions, pyLoc, maybeIp));
+ },
+ nb::arg("name"), nb::arg("opRegionSpec"),
+ nb::arg("operandSegmentSpecObj") = nb::none(),
+ nb::arg("resultSegmentSpecObj") = nb::none(),
+ nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
+ nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(),
+ nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
+ nb::arg("ip") = nb::none())
+ .def_prop_ro(
+ "operation",
+ [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperationObject();
+ })
+ .def_prop_ro("opview",
+ [](nb::object self) -> nb::typed<nb::object, PyOpView> {
+ return self;
+ })
+ .def(
+ "__str__",
+ [](PyOpView &self) { return nb::str(self.getOperationObject()); })
+ .def_prop_ro(
+ "successors",
+ [](PyOperationBase &self) {
+ return PyOpSuccessors(self.getOperation().getRef());
+ },
+ "Returns the list of Operation successors.")
+ .def(
+ "_set_invalid",
+ [](PyOpView &self) { self.getOperation().setInvalid(); },
+ "Invalidate the operation.");
+ opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
+ opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
+ opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
+ // It is faster to pass the operation_name, ods_regions, and
+ // ods_operand_segments/ods_result_segments as arguments to the constructor,
+ // rather than to access them as attributes.
+ opViewClass.attr("build_generic") = classmethod(
+ [](nb::handle cls, std::optional<nb::list> resultTypeList,
+ nb::list operandList, std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions, std::optional<PyLocation> location,
+ const nb::object &maybeIp) {
+ std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+ std::tuple<int, bool> opRegionSpec =
+ nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+ nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
+ nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
+ resultSegmentSpec, resultTypeList,
+ operandList, attributes, successors,
+ regions, pyLoc, maybeIp);
+ },
+ nb::arg("cls"), nb::arg("results") = nb::none(),
+ nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
+ nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
+ nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
+ "Builds a specific, generated OpView based on class level attributes.");
+ opViewClass.attr("parse") = classmethod(
+ [](const nb::object &cls, const std::string &sourceStr,
+ const std::string &sourceName,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
+ PyOperationRef parsed =
+ PyOperation::parse(context->getRef(), sourceStr, sourceName);
+
+ // Check if the expected operation was parsed, and cast to to the
+ // appropriate `OpView` subclass if successful.
+ // NOTE: This accesses attributes that have been automatically added to
+ // `OpView` subclasses, and is not intended to be used on `OpView`
+ // directly.
+ std::string clsOpName =
+ nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+ MlirStringRef identifier =
+ mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
+ std::string_view parsedOpName(identifier.data, identifier.length);
+ if (clsOpName != parsedOpName)
+ throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
+ parsedOpName + "'");
+ return PyOpView::constructDerived(cls, parsed.getObject());
+ },
+ nb::arg("cls"), nb::arg("source"), nb::kw_only(),
+ nb::arg("source_name") = "", nb::arg("context") = nb::none(),
+ "Parses a specific, generated OpView based on class level attributes.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyRegion.
+ //----------------------------------------------------------------------------
+ nb::class_<PyRegion>(m, "Region")
+ .def_prop_ro(
+ "blocks",
+ [](PyRegion &self) {
+ return PyBlockList(self.getParentOperation(), self.get());
+ },
+ "Returns a forward-optimized sequence of blocks.")
+ .def_prop_ro(
+ "owner",
+ [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the operation owning this region.")
+ .def(
+ "__iter__",
+ [](PyRegion &self) {
+ self.checkValid();
+ MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
+ return PyBlockIterator(self.getParentOperation(), firstBlock);
+ },
+ "Iterates over blocks in the region.")
+ .def(
+ "__eq__",
+ [](PyRegion &self, PyRegion &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two regions for pointer equality.")
+ .def(
+ "__eq__", [](PyRegion &self, nb::object &other) { return false; },
+ "Compares region with non-region object (always returns False).");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyBlock.
+ //----------------------------------------------------------------------------
+ nb::class_<PyBlock>(m, "Block")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
+ "Gets a capsule wrapping the MlirBlock.")
+ .def_prop_ro(
+ "owner",
+ [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the owning operation of this block.")
+ .def_prop_ro(
+ "region",
+ [](PyBlock &self) {
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ return PyRegion(self.getParentOperation(), region);
+ },
+ "Returns the owning region of this block.")
+ .def_prop_ro(
+ "arguments",
+ [](PyBlock &self) {
+ return PyBlockArgumentList(self.getParentOperation(), self.get());
+ },
+ "Returns a list of block arguments.")
+ .def(
+ "add_argument",
+ [](PyBlock &self, const PyType &type, const PyLocation &loc) {
+ return PyBlockArgument(self.getParentOperation(),
+ mlirBlockAddArgument(self.get(), type, loc));
+ },
+ "type"_a, "loc"_a,
+ R"(
+ Appends an argument of the specified type to the block.
+
+ Args:
+ type: The type of the argument to add.
+ loc: The source location for the argument.
+
+ Returns:
+ The newly added block argument.)")
+ .def(
+ "erase_argument",
+ [](PyBlock &self, unsigned index) {
+ return mlirBlockEraseArgument(self.get(), index);
+ },
+ nb::arg("index"),
+ R"(
+ Erases the argument at the specified index.
+
+ Args:
+ index: The index of the argument to erase.)")
+ .def_prop_ro(
+ "operations",
+ [](PyBlock &self) {
+ return PyOperationList(self.getParentOperation(), self.get());
+ },
+ "Returns a forward-optimized sequence of operations.")
+ .def_static(
+ "create_at_start",
+ [](PyRegion &parent, const nb::sequence &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ parent.checkValid();
+ MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
+ mlirRegionInsertOwnedBlock(parent, 0, block);
+ return PyBlock(parent.getParentOperation(), block);
+ },
+ nb::arg("parent"), nb::arg("arg_types") = nb::list(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block at the beginning of the given "
+ "region (with given argument types and locations).")
+ .def(
+ "append_to",
+ [](PyBlock &self, PyRegion ®ion) {
+ MlirBlock b = self.get();
+ if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
+ mlirBlockDetach(b);
+ mlirRegionAppendOwnedBlock(region.get(), b);
+ },
+ nb::arg("region"),
+ R"(
+ Appends this block to a region.
+
+ Transfers ownership if the block is currently owned by another region.
+
+ Args:
+ region: The region to append the block to.)")
+ .def(
+ "create_before",
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ self.checkValid();
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ nb::arg("arg_types"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block before this block "
+ "(with given argument types and locations).")
+ .def(
+ "create_after",
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ self.checkValid();
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ nb::arg("arg_types"), nb::kw_only(),
+ nb::arg("arg_locs") = std::nullopt,
+ "Creates and returns a new Block after this block "
+ "(with given argument types and locations).")
+ .def(
+ "__iter__",
+ [](PyBlock &self) {
+ self.checkValid();
+ MlirOperation firstOperation =
+ mlirBlockGetFirstOperation(self.get());
+ return PyOperationIterator(self.getParentOperation(),
+ firstOperation);
+ },
+ "Iterates over operations in the block.")
+ .def(
+ "__eq__",
+ [](PyBlock &self, PyBlock &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two blocks for pointer equality.")
+ .def(
+ "__eq__", [](PyBlock &self, nb::object &other) { return false; },
+ "Compares block with non-block object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyBlock &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the block.")
+ .def(
+ "__str__",
+ [](PyBlock &self) {
+ self.checkValid();
+ PyPrintAccumulator printAccum;
+ mlirBlockPrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the block.")
+ .def(
+ "append",
+ [](PyBlock &self, PyOperationBase &operation) {
+ if (operation.getOperation().isAttached())
+ operation.getOperation().detachFromParent();
+
+ MlirOperation mlirOperation = operation.getOperation().get();
+ mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
+ operation.getOperation().setAttached(
+ self.getParentOperation().getObject());
+ },
+ nb::arg("operation"),
+ R"(
+ Appends an operation to this block.
+
+ If the operation is currently in another block, it will be moved.
+
+ Args:
+ operation: The operation to append to the block.)")
+ .def_prop_ro(
+ "successors",
+ [](PyBlock &self) {
+ return PyBlockSuccessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block successors.")
+ .def_prop_ro(
+ "predecessors",
+ [](PyBlock &self) {
+ return PyBlockPredecessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block predecessors.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyInsertionPoint.
+ //----------------------------------------------------------------------------
+
+ nb::class_<PyInsertionPoint>(m, "InsertionPoint")
+ .def(nb::init<PyBlock &>(), nb::arg("block"),
+ "Inserts after the last operation but still inside the block.")
+ .def("__enter__", &PyInsertionPoint::contextEnter,
+ "Enters the insertion point as a context manager.")
+ .def("__exit__", &PyInsertionPoint::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none(),
+ "Exits the insertion point context manager.")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/) {
+ auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
+ if (!ip)
+ throw nb::value_error("No current InsertionPoint");
+ return ip;
+ },
+ nb::sig("def current(/) -> InsertionPoint"),
+ "Gets the InsertionPoint bound to the current thread or raises "
+ "ValueError if none has been set.")
+ .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
+ "Inserts before a referenced operation.")
+ .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
+ nb::arg("block"),
+ R"(
+ Creates an insertion point at the beginning of a block.
+
+ Args:
+ block: The block at whose beginning operations should be inserted.
+
+ Returns:
+ An InsertionPoint at the block's beginning.)")
+ .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
+ nb::arg("block"),
+ R"(
+ Creates an insertion point before a block's terminator.
+
+ Args:
+ block: The block whose terminator to insert before.
+
+ Returns:
+ An InsertionPoint before the terminator.
+
+ Raises:
+ ValueError: If the block has no terminator.)")
+ .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+ R"(
+ Creates an insertion point immediately after an operation.
+
+ Args:
+ operation: The operation after which to insert.
+
+ Returns:
+ An InsertionPoint after the operation.)")
+ .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
+ R"(
+ Inserts an operation at this insertion point.
+
+ Args:
+ operation: The operation to insert.)")
+ .def_prop_ro(
+ "block", [](PyInsertionPoint &self) { return self.getBlock(); },
+ "Returns the block that this `InsertionPoint` points to.")
+ .def_prop_ro(
+ "ref_operation",
+ [](PyInsertionPoint &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto refOperation = self.getRefOperation();
+ if (refOperation)
+ return refOperation->getObject();
+ return {};
+ },
+ "The reference operation before which new operations are "
+ "inserted, or None if the insertion point is at the end of "
+ "the block.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyAttribute.
+ //----------------------------------------------------------------------------
+ nb::class_<PyAttribute>(m, "Attribute")
+ // Delegate to the PyAttribute copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirAttribute.
+ .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
+ "Casts the passed attribute to the generic `Attribute`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
+ "Gets a capsule wrapping the MlirAttribute.")
+ .def_static(
+ MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
+ "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
+ .def_static(
+ "parse",
+ [](const std::string &attrSpec, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyAttribute> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr = mlirAttributeParseGet(
+ context->get(), toMlirStringRef(attrSpec));
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Unable to parse attribute", errors.take());
+ return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ "Parses an attribute from an assembly form. Raises an `MLIRError` on "
+ "failure.")
+ .def_prop_ro(
+ "context",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Attribute`.")
+ .def_prop_ro(
+ "type",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirAttributeGetType(self))
+ .maybeDownCast();
+ },
+ "Returns the type of the `Attribute`.")
+ .def(
+ "get_named",
+ [](PyAttribute &self, std::string name) {
+ return PyNamedAttribute(self, std::move(name));
+ },
+ nb::keep_alive<0, 1>(),
+ R"(
+ Binds a name to the attribute, creating a `NamedAttribute`.
+
+ Args:
+ name: The name to bind to the `Attribute`.
+
+ Returns:
+ A `NamedAttribute` with the given name and this attribute.)")
+ .def(
+ "__eq__",
+ [](PyAttribute &self, PyAttribute &other) { return self == other; },
+ "Compares two attributes for equality.")
+ .def(
+ "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
+ "Compares attribute with non-attribute object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyAttribute &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the attribute.")
+ .def(
+ "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyAttribute &self) {
+ PyPrintAccumulator printAccum;
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the Attribute.")
+ .def(
+ "__repr__",
+ [](PyAttribute &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, attribute values are generally considered useful and
+ // are printed. This may need to be re-evaluated if debug dumps end
+ // up being excessive.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Attribute(");
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the attribute.")
+ .def_prop_ro(
+ "typeid",
+ [](PyAttribute &self) {
+ MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ return PyTypeID(mlirTypeID);
+ },
+ "Returns the `TypeID` of the attribute.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the attribute to a more specific attribute if possible.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyNamedAttribute
+ //----------------------------------------------------------------------------
+ nb::class_<PyNamedAttribute>(m, "NamedAttribute")
+ .def(
+ "__repr__",
+ [](PyNamedAttribute &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("NamedAttribute(");
+ printAccum.parts.append(
+ nb::str(mlirIdentifierStr(self.namedAttr.name).data,
+ mlirIdentifierStr(self.namedAttr.name).length));
+ printAccum.parts.append("=");
+ mlirAttributePrint(self.namedAttr.attribute,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the named attribute.")
+ .def_prop_ro(
+ "name",
+ [](PyNamedAttribute &self) {
+ return mlirIdentifierStr(self.namedAttr.name);
+ },
+ "The name of the `NamedAttribute` binding.")
+ .def_prop_ro(
+ "attr",
+ [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
+ nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
+ "The underlying generic attribute of the `NamedAttribute` binding.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyType.
+ //----------------------------------------------------------------------------
+ nb::class_<PyType>(m, "Type")
+ // Delegate to the PyType copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirType.
+ .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
+ "Casts the passed type to the generic `Type`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
+ "Gets a capsule wrapping the `MlirType`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
+ "Creates a Type from a capsule wrapping `MlirType`.")
+ .def_static(
+ "parse",
+ [](std::string typeSpec,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type =
+ mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Unable to parse type", errors.take());
+ return PyType(context.get()->getRef(), type).maybeDownCast();
+ },
+ nb::arg("asm"), nb::arg("context") = nb::none(),
+ R"(
+ Parses the assembly form of a type.
+
+ Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
+
+ See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
+ .def_prop_ro(
+ "context",
+ [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Type`.")
+ .def(
+ "__eq__", [](PyType &self, PyType &other) { return self == other; },
+ "Compares two types for equality.")
+ .def(
+ "__eq__", [](PyType &self, nb::object &other) { return false; },
+ nb::arg("other").none(),
+ "Compares type with non-type object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyType &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the `Type`.")
+ .def(
+ "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyType &self) {
+ PyPrintAccumulator printAccum;
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the `Type`.")
+ .def(
+ "__repr__",
+ [](PyType &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, types are an exception as they typically have compact
+ // assembly forms and printing them is useful.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Type(");
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the `Type`.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) -> nb::typed<nb::object, PyType> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the Type to a more specific `Type` if possible.")
+ .def_prop_ro(
+ "typeid",
+ [](PyType &self) {
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+ if (!mlirTypeIDIsNull(mlirTypeID))
+ return PyTypeID(mlirTypeID);
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
+ throw nb::value_error(
+ (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
+ },
+ "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
+ "`Type` has no "
+ "`TypeID`.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyTypeID.
+ //----------------------------------------------------------------------------
+ nb::class_<PyTypeID>(m, "TypeID")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
+ "Gets a capsule wrapping the `MlirTypeID`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
+ "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
+ // Note, this tests whether the underlying TypeIDs are the same,
+ // not whether the wrapper MlirTypeIDs are the same, nor whether
+ // the Python objects are the same (i.e., PyTypeID is a value type).
+ .def(
+ "__eq__",
+ [](PyTypeID &self, PyTypeID &other) { return self == other; },
+ "Compares two `TypeID`s for equality.")
+ .def(
+ "__eq__",
+ [](PyTypeID &self, const nb::object &other) { return false; },
+ "Compares TypeID with non-TypeID object (always returns False).")
+ // Note, this gives the hash value of the underlying TypeID, not the
+ // hash value of the Python object, nor the hash value of the
+ // MlirTypeID wrapper.
+ .def(
+ "__hash__",
+ [](PyTypeID &self) {
+ return static_cast<size_t>(mlirTypeIDHashValue(self));
+ },
+ "Returns the hash value of the `TypeID`.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Value.
+ //----------------------------------------------------------------------------
+ m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
+
+ nb::class_<PyValue>(m, "Value", nb::is_generic(),
+ nb::sig("class Value(Generic[_T])"))
+ .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
+ "Creates a Value reference from another `Value`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
+ "Gets a capsule wrapping the `MlirValue`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
+ "Creates a `Value` from a capsule wrapping `MlirValue`.")
+ .def_prop_ro(
+ "context",
+ [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getParentOperation()->getContext().getObject();
+ },
+ "Context in which the value lives.")
+ .def(
+ "dump", [](PyValue &self) { mlirValueDump(self.get()); },
+ kDumpDocstring)
+ .def_prop_ro(
+ "owner",
+ [](PyValue &self) -> nb::object {
+ MlirValue v = self.get();
+ if (mlirValueIsAOpResult(v)) {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match "
+ "that in "
+ "the IR");
+ return self.getParentOperation().getObject();
+ }
+
+ if (mlirValueIsABlockArgument(v)) {
+ MlirBlock block = mlirBlockArgumentGetOwner(self.get());
+ return nb::cast(PyBlock(self.getParentOperation(), block));
+ }
+
+ assert(false && "Value must be a block argument or an op result");
+ return nb::none();
+ },
+ "Returns the owner of the value (`Operation` for results, `Block` "
+ "for "
+ "arguments).")
+ .def_prop_ro(
+ "uses",
+ [](PyValue &self) {
+ return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
+ },
+ "Returns an iterator over uses of this value.")
+ .def(
+ "__eq__",
+ [](PyValue &self, PyValue &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two values for pointer equality.")
+ .def(
+ "__eq__", [](PyValue &self, nb::object other) { return false; },
+ "Compares value with non-value object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyValue &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the value.")
+ .def(
+ "__str__",
+ [](PyValue &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Value(");
+ mlirValuePrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ R"(
+ Returns the string form of the value.
+
+ If the value is a block argument, this is the assembly form of its type and the
+ position in the argument list. If the value is an operation result, this is
+ equivalent to printing the operation that produced it.
+ )")
+ .def(
+ "get_name",
+ [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
+ PyPrintAccumulator printAccum;
+ MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ if (useNameLocAsPrefix)
+ mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
+ MlirAsmState valueState =
+ mlirAsmStateCreateForValue(self.get(), flags);
+ mlirValuePrintAsOperand(self.get(), valueState,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ mlirOpPrintingFlagsDestroy(flags);
+ mlirAsmStateDestroy(valueState);
+ return printAccum.join();
+ },
+ nb::arg("use_local_scope") = false,
+ nb::arg("use_name_loc_as_prefix") = false,
+ R"(
+ Returns the string form of value as an operand.
+
+ Args:
+ use_local_scope: Whether to use local scope for naming.
+ use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
+
+ Returns:
+ The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
+ .def(
+ "get_name",
+ [](PyValue &self, PyAsmState &state) {
+ PyPrintAccumulator printAccum;
+ MlirAsmState valueState = state.get();
+ mlirValuePrintAsOperand(self.get(), valueState,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ nb::arg("state"),
+ "Returns the string form of value as an operand (i.e., the ValueID).")
+ .def_prop_ro(
+ "type",
+ [](PyValue &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getParentOperation()->getContext(),
+ mlirValueGetType(self.get()))
+ .maybeDownCast();
+ },
+ "Returns the type of the value.")
+ .def(
+ "set_type",
+ [](PyValue &self, const PyType &type) {
+ mlirValueSetType(self.get(), type);
+ },
+ nb::arg("type"), "Sets the type of the value.",
+ nb::sig("def set_type(self, type: _T)"))
+ .def(
+ "replace_all_uses_with",
+ [](PyValue &self, PyValue &with) {
+ mlirValueReplaceAllUsesOfWith(self.get(), with.get());
+ },
+ "Replace all uses of value with the new value, updating anything in "
+ "the IR that uses `self` to use the other value instead.")
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, const nb::list &exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (nb::handle exception : exceptions) {
+ exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
+ }
+
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with,
+ std::vector<PyOperation> &exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (PyOperation &exception : exceptions)
+ exceptionOps.push_back(exception);
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the `Value` to a more specific kind if possible.")
+ .def_prop_ro(
+ "location",
+ [](MlirValue self) {
+ return PyLocation(
+ PyMlirContext::forContext(mlirValueGetContext(self)),
+ mlirValueGetLocation(self));
+ },
+ "Returns the source location of the value.");
+
+ PyBlockArgument::bind(m);
+ PyOpResult::bind(m);
+ PyOpOperand::bind(m);
+
+ nb::class_<PyAsmState>(m, "AsmState")
+ .def(nb::init<PyValue &, bool>(), nb::arg("value"),
+ nb::arg("use_local_scope") = false,
+ R"(
+ Creates an `AsmState` for consistent SSA value naming.
+
+ Args:
+ value: The value to create state for.
+ use_local_scope: Whether to use local scope for naming.)")
+ .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
+ nb::arg("use_local_scope") = false,
+ R"(
+ Creates an AsmState for consistent SSA value naming.
+
+ Args:
+ op: The operation to create state for.
+ use_local_scope: Whether to use local scope for naming.)");
+
+ //----------------------------------------------------------------------------
+ // Mapping of SymbolTable.
+ //----------------------------------------------------------------------------
+ nb::class_<PySymbolTable>(m, "SymbolTable")
+ .def(nb::init<PyOperationBase &>(),
+ R"(
+ Creates a symbol table for an operation.
+
+ Args:
+ operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
+
+ Raises:
+ TypeError: If the operation is not a symbol table.)")
+ .def(
+ "__getitem__",
+ [](PySymbolTable &self,
+ const std::string &name) -> nb::typed<nb::object, PyOpView> {
+ return self.dunderGetItem(name);
+ },
+ R"(
+ Looks up a symbol by name in the symbol table.
+
+ Args:
+ name: The name of the symbol to look up.
+
+ Returns:
+ The operation defining the symbol.
+
+ Raises:
+ KeyError: If the symbol is not found.)")
+ .def("insert", &PySymbolTable::insert, nb::arg("operation"),
+ R"(
+ Inserts a symbol operation into the symbol table.
+
+ Args:
+ operation: An operation with a symbol name to insert.
+
+ Returns:
+ The symbol name attribute of the inserted operation.
+
+ Raises:
+ ValueError: If the operation does not have a symbol name.)")
+ .def("erase", &PySymbolTable::erase, nb::arg("operation"),
+ R"(
+ Erases a symbol operation from the symbol table.
+
+ Args:
+ operation: The symbol operation to erase.
+
+ Note:
+ The operation is also erased from the IR and invalidated.)")
+ .def("__delitem__", &PySymbolTable::dunderDel,
+ "Deletes a symbol by name from the symbol table.")
+ .def(
+ "__contains__",
+ [](PySymbolTable &table, const std::string &name) {
+ return !mlirOperationIsNull(mlirSymbolTableLookup(
+ table, mlirStringRefCreate(name.data(), name.length())));
+ },
+ "Checks if a symbol with the given name exists in the table.")
+ // Static helpers.
+ .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
+ nb::arg("symbol"), nb::arg("name"),
+ "Sets the symbol name for a symbol operation.")
+ .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
+ nb::arg("symbol"),
+ "Gets the symbol name from a symbol operation.")
+ .def_static("get_visibility", &PySymbolTable::getVisibility,
+ nb::arg("symbol"),
+ "Gets the visibility attribute of a symbol operation.")
+ .def_static("set_visibility", &PySymbolTable::setVisibility,
+ nb::arg("symbol"), nb::arg("visibility"),
+ "Sets the visibility attribute of a symbol operation.")
+ .def_static("replace_all_symbol_uses",
+ &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
+ nb::arg("new_symbol"), nb::arg("from_op"),
+ "Replaces all uses of a symbol with a new symbol name within "
+ "the given operation.")
+ .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
+ nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
+ nb::arg("callback"),
+ "Walks symbol tables starting from an operation with a "
+ "callback function.");
+
+ // Container bindings.
+ PyBlockArgumentList::bind(m);
+ PyBlockIterator::bind(m);
+ PyBlockList::bind(m);
+ PyBlockSuccessors::bind(m);
+ PyBlockPredecessors::bind(m);
+ PyOperationIterator::bind(m);
+ PyOperationList::bind(m);
+ PyOpAttributeMap::bind(m);
+ PyOpOperandIterator::bind(m);
+ PyOpOperandList::bind(m);
+ PyOpResultList::bind(m);
+ PyOpSuccessors::bind(m);
+ PyRegionIterator::bind(m);
+ PyRegionList::bind(m);
+
+ // Debug bindings.
+ PyGlobalDebugFlag::bind(m);
+
+ // Attribute builder getter.
+ PyAttrBuilderMap::bind(m);
+
+ // nb::register_exception_translator([](const std::exception_ptr &p,
+ // void *payload) {
+ // // We can't define exceptions with custom fields through pybind, so
+ // instead
+ // // the exception class is defined in python and imported here.
+ // try {
+ // if (p)
+ // std::rethrow_exception(p);
+ // } catch (const MLIRError &e) {
+ // nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ // .attr("MLIRError")(e.message, e.errorDiagnostics);
+ // PyErr_SetObject(PyExc_Exception, obj.ptr());
+ // }
+ // });
+}
+
+namespace mlir::python {
+void populateIRAffine(nb::module_ &m);
+void populateIRAttributes(nb::module_ &m);
+void populateIRInterfaces(nb::module_ &m);
+void populateIRTypes(nb::module_ &m);
+void registerMLIRErrorInIRCore();
+} // namespace mlir::python
+
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
@@ -158,4 +2415,18 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
+ registerMLIRErrorInIRCore();
+ nb::register_exception_translator([](const std::exception_ptr &p,
+ void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdf01fff28cf2..c4165a04b284d 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,8 +8,8 @@
#include "Pass.h"
-#include "Globals.h"
-#include "IRModule.h"
+#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/Pass.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..0221bd10e723e 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_PASS_H
#define MLIR_BINDINGS_PYTHON_PASS_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 2a5129b7f4ab1..a83b2c1883174 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,7 +8,7 @@
#include "Rewrite.h"
-#include "IRModule.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index ae89e2b9589f1..f8ffdc7bdc458 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
-#include "NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace mlir {
namespace python {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index e9b1aff0455e6..3a4af1f066298 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -532,6 +532,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
MainModule.cpp
+ IRAffine.cpp
+ IRAttributes.cpp
+ IRInterfaces.cpp
+ IRTypes.cpp
Pass.cpp
Rewrite.cpp
@@ -991,12 +995,8 @@ get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso
list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
add_mlir_library(MLIRPythonSupport
- ${PYTHON_SOURCE_DIR}/Globals.cpp
- ${PYTHON_SOURCE_DIR}/IRAffine.cpp
- ${PYTHON_SOURCE_DIR}/IRAttributes.cpp
${PYTHON_SOURCE_DIR}/IRCore.cpp
- ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
- ${PYTHON_SOURCE_DIR}/IRTypes.cpp
+ ${PYTHON_SOURCE_DIR}/Globals.cpp
EXCLUDE_FROM_LIBMLIR
SHARED
LINK_COMPONENTS
@@ -1014,6 +1014,13 @@ set_target_properties(MLIRPythonSupport PROPERTIES
RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
)
+set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ # Needed for windows (and doesn't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+ ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
+)
set(eh_rtti_enable)
if(MSVC)
set(eh_rtti_enable /EHsc /GR)
@@ -1035,4 +1042,4 @@ endif()
target_link_libraries(
MLIRPythonModules.extension._mlir.dso
PUBLIC MLIRPythonSupport)
-
+target_compile_definitions(MLIRPythonSupport PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
>From 33a53bae2f76dbaa0517929b6b0ea77fa96983a5 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 11 Dec 2025 13:46:11 -0800
Subject: [PATCH 07/38] works
---
mlir/cmake/modules/AddMLIRPython.cmake | 1 -
mlir/include/mlir/Bindings/Python/Globals.h | 1 -
mlir/include/mlir/Bindings/Python/IRCore.h | 19 +++++++++++
mlir/lib/Bindings/Python/IRAttributes.cpp | 14 +-------
mlir/lib/Bindings/Python/IRCore.cpp | 5 +--
mlir/lib/Bindings/Python/IRInterfaces.cpp | 2 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 18 ++---------
mlir/lib/Bindings/Python/MainModule.cpp | 32 ++-----------------
mlir/lib/Bindings/Python/Pass.cpp | 3 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
mlir/python/CMakeLists.txt | 12 +++++--
.../python/lib/PythonTestModuleNanobind.cpp | 31 +++++++++++-------
12 files changed, 60 insertions(+), 80 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 111ef45609160..0a5d788b9bca0 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -601,7 +601,6 @@ function(add_mlir_python_common_capi_library name)
# Generate the aggregate .so that everything depends on.
add_mlir_aggregate(${name}
SHARED
- DISABLE_INSTALL
EMBED_LIBS ${_embed_libs}
)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index fea7a201453ce..19ffe8164d727 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -245,7 +245,6 @@ struct PyGlobalDebugFlag {
static nanobind::ft_mutex mutex;
};
-
} // namespace python
} // namespace mlir
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 488196ea42e44..66a6272eaaf68 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1325,6 +1325,25 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
+inline void registerMLIRError() {
+ nanobind::register_exception_translator(
+ [](const std::exception_ptr &p, void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nanobind::object obj =
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
+}
+
+void registerMLIRErrorInCore();
+
//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 36367e658697c..4323374a5d5b7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1852,18 +1852,6 @@ void populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
+ registerMLIRError();
}
} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 88cffb64906d7..ea1e62b8165ad 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -920,9 +920,6 @@ nb::object PyOperation::create(std::string_view name,
PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
if (!operation.ptr) {
- for (auto take : errors.take()) {
- std::cout << take.message << "\n";
- }
throw MLIRError("Operation creation failed", errors.take());
}
PyOperationRef created =
@@ -1672,7 +1669,7 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-void registerMLIRErrorInIRCore() {
+void registerMLIRErrorInCore() {
nb::register_exception_translator([](const std::exception_ptr &p,
void *payload) {
// We can't define exceptions with custom fields through pybind, so
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index f1e494c375523..78d1f977b2ebc 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -12,11 +12,11 @@
#include <utility>
#include <vector>
-#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 294ab91a059e2..7d9a0f16c913a 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -13,10 +13,10 @@
#include <optional>
-#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
using namespace mlir;
@@ -1175,18 +1175,6 @@ void populateIRTypes(nb::module_ &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
-});
-}
+ registerMLIRError();
}
+} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 686c55ee1e6a8..643851fcaf046 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -2250,21 +2250,6 @@ static void populateIRCore(nb::module_ &m) {
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
-
- // nb::register_exception_translator([](const std::exception_ptr &p,
- // void *payload) {
- // // We can't define exceptions with custom fields through pybind, so
- // instead
- // // the exception class is defined in python and imported here.
- // try {
- // if (p)
- // std::rethrow_exception(p);
- // } catch (const MLIRError &e) {
- // nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- // .attr("MLIRError")(e.message, e.errorDiagnostics);
- // PyErr_SetObject(PyExc_Exception, obj.ptr());
- // }
- // });
}
namespace mlir::python {
@@ -2272,7 +2257,6 @@ void populateIRAffine(nb::module_ &m);
void populateIRAttributes(nb::module_ &m);
void populateIRInterfaces(nb::module_ &m);
void populateIRTypes(nb::module_ &m);
-void registerMLIRErrorInIRCore();
} // namespace mlir::python
// -----------------------------------------------------------------------------
@@ -2415,18 +2399,6 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
- registerMLIRErrorInIRCore();
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
+ registerMLIRError();
+ registerMLIRErrorInCore();
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index c4165a04b284d..953d1eb7fd338 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,9 +8,9 @@
#include "Pass.h"
+#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
-#include "mlir-c/Pass.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
@@ -266,4 +266,5 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
+ registerMLIRError();
}
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index a83b2c1883174..8a3a27f78c0e4 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -8,10 +8,10 @@
#include "Rewrite.h"
-#include "mlir/Bindings/Python/IRCore.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/IRCore.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 3a4af1f066298..9286cead1a5c7 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -1002,11 +1002,16 @@ add_mlir_library(MLIRPythonSupport
LINK_COMPONENTS
Support
LINK_LIBS
+ Python::Module
${NB_LIBRARY_TARGET_NAME}
- MLIRCAPIIR
+ MLIRPythonCAPI
)
-target_link_libraries(MLIRPythonSupport PUBLIC ${NB_LIBRARY_TARGET_NAME})
nanobind_link_options(MLIRPythonSupport)
+get_target_property(_current_link_options MLIRPythonSupport LINK_OPTIONS)
+if(_current_link_options)
+ string(REPLACE "LINKER:-z,defs" "" _modified_link_options "${_current_link_options}")
+ set_property(TARGET MLIRPythonSupport PROPERTY LINK_OPTIONS "${_modified_link_options}")
+endif()
set_target_properties(MLIRPythonSupport PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
@@ -1042,4 +1047,7 @@ endif()
target_link_libraries(
MLIRPythonModules.extension._mlir.dso
PUBLIC MLIRPythonSupport)
+target_link_libraries(
+ MLIRPythonModules.extension._mlirPythonTestNanobind.dso
+ PUBLIC MLIRPythonSupport)
target_compile_definitions(MLIRPythonSupport PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a497fcccf13d7..e53f1ab3b4d3f 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -14,6 +14,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
@@ -26,6 +27,24 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
+struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirPythonTestTestTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](mlir::python::DefaultingPyMlirContext context) {
+ return PyTestType(context->getRef(),
+ mlirPythonTestTestTypeGet(context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
@@ -78,17 +97,7 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
// clang-format on
nb::arg("cls"), nb::arg("context").none() = nb::none());
- mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
- mlirPythonTestTestTypeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPythonTestTestTypeGet(ctx));
- },
- // clang-format off
- nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("context").none() = nb::none());
+ PyTestType::bind(m);
auto typeCls =
mlir_type_subclass(m, "TestIntegerRankedTensorType",
>From 6b886af92ce9faaa947e3c5c5fbef8afcbfe32f2 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 11 Dec 2025 16:19:14 -0800
Subject: [PATCH 08/38] rebase
---
mlir/include/mlir/Bindings/Python/Globals.h | 44 -----------------
mlir/include/mlir/Bindings/Python/IRCore.h | 14 +++---
mlir/lib/Bindings/Python/Globals.cpp | 3 --
mlir/lib/Bindings/Python/IRAttributes.cpp | 8 ----
mlir/lib/Bindings/Python/IRCore.cpp | 1 +
mlir/lib/Bindings/Python/MainModule.cpp | 53 +++++++++++++++++++++
mlir/python/CMakeLists.txt | 9 ++--
7 files changed, 65 insertions(+), 67 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 19ffe8164d727..da06bbfaed479 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -201,50 +201,6 @@ class PyGlobals {
TracebackLoc tracebackLoc;
TypeIDAllocator typeIDAllocator;
};
-
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nanobind::object &o, bool enable) {
- nanobind::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nanobind::object &) {
- nanobind::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
-
- static void bind(nanobind::module_ &m) {
- // Debug flags.
- nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- nanobind::arg("types"),
- "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- nanobind::arg("types"),
- "Sets multiple specific debug types to be produced by LLVM.");
- }
-
-private:
- static nanobind::ft_mutex mutex;
-};
-
} // namespace python
} // namespace mlir
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 66a6272eaaf68..e82ee8da20fe5 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1350,12 +1350,12 @@ void registerMLIRErrorInCore();
/// Helper for creating an @classmethod.
template <class Func, typename... Args>
-static nanobind::object classmethod(Func f, Args... args) {
+nanobind::object classmethod(Func f, Args... args) {
nanobind::object cf = nanobind::cpp_function(f, args...);
return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
}
-static nanobind::object
+inline nanobind::object
createCustomDialectWrapper(const std::string &dialectNamespace,
nanobind::object dialectDescriptor) {
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
@@ -1368,21 +1368,21 @@ createCustomDialectWrapper(const std::string &dialectNamespace,
return (*dialectClass)(std::move(dialectDescriptor));
}
-static MlirStringRef toMlirStringRef(const std::string &s) {
+inline MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
-static MlirStringRef toMlirStringRef(std::string_view s) {
+inline MlirStringRef toMlirStringRef(std::string_view s) {
return mlirStringRefCreate(s.data(), s.size());
}
-static MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
+inline MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
}
/// Create a block, using the current location context if no locations are
/// specified.
-static MlirBlock
+inline MlirBlock
createBlock(const nanobind::sequence &pyArgTypes,
const std::optional<nanobind::sequence> &pyArgLocs) {
SmallVector<MlirType> argTypes;
@@ -1871,7 +1871,7 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<nanobind::typed<nanobind::object, PyType>>
+std::vector<nanobind::typed<nanobind::object, PyType>>
getValueTypes(Container &container, PyMlirContextRef &context) {
std::vector<nanobind::typed<nanobind::object, PyType>> result;
result.reserve(container.size());
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index bc6b210426221..97a2df37a729b 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -267,7 +267,4 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename(
}
return isUserTracebackFilenameCache[file];
}
-
-nanobind::ft_mutex PyGlobalDebugFlag::mutex;
-
} // namespace mlir::python
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 4323374a5d5b7..e39eabdb136b8 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -228,14 +228,6 @@ struct nb_format_descriptor<double> {
static const char *format() { return "d"; }
};
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
-static MlirStringRef toMlirStringRef(const nb::bytes &s) {
- return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
-}
-
class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index ea1e62b8165ad..fc8743599508d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -26,6 +26,7 @@
#include <iostream>
#include <optional>
+#include <typeinfo>
namespace nb = nanobind;
using namespace nb::literals;
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 643851fcaf046..56dd4e0892655 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -185,6 +185,51 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}
+
+/// Wrapper for the global LLVM debugging flag.
+struct PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
+
+ static bool get(const nanobind::object &) {
+ nanobind::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ // Debug flags.
+ nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ nanobind::arg("types"),
+ "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ nanobind::arg("types"),
+ "Sets multiple specific debug types to be produced by LLVM.");
+ }
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+
+nanobind::ft_mutex PyGlobalDebugFlag::mutex;
} // namespace
//------------------------------------------------------------------------------
@@ -1241,6 +1286,14 @@ static void populateIRCore(nb::module_ &m) {
return PyOpSuccessors(self.getOperation().getRef());
},
"Returns the list of Operation successors.")
+ .def(
+ "replace_uses_of_with",
+ [](PyOperation &self, PyValue &of, PyValue &with) {
+ mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
+ },
+ "of"_a, "with_"_a,
+ "Replaces uses of the 'of' value with the 'with' value inside the "
+ "operation.")
.def("_set_invalid", &PyOperation::setInvalid,
"Invalidate the operation.");
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 9286cead1a5c7..a32c85cf10359 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -1006,12 +1006,11 @@ add_mlir_library(MLIRPythonSupport
${NB_LIBRARY_TARGET_NAME}
MLIRPythonCAPI
)
-nanobind_link_options(MLIRPythonSupport)
-get_target_property(_current_link_options MLIRPythonSupport LINK_OPTIONS)
-if(_current_link_options)
- string(REPLACE "LINKER:-z,defs" "" _modified_link_options "${_current_link_options}")
- set_property(TARGET MLIRPythonSupport PROPERTY LINK_OPTIONS "${_modified_link_options}")
+if((CMAKE_SYSTEM_NAME STREQUAL "Linux") AND (NOT LLVM_USE_SANITIZER))
+ target_link_options(MLIRPythonSupport PRIVATE "-Wl,-z,undefs")
+ target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
endif()
+nanobind_link_options(MLIRPythonSupport)
set_target_properties(MLIRPythonSupport PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
>From 7af0b285fe82a5b380de473d2f676e658b9a94d8 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 16 Dec 2025 15:13:36 -0800
Subject: [PATCH 09/38] fix after rebase
---
mlir/include/mlir/Bindings/Python/IRCore.h | 4 ++--
mlir/lib/Bindings/Python/MainModule.cpp | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index e82ee8da20fe5..649dfce22ad35 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1852,12 +1852,12 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
- [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOperation> {
+ [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOpView> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
- return self.getParentOperation().getObject();
+ return self.getParentOperation()->createOpView();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 56dd4e0892655..f72775cc0b83a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -1993,7 +1993,7 @@ static void populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
- [](PyValue &self) -> nb::object {
+ [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -2001,7 +2001,7 @@ static void populateIRCore(nb::module_ &m) {
"expected the owner of the value in Python to match "
"that in "
"the IR");
- return self.getParentOperation().getObject();
+ return self.getParentOperation()->createOpView();
}
if (mlirValueIsABlockArgument(v)) {
>From 1abbe3cfb125b7f2718592f2ee916700b31a546d Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 10:03:47 -0800
Subject: [PATCH 10/38] try fix windows badcast
---
mlir/include/mlir/Bindings/Python/Globals.h | 5 +----
mlir/lib/Bindings/Python/Globals.cpp | 5 +++++
mlir/python/CMakeLists.txt | 18 +++++++++---------
mlir/test/python/dialects/python_test.py | 12 +++---------
4 files changed, 18 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index da06bbfaed479..4584828868451 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -38,10 +38,7 @@ class PyGlobals {
~PyGlobals();
/// Most code should get the globals via this static accessor.
- static PyGlobals &get() {
- assert(instance && "PyGlobals is null");
- return *instance;
- }
+ static PyGlobals &get();
/// Get and set the list of parent modules to search for dialect
/// implementation classes.
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 97a2df37a729b..ecac571a132f6 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -39,6 +39,11 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }
+PyGlobals &PyGlobals::get() {
+ assert(instance && "PyGlobals is null");
+ return *instance;
+}
+
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
{
nb::ft_lock_guard lock(mutex);
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index a32c85cf10359..a8bbd15124df5 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -994,17 +994,16 @@ endif()
get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso PROPERTY LINK_LIBRARIES)
list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
-add_mlir_library(MLIRPythonSupport
+add_library(MLIRPythonSupport SHARED
${PYTHON_SOURCE_DIR}/IRCore.cpp
${PYTHON_SOURCE_DIR}/Globals.cpp
- EXCLUDE_FROM_LIBMLIR
- SHARED
- LINK_COMPONENTS
- Support
- LINK_LIBS
- Python::Module
- ${NB_LIBRARY_TARGET_NAME}
- MLIRPythonCAPI
+)
+target_link_libraries(MLIRPythonSupport PRIVATE
+ LLVMSupport
+ Python::Module
+ ${NB_LIBRARY_TARGET_NAME}
+ MLIRPythonCAPI
+
)
if((CMAKE_SYSTEM_NAME STREQUAL "Linux") AND (NOT LLVM_USE_SANITIZER))
target_link_options(MLIRPythonSupport PRIVATE "-Wl,-z,undefs")
@@ -1028,6 +1027,7 @@ set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
set(eh_rtti_enable)
if(MSVC)
set(eh_rtti_enable /EHsc /GR)
+ set_property(TARGET MLIRPythonSupport PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
elseif(LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
set(eh_rtti_enable -frtti -fexceptions)
endif()
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 7bba20931e675..e50c8722f8959 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -613,12 +613,6 @@ def testCustomType():
b = TestType(a)
# Instance custom types should have typeids
assert isinstance(b.typeid, TypeID)
- # Subclasses of ir.Type should not have a static_typeid
- # CHECK: 'TestType' object has no attribute 'static_typeid'
- try:
- b.static_typeid
- except AttributeError as e:
- print(e)
i8 = IntegerType.get_signless(8)
try:
@@ -633,9 +627,9 @@ def testCustomType():
try:
TestType(42)
except TypeError as e:
- assert "Expected an MLIR object (got 42)" in str(e)
- except ValueError as e:
- assert "Cannot cast type to TestType (from 42)" in str(e)
+ assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
+ assert "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None" in str(e)
+ assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int" in str(e)
else:
raise
>From a1ac2c72fa7010658cf56cb85b21fd317b2b4440 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 14:04:46 -0800
Subject: [PATCH 11/38] port mlir_attribute_subclass
---
mlir/test/python/dialects/python_test.py | 6 ++--
.../python/lib/PythonTestModuleNanobind.cpp | 34 ++++++++++++-------
2 files changed, 24 insertions(+), 16 deletions(-)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index e50c8722f8959..0ba56b7922ff5 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -586,9 +586,9 @@ def testCustomAttribute():
try:
TestAttr(42)
except TypeError as e:
- assert "Expected an MLIR object (got 42)" in str(e)
- except ValueError as e:
- assert "Cannot cast attribute to TestAttr (from 42)" in str(e)
+ assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
+ assert "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None" in str(e)
+ assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int" in str(e)
else:
raise
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index e53f1ab3b4d3f..c8b95e2316778 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -45,6 +45,26 @@ struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
}
};
+class PyTestAttr : public mlir::python::PyConcreteAttribute<PyTestAttr> {
+public:
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsAPythonTestTestAttribute;
+ static constexpr const char *pyClassName = "TestAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirPythonTestTestAttributeGetTypeID;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](mlir::python::DefaultingPyMlirContext context) {
+ return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
+ context.get()->get()));
+ },
+ nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
@@ -84,19 +104,7 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
// clang-format on
- mlir_attribute_subclass(m, "TestAttr",
- mlirAttributeIsAPythonTestTestAttribute,
- mlirPythonTestTestAttributeGetTypeID)
- .def_classmethod(
- "get",
- [](const nb::object &cls, MlirContext ctx) {
- return cls(mlirPythonTestTestAttributeGet(ctx));
- },
- // clang-format off
- nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("context").none() = nb::none());
-
+ PyTestAttr::bind(m);
PyTestType::bind(m);
auto typeCls =
>From 501def95208cd434c6bf49d7ff632644769a0462 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 14:10:11 -0800
Subject: [PATCH 12/38] format
---
mlir/test/python/dialects/python_test.py | 30 +++++++++++++++++++-----
1 file changed, 24 insertions(+), 6 deletions(-)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 0ba56b7922ff5..9c0966d2d8798 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -586,9 +586,18 @@ def testCustomAttribute():
try:
TestAttr(42)
except TypeError as e:
- assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
- assert "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None" in str(e)
- assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int" in str(e)
+ assert (
+ "__init__(): incompatible function arguments. The following argument types are supported"
+ in str(e)
+ )
+ assert (
+ "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None"
+ in str(e)
+ )
+ assert (
+ "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int"
+ in str(e)
+ )
else:
raise
@@ -627,9 +636,18 @@ def testCustomType():
try:
TestType(42)
except TypeError as e:
- assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
- assert "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None" in str(e)
- assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int" in str(e)
+ assert (
+ "__init__(): incompatible function arguments. The following argument types are supported"
+ in str(e)
+ )
+ assert (
+ "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None"
+ in str(e)
+ )
+ assert (
+ "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int"
+ in str(e)
+ )
else:
raise
>From c65cd2a0fb7cbd1bd5040c367505e46b9f7b19cf Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 19 Dec 2025 18:40:39 -0800
Subject: [PATCH 13/38] massage cmake
---
mlir/cmake/modules/AddMLIRPython.cmake | 158 ++++++++++++++++++++-----
mlir/python/CMakeLists.txt | 71 ++---------
2 files changed, 137 insertions(+), 92 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 0a5d788b9bca0..d581f3ce51005 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -228,14 +228,19 @@ endfunction()
# aggregate dylib that is linked against.
function(declare_mlir_python_extension name)
cmake_parse_arguments(ARG
- ""
- "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT"
+ "SUPPORT_LIB"
+ "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;SOURCES_TYPE"
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
${ARGN})
if(NOT ARG_ROOT_DIR)
set(ARG_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
endif()
+ if(ARG_SUPPORT_LIB)
+ set(SOURCES_TYPE "support")
+ else()
+ set(SOURCES_TYPE "extension")
+ endif()
set(_install_destination "src/python/${name}")
add_library(${name} INTERFACE)
@@ -243,7 +248,7 @@ function(declare_mlir_python_extension name)
# Yes: Leading-lowercase property names are load bearing and the recommended
# way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS"
- mlir_python_SOURCES_TYPE extension
+ mlir_python_SOURCES_TYPE "${SOURCES_TYPE}"
mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}"
mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
mlir_python_DEPENDS ""
@@ -297,6 +302,39 @@ function(_mlir_python_install_sources name source_root_dir destination)
)
endfunction()
+function(build_nanobind_lib)
+ cmake_parse_arguments(ARG
+ ""
+ "INSTALL_COMPONENT;INSTALL_DESTINATION;OUTPUT_DIRECTORY"
+ ""
+ ${ARGN})
+
+ if (NB_ABI MATCHES "[0-9]t")
+ set(_ft "-ft")
+ endif()
+ # nanobind does a string match on the suffix to figure out whether to build
+ # the lib with free threading...
+ set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
+ set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
+ nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
+ endif()
+ set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ # Needed for windows (and don't hurt others).
+ RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+ )
+ mlir_python_setup_extension_rpath(${NB_LIBRARY_TARGET_NAME})
+ install(TARGETS ${NB_LIBRARY_TARGET_NAME}
+ COMPONENT ${ARG_INSTALL_COMPONENT}
+ LIBRARY DESTINATION "${ARG_INSTALL_DESTINATION}"
+ RUNTIME DESTINATION "${ARG_INSTALL_DESTINATION}"
+ )
+endfunction()
+
# Function: add_mlir_python_modules
# Adds python modules to a project, building them from a list of declared
# source groupings (see declare_mlir_python_sources and
@@ -318,8 +356,16 @@ function(add_mlir_python_modules name)
"ROOT_PREFIX;INSTALL_PREFIX"
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
+
+ # This call sets NB_LIBRARY_TARGET_NAME.
+ build_nanobind_lib(
+ INSTALL_COMPONENT ${name}
+ INSTALL_DESTINATION "${ARG_INSTALL_PREFIX}/_mlir_libs"
+ OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ )
+
# Helper to process an individual target.
- function(_process_target modules_target sources_target)
+ function(_process_target modules_target sources_target support_libs)
get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE)
if(_source_type STREQUAL "pure")
@@ -337,16 +383,19 @@ function(add_mlir_python_modules name)
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
# Transform relative source to based on root dir.
set(_extension_target "${modules_target}.extension.${_module_name}.dso")
- add_mlir_python_extension(${_extension_target} "${_module_name}"
+ add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${modules_target}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
LINK_LIBS PRIVATE
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
+ ${support_libs}
)
add_dependencies(${modules_target} ${_extension_target})
mlir_python_setup_extension_rpath(${_extension_target})
+ elseif(_source_type STREQUAL "support")
+ # do nothing because already built
else()
message(SEND_ERROR "Unrecognized source type '${_source_type}' for python source target ${sources_target}")
return()
@@ -356,8 +405,34 @@ function(add_mlir_python_modules name)
# Build the modules target.
add_custom_target(${name} ALL)
_flatten_mlir_python_targets(_flat_targets ${ARG_DECLARED_SOURCES})
+
+ # Build all support libs first.
+ set(_mlir_python_support_libs)
+ foreach(sources_target ${_flat_targets})
+ get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE)
+ if(_source_type STREQUAL "support")
+ get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
+ set(_extension_target "${name}.extension.${_module_name}.dso")
+ add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
+ INSTALL_COMPONENT ${name}
+ INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
+ OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ SUPPORT_LIB
+ LINK_LIBS PRIVATE
+ LLVMSupport
+ Python::Module
+ ${sources_target}
+ ${ARG_COMMON_CAPI_LINK_LIBS}
+ )
+ add_dependencies(${name} ${_extension_target})
+ mlir_python_setup_extension_rpath(${_extension_target})
+ list(APPEND _mlir_python_support_libs "${_extension_target}")
+ endif()
+ endforeach()
+
+ # Build extensions.
foreach(sources_target ${_flat_targets})
- _process_target(${name} ${sources_target})
+ _process_target(${name} ${sources_target} ${_mlir_python_support_libs})
endforeach()
# Create an install target.
@@ -741,9 +816,9 @@ endfunction()
################################################################################
# Build python extension
################################################################################
-function(add_mlir_python_extension libname extname)
+function(add_mlir_python_extension libname extname nb_library_target_name)
cmake_parse_arguments(ARG
- ""
+ "SUPPORT_LIB"
"INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY"
"SOURCES;LINK_LIBS"
${ARGN})
@@ -760,12 +835,29 @@ function(add_mlir_python_extension libname extname)
set(eh_rtti_enable -frtti -fexceptions)
endif ()
- nanobind_add_module(${libname}
- NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- FREE_THREADED
- NB_SHARED
- ${ARG_SOURCES}
- )
+ if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
+ endif()
+
+ if(ARG_SUPPORT_LIB)
+ add_library(${libname} SHARED ${ARG_SOURCES})
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ target_link_options(${libname} PRIVATE "-Wl,-z,undefs")
+ endif()
+ nanobind_link_options(${libname})
+ target_compile_definitions(${libname} PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
+ if (MSVC)
+ set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
+ endif ()
+ else()
+ nanobind_add_module(${libname}
+ NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ FREE_THREADED
+ NB_SHARED
+ ${ARG_SOURCES}
+ )
+ endif()
+ target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
if(APPLE)
# In llvm/cmake/modules/HandleLLVMOptions.cmake:268 we set -Wl,-flat_namespace which breaks
# the default name spacing on MacOS and causes "cross-wired" symbol resolution when multiple
@@ -778,29 +870,28 @@ function(add_mlir_python_extension libname extname)
# Avoid some warnings from upstream nanobind.
# If a superproject set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES, let
# the super project handle compile options as it wishes.
- get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES)
- target_compile_options(${NB_LIBRARY_TARGET_NAME}
+ target_compile_options(${nb_library_target_name}
PRIVATE
-Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- -Wno-missing-field-initializers
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-deprecated-literal-operator
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ -Wno-missing-field-initializers
${eh_rtti_enable})
target_compile_options(${libname}
PRIVATE
-Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- -Wno-missing-field-initializers
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-deprecated-literal-operator
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ -Wno-missing-field-initializers
${eh_rtti_enable})
endif()
@@ -819,11 +910,16 @@ function(add_mlir_python_extension libname extname)
target_compile_options(${libname} PRIVATE ${eh_rtti_enable})
# Configure the output to match python expectations.
+ if (ARG_SUPPORT_LIB)
+ set(_no_soname OFF)
+ else ()
+ set(_no_soname ON)
+ endif ()
set_target_properties(
${libname} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${ARG_OUTPUT_DIRECTORY}
OUTPUT_NAME "${extname}"
- NO_SONAME ON
+ NO_SONAME ${_no_soname}
)
if(WIN32)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index a8bbd15124df5..b22d2ec75b3ba 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -784,7 +784,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Nanobind
MODULE_NAME _mlirDialectsAMDGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
- PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectAMDGPU.cpp
PRIVATE_LINK_LIBS
@@ -841,6 +840,16 @@ if(MLIR_INCLUDE_TESTS)
)
endif()
+declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
+ SUPPORT_LIB
+ MODULE_NAME MLIRPythonSupport
+ ADD_TO_PARENT MLIRPythonSources.Core
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ SOURCES
+ IRCore.cpp
+ Globals.cpp
+)
+
################################################################################
# Common CAPI dependency DSO.
# All python extensions must link through one DSO which exports the CAPI, and
@@ -990,63 +999,3 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
endif()
endif()
-
-get_property(NB_LIBRARY_TARGET_NAME TARGET MLIRPythonModules.extension._mlir.dso PROPERTY LINK_LIBRARIES)
-list(GET NB_LIBRARY_TARGET_NAME 0 NB_LIBRARY_TARGET_NAME)
-add_mlir_library_install(${NB_LIBRARY_TARGET_NAME})
-add_library(MLIRPythonSupport SHARED
- ${PYTHON_SOURCE_DIR}/IRCore.cpp
- ${PYTHON_SOURCE_DIR}/Globals.cpp
-)
-target_link_libraries(MLIRPythonSupport PRIVATE
- LLVMSupport
- Python::Module
- ${NB_LIBRARY_TARGET_NAME}
- MLIRPythonCAPI
-
-)
-if((CMAKE_SYSTEM_NAME STREQUAL "Linux") AND (NOT LLVM_USE_SANITIZER))
- target_link_options(MLIRPythonSupport PRIVATE "-Wl,-z,undefs")
- target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
-endif()
-nanobind_link_options(MLIRPythonSupport)
-set_target_properties(MLIRPythonSupport PROPERTIES
- LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- # Needed for windows (and doesn't hurt others).
- RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
-)
-set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
- LIBRARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- BINARY_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- # Needed for windows (and doesn't hurt others).
- RUNTIME_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
- ARCHIVE_OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
-)
-set(eh_rtti_enable)
-if(MSVC)
- set(eh_rtti_enable /EHsc /GR)
- set_property(TARGET MLIRPythonSupport PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
-elseif(LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
- set(eh_rtti_enable -frtti -fexceptions)
-endif()
-target_compile_options(MLIRPythonSupport PRIVATE ${eh_rtti_enable})
-if(APPLE)
- # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
- # doesn't declare this API as undefined in its linker flags. So we need to declare it as such
- # for downstream users that do not do something like `-undefined dynamic_lookup`.
- # Same for the rest.
- target_link_options(MLIRPythonSupport PUBLIC
- "LINKER:-U,_PyClassMethod_New"
- "LINKER:-U,_PyCode_Addr2Location"
- "LINKER:-U,_PyFrame_GetLasti"
- )
-endif()
-target_link_libraries(
- MLIRPythonModules.extension._mlir.dso
- PUBLIC MLIRPythonSupport)
-target_link_libraries(
- MLIRPythonModules.extension._mlirPythonTestNanobind.dso
- PUBLIC MLIRPythonSupport)
-target_compile_definitions(MLIRPythonSupport PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
>From d2186c2581074b31d44d676831e9bc54838ace0d Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sun, 21 Dec 2025 16:12:46 -0800
Subject: [PATCH 14/38] add standalone test/use of IRCore
---
.../include/Standalone-c/Dialects.h | 7 ++++++
.../examples/standalone/lib/CAPI/Dialects.cpp | 13 ++++++++++
.../python/StandaloneExtensionNanobind.cpp | 25 +++++++++++++++++++
.../standalone/test/python/smoketest.py | 4 +++
mlir/include/mlir/Bindings/Python/Globals.h | 1 -
5 files changed, 49 insertions(+), 1 deletion(-)
diff --git a/mlir/examples/standalone/include/Standalone-c/Dialects.h b/mlir/examples/standalone/include/Standalone-c/Dialects.h
index b3e47752ccc69..5aa9e004cb9fe 100644
--- a/mlir/examples/standalone/include/Standalone-c/Dialects.h
+++ b/mlir/examples/standalone/include/Standalone-c/Dialects.h
@@ -17,6 +17,13 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Standalone, standalone);
+MLIR_CAPI_EXPORTED MlirType mlirStandaloneCustomTypeGet(MlirContext ctx,
+ MlirStringRef value);
+
+MLIR_CAPI_EXPORTED bool mlirStandaloneTypeIsACustomType(MlirType t);
+
+MLIR_CAPI_EXPORTED MlirTypeID mlirStandaloneCustomTypeGetTypeID(void);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/examples/standalone/lib/CAPI/Dialects.cpp b/mlir/examples/standalone/lib/CAPI/Dialects.cpp
index 98006e81a3d26..4de55ba485490 100644
--- a/mlir/examples/standalone/lib/CAPI/Dialects.cpp
+++ b/mlir/examples/standalone/lib/CAPI/Dialects.cpp
@@ -9,7 +9,20 @@
#include "Standalone-c/Dialects.h"
#include "Standalone/StandaloneDialect.h"
+#include "Standalone/StandaloneTypes.h"
#include "mlir/CAPI/Registration.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Standalone, standalone,
mlir::standalone::StandaloneDialect)
+
+MlirType mlirStandaloneCustomTypeGet(MlirContext ctx, MlirStringRef value) {
+ return wrap(mlir::standalone::CustomType::get(unwrap(ctx), unwrap(value)));
+}
+
+bool mlirStandaloneTypeIsACustomType(MlirType t) {
+ return llvm::isa<mlir::standalone::CustomType>(unwrap(t));
+}
+
+MlirTypeID mlirStandaloneCustomTypeGetTypeID() {
+ return wrap(mlir::standalone::CustomType::getTypeID());
+}
diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
index 0ec6cdfa7994b..37737cd89ee1e 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
@@ -11,17 +11,42 @@
#include "Standalone-c/Dialects.h"
#include "mlir-c/Dialect/Arith.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
+struct PyCustomType : mlir::python::PyConcreteType<PyCustomType> {
+ static constexpr IsAFunctionTy isaFunction = mlirStandaloneTypeIsACustomType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirStandaloneCustomTypeGetTypeID;
+ static constexpr const char *pyClassName = "CustomType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value,
+ mlir::python::DefaultingPyMlirContext context) {
+ return PyCustomType(
+ context->getRef(),
+ mlirStandaloneCustomTypeGet(
+ context.get()->get(),
+ mlirStringRefCreateFromCString(value.c_str())));
+ },
+ nb::arg("value"), nb::arg("context").none() = nb::none());
+ }
+};
+
NB_MODULE(_standaloneDialectsNanobind, m) {
//===--------------------------------------------------------------------===//
// standalone dialect
//===--------------------------------------------------------------------===//
auto standaloneM = m.def_submodule("standalone");
+ PyCustomType::bind(standaloneM);
+
standaloneM.def(
"register_dialects",
[](MlirContext context, bool load) {
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 09040eb2f45dc..dbb664d9190b2 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -19,6 +19,10 @@
# CHECK: standalone.foo %[[C2]] : i32
print(str(standalone_module), file=sys.stderr)
+ custom_type = standalone_d.CustomType.get("foo")
+ # CHECK: !standalone.custom<"foo">
+ print(custom_type)
+
# CHECK: Testing mlir package
print("Testing mlir package", file=sys.stderr)
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 4584828868451..2184e7e2dc5ca 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -15,7 +15,6 @@
#include <unordered_set>
#include <vector>
-#include "mlir-c/Debug.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
>From 216039969c7ac4feaf4ad31ed5c4b07f039d16d0 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 08:24:41 -0800
Subject: [PATCH 15/38] disable LTO by default
---
mlir/cmake/modules/AddMLIRPython.cmake | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index d581f3ce51005..790c63104911f 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -317,6 +317,14 @@ function(build_nanobind_lib)
set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
+ # nanobind configures with LTO for shared build which doesn't work everywhere
+ # (see https://github.com/llvm/llvm-project/issues/139602).
+ if(NOT LLVM_ENABLE_LTO)
+ set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
+ INTERPROCEDURAL_OPTIMIZATION_RELEASE OFF
+ INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL OFF
+ )
+ endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
endif()
>From be662fa9c13e0140d5359bfde12424a8678ce4c9 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 09:27:36 -0800
Subject: [PATCH 16/38] restore DISABLE_INSTALL
---
mlir/cmake/modules/AddMLIRPython.cmake | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 790c63104911f..d194c2dee342c 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -684,6 +684,7 @@ function(add_mlir_python_common_capi_library name)
# Generate the aggregate .so that everything depends on.
add_mlir_aggregate(${name}
SHARED
+ DISABLE_INSTALL
EMBED_LIBS ${_embed_libs}
)
>From 0606a281e4c9d34f4f4a77a41ff0baff2346c8af Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 11:09:32 -0800
Subject: [PATCH 17/38] set VISIBILITY_INLINES_HIDDEN for libMLIRPYthonSupport
---
mlir/cmake/modules/AddMLIRPython.cmake | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index d194c2dee342c..8ff4523d2cc06 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -855,6 +855,11 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
endif()
nanobind_link_options(${libname})
target_compile_definitions(${libname} PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
+ set_target_properties(${libname} PROPERTIES
+ VISIBILITY_INLINES_HIDDEN OFF
+ C_VISIBILITY_PRESET default
+ CXX_VISIBILITY_PRESET default
+ )
if (MSVC)
set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif ()
>From 8651959647dfd0dfca6752a8b216a2c011fa2e77 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 12:20:34 -0800
Subject: [PATCH 18/38] try MLIR_PYTHON_API_EXPORTED
---
mlir/cmake/modules/AddMLIRPython.cmake | 9 +-
mlir/include/mlir-c/Support.h | 2 +
mlir/include/mlir/Bindings/Python/Globals.h | 4 +-
mlir/include/mlir/Bindings/Python/IRCore.h | 121 +++++++++++---------
mlir/include/mlir/Bindings/Python/IRTypes.h | 3 +-
5 files changed, 75 insertions(+), 64 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 8ff4523d2cc06..154ec611fb358 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -854,11 +854,10 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
target_link_options(${libname} PRIVATE "-Wl,-z,undefs")
endif()
nanobind_link_options(${libname})
- target_compile_definitions(${libname} PRIVATE NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
- set_target_properties(${libname} PROPERTIES
- VISIBILITY_INLINES_HIDDEN OFF
- C_VISIBILITY_PRESET default
- CXX_VISIBILITY_PRESET default
+ target_compile_definitions(${libname}
+ PRIVATE
+ NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ MLIR_CAPI_BUILDING_LIBRARY=1
)
if (MSVC)
set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index 78fc94f93439e..6abd8894227c3 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -46,6 +46,8 @@
#define MLIR_CAPI_EXPORTED __attribute__((visibility("default")))
#endif
+#define MLIR_PYTHON_API_EXPORTED MLIR_CAPI_EXPORTED
+
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 2184e7e2dc5ca..112c7b9b0547f 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -31,7 +31,7 @@ namespace python {
/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
-class PyGlobals {
+class MLIR_PYTHON_API_EXPORTED PyGlobals {
public:
PyGlobals();
~PyGlobals();
@@ -117,7 +117,7 @@ class PyGlobals {
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
- class TracebackLoc {
+ class MLIR_PYTHON_API_EXPORTED TracebackLoc {
public:
bool locTracebacksEnabled();
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 649dfce22ad35..ceedeb691eb58 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -52,7 +52,7 @@ class PyValue;
/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
template <typename T>
-class PyObjectRef {
+class MLIR_PYTHON_API_EXPORTED PyObjectRef {
public:
PyObjectRef(T *referrent, nanobind::object object)
: referrent(referrent), object(std::move(object)) {
@@ -111,7 +111,7 @@ class PyObjectRef {
/// Context. Pushing a Context will not modify the Location or InsertionPoint
/// unless if they are from a different context, in which case, they are
/// cleared.
-class PyThreadContextEntry {
+class MLIR_PYTHON_API_EXPORTED PyThreadContextEntry {
public:
enum class FrameKind {
Context,
@@ -167,7 +167,7 @@ class PyThreadContextEntry {
/// Wrapper around MlirLlvmThreadPool
/// Python object owns the C++ thread pool
-class PyThreadPool {
+class MLIR_PYTHON_API_EXPORTED PyThreadPool {
public:
PyThreadPool() {
ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
@@ -190,7 +190,7 @@ class PyThreadPool {
/// Wrapper around MlirContext.
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
-class PyMlirContext {
+class MLIR_PYTHON_API_EXPORTED PyMlirContext {
public:
PyMlirContext() = delete;
PyMlirContext(MlirContext context);
@@ -271,7 +271,7 @@ class PyMlirContext {
/// Used in function arguments when None should resolve to the current context
/// manager set instance.
-class DefaultingPyMlirContext
+class MLIR_PYTHON_API_EXPORTED DefaultingPyMlirContext
: public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
public:
using Defaulting::Defaulting;
@@ -283,7 +283,7 @@ class DefaultingPyMlirContext
/// MlirContext. The lifetime of the context will extend at least to the
/// lifetime of these instances.
/// Immutable objects that depend on a context extend this directly.
-class BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED BaseContextObject {
public:
BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
assert(this->contextRef &&
@@ -298,7 +298,7 @@ class BaseContextObject {
};
/// Wrapper around an MlirLocation.
-class PyLocation : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyLocation : public BaseContextObject {
public:
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
: BaseContextObject(std::move(contextRef)), loc(loc) {}
@@ -329,7 +329,7 @@ class PyLocation : public BaseContextObject {
/// are only valid for the duration of a diagnostic callback and attempting
/// to access them outside of that will raise an exception. This applies to
/// nested diagnostics (in the notes) as well.
-class PyDiagnostic {
+class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
public:
PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
void invalidate();
@@ -379,7 +379,7 @@ class PyDiagnostic {
/// The object may remain live from a Python perspective for an arbitrary time
/// after detachment, but there is nothing the user can do with it (since there
/// is no way to attach an existing handler object).
-class PyDiagnosticHandler {
+class MLIR_PYTHON_API_EXPORTED PyDiagnosticHandler {
public:
PyDiagnosticHandler(MlirContext context, nanobind::object callback);
~PyDiagnosticHandler();
@@ -407,7 +407,7 @@ class PyDiagnosticHandler {
/// RAII object that captures any error diagnostics emitted to the provided
/// context.
-struct PyMlirContext::ErrorCapture {
+struct MLIR_PYTHON_API_EXPORTED PyMlirContext::ErrorCapture {
ErrorCapture(PyMlirContextRef ctx)
: ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler(
ctx->get(), handler, /*userData=*/this,
@@ -434,7 +434,7 @@ struct PyMlirContext::ErrorCapture {
/// plugins which extend dialect functionality through extension python code.
/// This should be seen as the "low-level" object and `Dialect` as the
/// high-level, user facing object.
-class PyDialectDescriptor : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyDialectDescriptor : public BaseContextObject {
public:
PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
: BaseContextObject(std::move(contextRef)), dialect(dialect) {}
@@ -447,7 +447,7 @@ class PyDialectDescriptor : public BaseContextObject {
/// User-level object for accessing dialects with dotted syntax such as:
/// ctx.dialect.std
-class PyDialects : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyDialects : public BaseContextObject {
public:
PyDialects(PyMlirContextRef contextRef)
: BaseContextObject(std::move(contextRef)) {}
@@ -458,7 +458,7 @@ class PyDialects : public BaseContextObject {
/// User-level dialect object. For dialects that have a registered extension,
/// this will be the base class of the extension dialect type. For un-extended,
/// objects of this type will be returned directly.
-class PyDialect {
+class MLIR_PYTHON_API_EXPORTED PyDialect {
public:
PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {}
@@ -471,7 +471,7 @@ class PyDialect {
/// Wrapper around an MlirDialectRegistry.
/// Upon construction, the Python wrapper takes ownership of the
/// underlying MlirDialectRegistry.
-class PyDialectRegistry {
+class MLIR_PYTHON_API_EXPORTED PyDialectRegistry {
public:
PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {}
PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {}
@@ -497,7 +497,7 @@ class PyDialectRegistry {
/// Used in function arguments when None should resolve to the current context
/// manager set instance.
-class DefaultingPyLocation
+class MLIR_PYTHON_API_EXPORTED DefaultingPyLocation
: public Defaulting<DefaultingPyLocation, PyLocation> {
public:
using Defaulting::Defaulting;
@@ -511,7 +511,7 @@ class DefaultingPyLocation
/// This is the top-level, user-owned object that contains regions/ops/blocks.
class PyModule;
using PyModuleRef = PyObjectRef<PyModule>;
-class PyModule : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyModule : public BaseContextObject {
public:
/// Returns a PyModule reference for the given MlirModule. This always returns
/// a new object.
@@ -551,7 +551,7 @@ class PyAsmState;
/// Base class for PyOperation and PyOpView which exposes the primary, user
/// visible methods for manipulating it.
-class PyOperationBase {
+class MLIR_PYTHON_API_EXPORTED PyOperationBase {
public:
virtual ~PyOperationBase() = default;
/// Implements the bound 'print' method and helps with others.
@@ -604,7 +604,8 @@ class PyOperationBase {
class PyOperation;
class PyOpView;
using PyOperationRef = PyObjectRef<PyOperation>;
-class PyOperation : public PyOperationBase, public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyOperation : public PyOperationBase,
+ public BaseContextObject {
public:
~PyOperation() override;
PyOperation &getOperation() override { return *this; }
@@ -722,7 +723,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// custom ODS-style operation classes. Since this class is subclass on the
/// python side, it must present an __init__ method that operates in pure
/// python types.
-class PyOpView : public PyOperationBase {
+class MLIR_PYTHON_API_EXPORTED PyOpView : public PyOperationBase {
public:
PyOpView(const nanobind::object &operationObject);
PyOperation &getOperation() override { return operation; }
@@ -758,7 +759,7 @@ class PyOpView : public PyOperationBase {
/// Wrapper around an MlirRegion.
/// Regions are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached regions.
-class PyRegion {
+class MLIR_PYTHON_API_EXPORTED PyRegion {
public:
PyRegion(PyOperationRef parentOperation, MlirRegion region)
: parentOperation(std::move(parentOperation)), region(region) {
@@ -777,7 +778,7 @@ class PyRegion {
};
/// Wrapper around an MlirAsmState.
-class PyAsmState {
+class MLIR_PYTHON_API_EXPORTED PyAsmState {
public:
PyAsmState(MlirValue value, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
@@ -812,7 +813,7 @@ class PyAsmState {
/// Wrapper around an MlirBlock.
/// Blocks are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached blocks.
-class PyBlock {
+class MLIR_PYTHON_API_EXPORTED PyBlock {
public:
PyBlock(PyOperationRef parentOperation, MlirBlock block)
: parentOperation(std::move(parentOperation)), block(block) {
@@ -836,7 +837,7 @@ class PyBlock {
/// Calls to insert() will insert a new operation before the
/// reference operation. If the reference operation is null, then appends to
/// the end of the block.
-class PyInsertionPoint {
+class MLIR_PYTHON_API_EXPORTED PyInsertionPoint {
public:
/// Creates an insertion point positioned after the last operation in the
/// block, but still inside the block.
@@ -877,7 +878,7 @@ class PyInsertionPoint {
};
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
-class PyType : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyType : public BaseContextObject {
public:
PyType(PyMlirContextRef contextRef, MlirType type)
: BaseContextObject(std::move(contextRef)), type(type) {}
@@ -903,7 +904,7 @@ class PyType : public BaseContextObject {
/// A TypeID provides an efficient and unique identifier for a specific C++
/// type. This allows for a C++ type to be compared, hashed, and stored in an
/// opaque context. This class wraps around the generic MlirTypeID.
-class PyTypeID {
+class MLIR_PYTHON_API_EXPORTED PyTypeID {
public:
PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
// Note, this tests whether the underlying TypeIDs are the same,
@@ -929,7 +930,7 @@ class PyTypeID {
/// concrete type class extends PyType); however, intermediate python-visible
/// base classes can be modeled by specifying a BaseTy.
template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
+class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1007,7 +1008,7 @@ class PyConcreteType : public BaseTy {
/// Wrapper around the generic MlirAttribute.
/// The lifetime of a type is bound by the PyContext that created it.
-class PyAttribute : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAttribute : public BaseContextObject {
public:
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseContextObject(std::move(contextRef)), attr(attr) {}
@@ -1033,7 +1034,7 @@ class PyAttribute : public BaseContextObject {
/// Represents a Python MlirNamedAttr, carrying an optional owned name.
/// TODO: Refactor this and the C-API to be based on an Identifier owned
/// by the context so as to avoid ownership issues here.
-class PyNamedAttribute {
+class MLIR_PYTHON_API_EXPORTED PyNamedAttribute {
public:
/// Constructs a PyNamedAttr that retains an owned name. This should be
/// used in any code that originates an MlirNamedAttribute from a python
@@ -1059,7 +1060,7 @@ class PyNamedAttribute {
/// concrete attribute class extends PyAttribute); however, intermediate
/// python-visible base classes can be modeled by specifying a BaseTy.
template <typename DerivedTy, typename BaseTy = PyAttribute>
-class PyConcreteAttribute : public BaseTy {
+class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1149,7 +1150,8 @@ class PyConcreteAttribute : public BaseTy {
static void bindDerived(ClassTy &m) {}
};
-class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
+class MLIR_PYTHON_API_EXPORTED PyStringAttribute
+ : public PyConcreteAttribute<PyStringAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
static constexpr const char *pyClassName = "StringAttr";
@@ -1166,7 +1168,7 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
/// value. For block argument values, this is the operation that contains the
/// block to which the value is an argument (blocks cannot be detached in Python
/// bindings so such operation always exists).
-class PyValue {
+class MLIR_PYTHON_API_EXPORTED PyValue {
public:
// The virtual here is "load bearing" in that it enables RTTI
// for PyConcreteValue CRTP classes that support maybeDownCast.
@@ -1196,7 +1198,7 @@ class PyValue {
};
/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
-class PyAffineExpr : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAffineExpr : public BaseContextObject {
public:
PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
: BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
@@ -1223,7 +1225,7 @@ class PyAffineExpr : public BaseContextObject {
MlirAffineExpr affineExpr;
};
-class PyAffineMap : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyAffineMap : public BaseContextObject {
public:
PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
: BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
@@ -1244,7 +1246,7 @@ class PyAffineMap : public BaseContextObject {
MlirAffineMap affineMap;
};
-class PyIntegerSet : public BaseContextObject {
+class MLIR_PYTHON_API_EXPORTED PyIntegerSet : public BaseContextObject {
public:
PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
: BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
@@ -1265,7 +1267,7 @@ class PyIntegerSet : public BaseContextObject {
};
/// Bindings for MLIR symbol tables.
-class PySymbolTable {
+class MLIR_PYTHON_API_EXPORTED PySymbolTable {
public:
/// Constructs a symbol table for the given operation.
explicit PySymbolTable(PyOperationBase &operation);
@@ -1317,7 +1319,7 @@ class PySymbolTable {
/// Custom exception that allows access to error diagnostic information. This is
/// converted to the `ir.MLIRError` python exception when thrown.
-struct MLIRError {
+struct MLIR_PYTHON_API_EXPORTED MLIRError {
MLIRError(llvm::Twine message,
std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
: message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {}
@@ -1342,7 +1344,7 @@ inline void registerMLIRError() {
});
}
-void registerMLIRErrorInCore();
+MLIR_PYTHON_API_EXPORTED void registerMLIRErrorInCore();
//------------------------------------------------------------------------------
// Utilities.
@@ -1455,7 +1457,7 @@ inline nanobind::object PyBlock::getCapsule() {
// Collections.
//------------------------------------------------------------------------------
-class PyRegionIterator {
+class MLIR_PYTHON_API_EXPORTED PyRegionIterator {
public:
PyRegionIterator(PyOperationRef operation, int nextIndex)
: operation(std::move(operation)), nextIndex(nextIndex) {}
@@ -1486,7 +1488,8 @@ class PyRegionIterator {
/// Regions of an op are fixed length and indexed numerically so are represented
/// with a sequence-like container.
-class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
+class MLIR_PYTHON_API_EXPORTED PyRegionList
+ : public Sliceable<PyRegionList, PyRegion> {
public:
static constexpr const char *pyClassName = "RegionSequence";
@@ -1529,7 +1532,7 @@ class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
PyOperationRef operation;
};
-class PyBlockIterator {
+class MLIR_PYTHON_API_EXPORTED PyBlockIterator {
public:
PyBlockIterator(PyOperationRef operation, MlirBlock next)
: operation(std::move(operation)), next(next) {}
@@ -1563,7 +1566,7 @@ class PyBlockIterator {
/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
/// we present them as a more full-featured list-like container but optimize
/// it for forward iteration. Blocks are always owned by a region.
-class PyBlockList {
+class MLIR_PYTHON_API_EXPORTED PyBlockList {
public:
PyBlockList(PyOperationRef operation, MlirRegion region)
: operation(std::move(operation)), region(region) {}
@@ -1636,7 +1639,7 @@ class PyBlockList {
MlirRegion region;
};
-class PyOperationIterator {
+class MLIR_PYTHON_API_EXPORTED PyOperationIterator {
public:
PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
: parentOperation(std::move(parentOperation)), next(next) {}
@@ -1672,7 +1675,7 @@ class PyOperationIterator {
/// Python, we present them as a more full-featured list-like container but
/// optimize it for forward iteration. Iterable operations are always owned
/// by a block.
-class PyOperationList {
+class MLIR_PYTHON_API_EXPORTED PyOperationList {
public:
PyOperationList(PyOperationRef parentOperation, MlirBlock block)
: parentOperation(std::move(parentOperation)), block(block) {}
@@ -1729,7 +1732,7 @@ class PyOperationList {
MlirBlock block;
};
-class PyOpOperand {
+class MLIR_PYTHON_API_EXPORTED PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
@@ -1754,7 +1757,7 @@ class PyOpOperand {
MlirOpOperand opOperand;
};
-class PyOpOperandIterator {
+class MLIR_PYTHON_API_EXPORTED PyOpOperandIterator {
public:
PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
@@ -1785,7 +1788,7 @@ class PyOpOperandIterator {
/// castable from it. The value hierarchy is one level deep and is not supposed
/// to accommodate other levels unless core MLIR changes.
template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
+class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1843,7 +1846,7 @@ class PyConcreteValue : public PyValue {
};
/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
+class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
public:
static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
static constexpr const char *pyClassName = "OpResult";
@@ -1887,7 +1890,8 @@ getValueTypes(Container &container, PyMlirContextRef &context) {
/// elements, random access is cheap. The (returned) result list is associated
/// with the operation whose results these are, and thus extends the lifetime of
/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+class MLIR_PYTHON_API_EXPORTED PyOpResultList
+ : public Sliceable<PyOpResultList, PyOpResult> {
public:
static constexpr const char *pyClassName = "OpResultList";
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
@@ -1940,7 +1944,8 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
};
/// Python wrapper for MlirBlockArgument.
-class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
+class MLIR_PYTHON_API_EXPORTED PyBlockArgument
+ : public PyConcreteValue<PyBlockArgument> {
public:
static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
static constexpr const char *pyClassName = "BlockArgument";
@@ -1979,7 +1984,7 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
/// Python bindings) and extends its lifetime.
-class PyBlockArgumentList
+class MLIR_PYTHON_API_EXPORTED PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
@@ -2032,7 +2037,8 @@ class PyBlockArgumentList
/// elements, random access is cheap. The (returned) operand list is associated
/// with the operation whose operands these are, and thus extends the lifetime
/// of this operation.
-class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
+class MLIR_PYTHON_API_EXPORTED PyOpOperandList
+ : public Sliceable<PyOpOperandList, PyValue> {
public:
static constexpr const char *pyClassName = "OpOperandList";
using SliceableT = Sliceable<PyOpOperandList, PyValue>;
@@ -2090,7 +2096,8 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation whose successors these are, and thus extends
/// the lifetime of this operation.
-class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
+class MLIR_PYTHON_API_EXPORTED PyOpSuccessors
+ : public Sliceable<PyOpSuccessors, PyBlock> {
public:
static constexpr const char *pyClassName = "OpSuccessors";
@@ -2138,7 +2145,8 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation and block whose successors these are, and thus
/// extends the lifetime of this operation and block.
-class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+class MLIR_PYTHON_API_EXPORTED PyBlockSuccessors
+ : public Sliceable<PyBlockSuccessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockSuccessors";
@@ -2180,7 +2188,8 @@ class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
/// WARNING: This Sliceable is more expensive than the others here because
/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
/// operands) anew for each indexed access.
-class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
+class MLIR_PYTHON_API_EXPORTED PyBlockPredecessors
+ : public Sliceable<PyBlockPredecessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockPredecessors";
@@ -2218,7 +2227,7 @@ class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
-class PyOpAttributeMap {
+class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
public:
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}
@@ -2354,7 +2363,7 @@ class PyOpAttributeMap {
PyOperationRef operation;
};
-MlirValue getUniqueResult(MlirOperation operation);
+MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
} // namespace python
} // namespace mlir
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index ba9642cf2c6a2..cd0cfbc7d61d8 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -14,7 +14,8 @@
namespace mlir {
/// Shaped Type Interface - ShapedType
-class PyShapedType : public python::PyConcreteType<PyShapedType> {
+class MLIR_PYTHON_API_EXPORTED PyShapedType
+ : public python::PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
static constexpr const char *pyClassName = "ShapedType";
>From 2cac33fe7eece9be9346969e4c6c20429b6df90c Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 13:37:11 -0800
Subject: [PATCH 19/38] globals doesn't work
---
mlir/cmake/modules/AddMLIRPython.cmake | 18 ++++++++++++------
mlir/include/mlir/Bindings/Python/Globals.h | 2 --
mlir/lib/Bindings/Python/Globals.cpp | 18 ++++++++++++------
mlir/test/Examples/standalone/test.wheel.toy | 7 ++-----
4 files changed, 26 insertions(+), 19 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 154ec611fb358..dff5528aec917 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -317,6 +317,10 @@ function(build_nanobind_lib)
set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
+ target_compile_definitions(${NB_LIBRARY_TARGET_NAME}
+ PRIVATE
+ NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
# nanobind configures with LTO for shared build which doesn't work everywhere
# (see https://github.com/llvm/llvm-project/issues/139602).
if(NOT LLVM_ENABLE_LTO)
@@ -365,6 +369,10 @@ function(add_mlir_python_modules name)
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
+ if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
+ endif()
+
# This call sets NB_LIBRARY_TARGET_NAME.
build_nanobind_lib(
INSTALL_COMPONENT ${name}
@@ -420,6 +428,8 @@ function(add_mlir_python_modules name)
get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE)
if(_source_type STREQUAL "support")
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
+ # Use a similar mechanism as nanobind to help the runtime loader pick the correct lib.
+ set(_module_name "${_module_name}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(_extension_target "${name}.extension.${_module_name}.dso")
add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${name}
@@ -844,10 +854,6 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
set(eh_rtti_enable -frtti -fexceptions)
endif ()
- if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
- set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
- endif()
-
if(ARG_SUPPORT_LIB)
add_library(${libname} SHARED ${ARG_SOURCES})
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
@@ -859,9 +865,9 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
MLIR_CAPI_BUILDING_LIBRARY=1
)
- if (MSVC)
+ if(MSVC)
set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
- endif ()
+ endif()
else()
nanobind_add_module(${libname}
NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 112c7b9b0547f..d9334cb35cc27 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -174,8 +174,6 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); }
private:
- static PyGlobals *instance;
-
nanobind::ft_mutex mutex;
/// Module name prefixes to search under for dialect implementation modules.
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index ecac571a132f6..7e451c8009809 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -19,6 +19,8 @@
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include <iostream>
+
namespace nb = nanobind;
using namespace mlir;
@@ -26,22 +28,26 @@ using namespace mlir;
// PyGlobals
// -----------------------------------------------------------------------------
+namespace {
+python::PyGlobals *pyGlobalsInstance = nullptr;
+}
+
namespace mlir::python {
-PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
- assert(!instance && "PyGlobals already constructed");
- instance = this;
+ std::cerr << MAKE_MLIR_PYTHON_QUALNAME("dialects") << "\n";
+ assert(!pyGlobalsInstance && "PyGlobals already constructed");
+ pyGlobalsInstance = this;
// The default search path include {mlir.}dialects, where {mlir.} is the
// package prefix configured at compile time.
dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
}
-PyGlobals::~PyGlobals() { instance = nullptr; }
+PyGlobals::~PyGlobals() { pyGlobalsInstance = nullptr; }
PyGlobals &PyGlobals::get() {
- assert(instance && "PyGlobals is null");
- return *instance;
+ assert(pyGlobalsInstance && "PyGlobals is null");
+ return *pyGlobalsInstance;
}
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index b60347ba687d0..8dedaa07c84f7 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -1,10 +1,6 @@
# There's no real issue with windows here, it's just that some CMake generated paths for targets end up being longer
# than 255 chars when combined with the fact that pip wants to install into a tmp directory buried under
# C/Users/ContainerAdministrator/AppData/Local/Temp.
-# UNSUPPORTED: target={{.*(windows).*}}
-# REQUIRES: expensive_checks
-# REQUIRES: non-shared-libs-build
-# REQUIRES: bindings-python
# RUN: export CMAKE_BUILD_TYPE=%cmake_build_type
# RUN: export CMAKE_CXX_COMPILER=%host_cxx
@@ -18,7 +14,8 @@
# RUN: export MLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone
# RUN: export MLIR_BINDINGS_PYTHON_NB_DOMAIN=mlir_standalone
-# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v | tee %t
+# RUN: %python -m pip install scikit-build-core
+# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v --no-build-isolation | tee %t
# RUN: rm -rf "%mlir_obj_root/standalone-python-bindings-install"
# RUN: %python -m pip install standalone_python_bindings -f "%mlir_obj_root/wheelhouse" --target "%mlir_obj_root/standalone-python-bindings-install" -v | tee -a %t
>From f94326a3864d3d59d00ae5ae702cd49ef46c2db9 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 23 Dec 2025 17:00:17 -0800
Subject: [PATCH 20/38] works
---
.../examples/standalone/python/CMakeLists.txt | 1 +
.../python/StandaloneExtensionNanobind.cpp | 6 +-
mlir/include/mlir/Bindings/Python/Globals.h | 5 +-
mlir/include/mlir/Bindings/Python/IRCore.h | 43 ++++--
mlir/include/mlir/Bindings/Python/IRTypes.h | 8 +-
mlir/lib/Bindings/Python/Globals.cpp | 26 ++--
mlir/lib/Bindings/Python/IRAffine.cpp | 34 +++--
mlir/lib/Bindings/Python/IRAttributes.cpp | 18 ++-
mlir/lib/Bindings/Python/IRCore.cpp | 28 ++--
mlir/lib/Bindings/Python/IRInterfaces.cpp | 4 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 37 +++--
mlir/lib/Bindings/Python/MainModule.cpp | 133 ++++++++++--------
mlir/lib/Bindings/Python/Pass.cpp | 39 +++--
mlir/lib/Bindings/Python/Pass.h | 3 +-
mlir/lib/Bindings/Python/Rewrite.cpp | 43 +++---
mlir/lib/Bindings/Python/Rewrite.h | 4 +-
mlir/python/CMakeLists.txt | 1 +
.../python/lib/PythonTestModuleNanobind.cpp | 13 +-
18 files changed, 275 insertions(+), 171 deletions(-)
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index edaedf18cc843..d3b3aeadb6396 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -3,6 +3,7 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `mlir_standalone`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
+add_compile_definitions("MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
################################################################################
diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
index 37737cd89ee1e..c568369913595 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp
@@ -17,7 +17,8 @@
namespace nb = nanobind;
-struct PyCustomType : mlir::python::PyConcreteType<PyCustomType> {
+struct PyCustomType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyCustomType> {
static constexpr IsAFunctionTy isaFunction = mlirStandaloneTypeIsACustomType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirStandaloneCustomTypeGetTypeID;
@@ -28,7 +29,8 @@ struct PyCustomType : mlir::python::PyConcreteType<PyCustomType> {
c.def_static(
"get",
[](const std::string &value,
- mlir::python::DefaultingPyMlirContext context) {
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
return PyCustomType(
context->getRef(),
mlirStandaloneCustomTypeGet(
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index d9334cb35cc27..5548a716cbe21 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -28,7 +28,7 @@
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Globals that are always accessible once the extension has been initialized.
/// Methods of this class are thread-safe.
class MLIR_PYTHON_API_EXPORTED PyGlobals {
@@ -174,6 +174,8 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); }
private:
+ static PyGlobals *instance;
+
nanobind::ft_mutex mutex;
/// Module name prefixes to search under for dialect implementation modules.
@@ -195,6 +197,7 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
TracebackLoc tracebackLoc;
TypeIDAllocator typeIDAllocator;
};
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index ceedeb691eb58..7ed0a9f63bfda 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -33,6 +33,7 @@
namespace mlir {
namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyBlock;
class PyDiagnostic;
@@ -325,6 +326,26 @@ class MLIR_PYTHON_API_EXPORTED PyLocation : public BaseContextObject {
MlirLocation loc;
};
+enum PyMlirDiagnosticSeverity : std::underlying_type<
+ MlirDiagnosticSeverity>::type {
+ MlirDiagnosticError = MlirDiagnosticError,
+ MlirDiagnosticWarning = MlirDiagnosticWarning,
+ MlirDiagnosticNote = MlirDiagnosticNote,
+ MlirDiagnosticRemark = MlirDiagnosticRemark
+};
+
+enum PyMlirWalkResult : std::underlying_type<MlirWalkResult>::type {
+ MlirWalkResultAdvance = MlirWalkResultAdvance,
+ MlirWalkResultInterrupt = MlirWalkResultInterrupt,
+ MlirWalkResultSkip = MlirWalkResultSkip
+};
+
+/// Traversal order for operation walk.
+enum PyMlirWalkOrder : std::underlying_type<MlirWalkOrder>::type {
+ MlirWalkPreOrder = MlirWalkPreOrder,
+ MlirWalkPostOrder = MlirWalkPostOrder
+};
+
/// Python class mirroring the C MlirDiagnostic struct. Note that these structs
/// are only valid for the duration of a diagnostic callback and attempting
/// to access them outside of that will raise an exception. This applies to
@@ -334,7 +355,7 @@ class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
void invalidate();
bool isValid() { return valid; }
- MlirDiagnosticSeverity getSeverity();
+ PyMlirDiagnosticSeverity getSeverity();
PyLocation getLocation();
nanobind::str getMessage();
nanobind::tuple getNotes();
@@ -342,7 +363,7 @@ class MLIR_PYTHON_API_EXPORTED PyDiagnostic {
/// Materialized diagnostic information. This is safe to access outside the
/// diagnostic callback.
struct DiagnosticInfo {
- MlirDiagnosticSeverity severity;
+ PyMlirDiagnosticSeverity severity;
PyLocation location;
std::string message;
std::vector<DiagnosticInfo> notes;
@@ -573,8 +594,8 @@ class MLIR_PYTHON_API_EXPORTED PyOperationBase {
std::optional<int64_t> bytecodeVersion);
// Implement the walk method.
- void walk(std::function<MlirWalkResult(MlirOperation)> callback,
- MlirWalkOrder walkOrder);
+ void walk(std::function<PyMlirWalkResult(MlirOperation)> callback,
+ PyMlirWalkOrder walkOrder);
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);
@@ -2364,6 +2385,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
};
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
@@ -2371,11 +2393,16 @@ namespace nanobind {
namespace detail {
template <>
-struct type_caster<mlir::python::DefaultingPyMlirContext>
- : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
+struct type_caster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
+ : MlirDefaultingCaster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext> {
+};
template <>
-struct type_caster<mlir::python::DefaultingPyLocation>
- : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
+struct type_caster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation>
+ : MlirDefaultingCaster<
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation> {};
} // namespace detail
} // namespace nanobind
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index cd0cfbc7d61d8..87e0e10764bd8 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -12,10 +12,11 @@
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace mlir {
-
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Shaped Type Interface - ShapedType
class MLIR_PYTHON_API_EXPORTED PyShapedType
- : public python::PyConcreteType<PyShapedType> {
+ : public PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
static constexpr const char *pyClassName = "ShapedType";
@@ -26,7 +27,8 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType
private:
void requireHasRank();
};
-
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_IRTYPES_H
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 7e451c8009809..e2e8693ba45f3 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -19,8 +19,6 @@
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include <iostream>
-
namespace nb = nanobind;
using namespace mlir;
@@ -28,26 +26,24 @@ using namespace mlir;
// PyGlobals
// -----------------------------------------------------------------------------
-namespace {
-python::PyGlobals *pyGlobalsInstance = nullptr;
-}
-
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
- std::cerr << MAKE_MLIR_PYTHON_QUALNAME("dialects") << "\n";
- assert(!pyGlobalsInstance && "PyGlobals already constructed");
- pyGlobalsInstance = this;
+ assert(!instance && "PyGlobals already constructed");
+ instance = this;
// The default search path include {mlir.}dialects, where {mlir.} is the
// package prefix configured at compile time.
dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
}
-PyGlobals::~PyGlobals() { pyGlobalsInstance = nullptr; }
+PyGlobals::~PyGlobals() { instance = nullptr; }
PyGlobals &PyGlobals::get() {
- assert(pyGlobalsInstance && "PyGlobals is null");
- return *pyGlobalsInstance;
+ assert(instance && "PyGlobals is null");
+ return *instance;
}
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
@@ -278,4 +274,6 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename(
}
return isUserTracebackFilenameCache[file];
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 624d8f0fa57ce..ce235470bbdc7 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -30,7 +30,7 @@
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
using llvm::StringRef;
@@ -80,7 +80,9 @@ static bool isPermutation(const std::vector<PermutationTy> &permutation) {
return true;
}
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
/// and should be castable from it. Intermediate hierarchy classes can be
@@ -358,7 +360,9 @@ class PyAffineCeilDivExpr
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
return mlirAffineExprEqual(affineExpr, other.affineExpr);
@@ -380,7 +384,9 @@ PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) {
//------------------------------------------------------------------------------
// PyAffineMap and utilities.
//------------------------------------------------------------------------------
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// A list of expressions contained in an affine map. Internally these are
/// stored as a consecutive array leading to inexpensive random access. Both
@@ -416,7 +422,9 @@ class PyAffineMapExprList
PyAffineMap affineMap;
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyAffineMap::operator==(const PyAffineMap &other) const {
return mlirAffineMapEqual(affineMap, other.affineMap);
@@ -438,7 +446,9 @@ PyAffineMap PyAffineMap::createFromCapsule(const nb::object &capsule) {
//------------------------------------------------------------------------------
// PyIntegerSet and utilities.
//------------------------------------------------------------------------------
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyIntegerSetConstraint {
public:
@@ -492,7 +502,9 @@ class PyIntegerSetConstraintList
PyIntegerSet set;
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
return mlirIntegerSetEqual(integerSet, other.integerSet);
@@ -511,7 +523,9 @@ PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) {
rawIntegerSet);
}
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRAffine(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineExpr and derived classes.
@@ -998,4 +1012,6 @@ void populateIRAffine(nb::module_ &m) {
PyIntegerSetConstraint::bind(m);
PyIntegerSetConstraintList::bind(m);
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index e39eabdb136b8..a4d308bf049d8 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -24,7 +24,7 @@
namespace nb = nanobind;
using namespace nanobind::literals;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
@@ -121,7 +121,9 @@ subsequent processing.
type or if the buffer does not meet expectations.
)";
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
struct nb_buffer_info {
void *ptr = nullptr;
@@ -1745,7 +1747,9 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
void PyStringAttribute::bindDerived(ClassTy &c) {
c.def_static(
@@ -1791,7 +1795,9 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
@@ -1846,4 +1852,6 @@ void populateIRAttributes(nb::module_ &m) {
PyStridedLayoutAttribute::bind(m);
registerMLIRError();
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index fc8743599508d..069e177708afc 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -31,13 +31,14 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
-using namespace mlir::python;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -169,7 +170,8 @@ MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
if (self->ctx->emitErrorDiagnostics)
return mlirLogicalResultFailure();
- if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
+ if (mlirDiagnosticGetSeverity(diag) !=
+ MlirDiagnosticSeverity::MlirDiagnosticError)
return mlirLogicalResultFailure();
self->errors.emplace_back(PyDiagnostic(diag).getInfo());
@@ -356,9 +358,10 @@ void PyDiagnostic::checkValid() {
}
}
-MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
+PyMlirDiagnosticSeverity PyDiagnostic::getSeverity() {
checkValid();
- return mlirDiagnosticGetSeverity(diagnostic);
+ return static_cast<PyMlirDiagnosticSeverity>(
+ mlirDiagnosticGetSeverity(diagnostic));
}
PyLocation PyDiagnostic::getLocation() {
@@ -672,12 +675,12 @@ void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
}
void PyOperationBase::walk(
- std::function<MlirWalkResult(MlirOperation)> callback,
- MlirWalkOrder walkOrder) {
+ std::function<PyMlirWalkResult(MlirOperation)> callback,
+ PyMlirWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
struct UserData {
- std::function<MlirWalkResult(MlirOperation)> callback;
+ std::function<PyMlirWalkResult(MlirOperation)> callback;
bool gotException;
std::string exceptionWhat;
nb::object exceptionType;
@@ -687,7 +690,7 @@ void PyOperationBase::walk(
void *userData) {
UserData *calleeUserData = static_cast<UserData *>(userData);
try {
- return (calleeUserData->callback)(op);
+ return static_cast<MlirWalkResult>((calleeUserData->callback)(op));
} catch (nb::python_error &e) {
calleeUserData->gotException = true;
calleeUserData->exceptionWhat = std::string(e.what());
@@ -695,7 +698,8 @@ void PyOperationBase::walk(
return MlirWalkResult::MlirWalkResultInterrupt;
}
};
- mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+ mlirOperationWalk(operation, walkCallback, &userData,
+ static_cast<MlirWalkOrder>(walkOrder));
if (userData.gotException) {
std::string message("Exception raised in callback: ");
message.append(userData.exceptionWhat);
@@ -1685,4 +1689,6 @@ void registerMLIRErrorInCore() {
}
});
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 78d1f977b2ebc..09112d4989ae4 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -25,7 +25,7 @@ namespace nb = nanobind;
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
constexpr static const char *constructorDoc =
R"(Creates an interface from a given operation/opview object or from a
subclass of OpView. Raises ValueError if the operation does not implement the
@@ -469,6 +469,6 @@ void populateIRInterfaces(nb::module_ &m) {
PyShapedTypeComponents::bind(m);
PyInferShapedTypeOpInterface::bind(m);
}
-
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7d9a0f16c913a..62fb2ef207d58 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -20,12 +20,14 @@
namespace nb = nanobind;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using llvm::SmallVector;
using llvm::Twine;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Checks whether the given type is an integer or float type.
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
@@ -508,10 +510,12 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
// Shaped Type Interface - ShapedType
-void mlir::PyShapedType::bindDerived(ClassTy &c) {
+void PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
[](PyShapedType &self) -> nb::typed<nb::object, PyType> {
@@ -616,17 +620,18 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
"shaped types.");
}
-void mlir::PyShapedType::requireHasRank() {
+void PyShapedType::requireHasRank() {
if (!mlirShapedTypeHasRank(*this)) {
throw nb::value_error(
"calling this method requires that the type has a rank.");
}
}
-const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
- mlirTypeIsAShaped;
+const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Vector Type subclass - VectorType.
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
@@ -1098,10 +1103,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
}
};
-static MlirStringRef toMlirStringRef(const std::string &s) {
- return mlirStringRefCreate(s.data(), s.size());
-}
-
/// Opaque Type subclass - OpaqueType.
class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
public:
@@ -1141,9 +1142,13 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
}
};
-} // namespace
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
@@ -1177,4 +1182,6 @@ void populateIRTypes(nb::module_ &m) {
PyOpaqueType::bind(m);
registerMLIRError();
}
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index f72775cc0b83a..392144ec5f0b7 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -16,7 +16,7 @@
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
static const char kModuleParseDocstring[] =
R"(Parses a module's assembly format from a string.
@@ -35,6 +35,56 @@ in `exceptions`. `exceptions` can be either a single operation or a list of
operations.
)";
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+/// Wrapper for the global LLVM debugging flag.
+struct PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
+
+ static bool get(const nanobind::object &) {
+ nanobind::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
+
+ static void bind(nanobind::module_ &m) {
+ // Debug flags.
+ nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ nanobind::arg("types"),
+ "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nanobind::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ nanobind::arg("types"),
+ "Sets multiple specific debug types to be produced by LLVM.");
+ }
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+nanobind::ft_mutex PyGlobalDebugFlag::mutex;
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
namespace {
// see
// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
@@ -185,51 +235,6 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}
-
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nanobind::object &o, bool enable) {
- nanobind::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nanobind::object &) {
- nanobind::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
-
- static void bind(nanobind::module_ &m) {
- // Debug flags.
- nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- nanobind::arg("types"),
- "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- nanobind::arg("types"),
- "Sets multiple specific debug types to be produced by LLVM.");
- }
-
-private:
- static nanobind::ft_mutex mutex;
-};
-
-nanobind::ft_mutex PyGlobalDebugFlag::mutex;
} // namespace
//------------------------------------------------------------------------------
@@ -242,20 +247,20 @@ static void populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
// Enums.
//----------------------------------------------------------------------------
- nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
- .value("ERROR", MlirDiagnosticError)
- .value("WARNING", MlirDiagnosticWarning)
- .value("NOTE", MlirDiagnosticNote)
- .value("REMARK", MlirDiagnosticRemark);
+ nb::enum_<PyMlirDiagnosticSeverity>(m, "DiagnosticSeverity")
+ .value("ERROR", PyMlirDiagnosticSeverity::MlirDiagnosticError)
+ .value("WARNING", PyMlirDiagnosticSeverity::MlirDiagnosticWarning)
+ .value("NOTE", PyMlirDiagnosticSeverity::MlirDiagnosticNote)
+ .value("REMARK", PyMlirDiagnosticSeverity::MlirDiagnosticRemark);
- nb::enum_<MlirWalkOrder>(m, "WalkOrder")
- .value("PRE_ORDER", MlirWalkPreOrder)
- .value("POST_ORDER", MlirWalkPostOrder);
+ nb::enum_<PyMlirWalkOrder>(m, "WalkOrder")
+ .value("PRE_ORDER", PyMlirWalkOrder::MlirWalkPreOrder)
+ .value("POST_ORDER", PyMlirWalkOrder::MlirWalkPostOrder);
- nb::enum_<MlirWalkResult>(m, "WalkResult")
- .value("ADVANCE", MlirWalkResultAdvance)
- .value("INTERRUPT", MlirWalkResultInterrupt)
- .value("SKIP", MlirWalkResultSkip);
+ nb::enum_<PyMlirWalkResult>(m, "WalkResult")
+ .value("ADVANCE", PyMlirWalkResult::MlirWalkResultAdvance)
+ .value("INTERRUPT", PyMlirWalkResult::MlirWalkResultInterrupt)
+ .value("SKIP", PyMlirWalkResult::MlirWalkResultSkip);
//----------------------------------------------------------------------------
// Mapping of Diagnostics.
@@ -1186,7 +1191,7 @@ static void populateIRCore(nb::module_ &m) {
Note:
After erasing, any Python references to the operation become invalid.)")
.def("walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = MlirWalkPostOrder,
+ nb::arg("walk_order") = PyMlirWalkOrder::MlirWalkPostOrder,
// clang-format off
nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
// clang-format on
@@ -2305,12 +2310,16 @@ static void populateIRCore(nb::module_ &m) {
PyAttrBuilderMap::bind(m);
}
-namespace mlir::python {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRAffine(nb::module_ &m);
void populateIRAttributes(nb::module_ &m);
void populateIRInterfaces(nb::module_ &m);
void populateIRTypes(nb::module_ &m);
-} // namespace mlir::python
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
// -----------------------------------------------------------------------------
// Module initialization.
@@ -2453,5 +2462,5 @@ NB_MODULE(_mlir, m) {
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
registerMLIRError();
- registerMLIRErrorInCore();
+ // registerMLIRErrorInCore();
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 953d1eb7fd338..708bf00186299 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -19,9 +19,11 @@
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Owning Wrapper around a PassManager.
class PyPassManager {
@@ -53,23 +55,29 @@ class PyPassManager {
MlirPassManager passManager;
};
-} // namespace
+enum PyMlirPassDisplayMode : std::underlying_type<MlirPassDisplayMode>::type {
+ MLIR_PASS_DISPLAY_MODE_LIST = MLIR_PASS_DISPLAY_MODE_LIST,
+ MLIR_PASS_DISPLAY_MODE_PIPELINE = MLIR_PASS_DISPLAY_MODE_PIPELINE
+};
+
+struct PyMlirExternalPass : MlirExternalPass {};
/// Create the `mlir.passmanager` here.
-void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+void populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of enumerated types
//----------------------------------------------------------------------------
- nb::enum_<MlirPassDisplayMode>(m, "PassDisplayMode")
+ nb::enum_<PyMlirPassDisplayMode>(m, "PassDisplayMode")
.value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
.value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
//----------------------------------------------------------------------------
// Mapping of MlirExternalPass
//----------------------------------------------------------------------------
- nb::class_<MlirExternalPass>(m, "ExternalPass")
- .def("signal_pass_failure",
- [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+ nb::class_<PyMlirExternalPass>(m, "ExternalPass")
+ .def("signal_pass_failure", [](PyMlirExternalPass pass) {
+ mlirExternalPassSignalFailure(pass);
+ });
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
@@ -148,11 +156,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"Enable pass timing.")
.def(
"enable_statistics",
- [](PyPassManager &passManager, MlirPassDisplayMode displayMode) {
- mlirPassManagerEnableStatistics(passManager.get(), displayMode);
+ [](PyPassManager &passManager, PyMlirPassDisplayMode displayMode) {
+ mlirPassManagerEnableStatistics(
+ passManager.get(),
+ static_cast<MlirPassDisplayMode>(displayMode));
},
- "displayMode"_a =
- MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE,
+ "displayMode"_a = MLIR_PASS_DISPLAY_MODE_PIPELINE,
"Enable pass statistics.")
.def_static(
"parse",
@@ -211,7 +220,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
};
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::handle(static_cast<PyObject *>(userData))(op, pass);
+ nb::handle(static_cast<PyObject *>(userData))(
+ op, PyMlirExternalPass{pass.ptr});
};
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
@@ -268,3 +278,6 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"be passed to `parse` for round-tripping.");
registerMLIRError();
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index 0221bd10e723e..1a311666ebecd 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -13,8 +13,9 @@
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populatePassManagerSubmodule(nanobind::module_ &m);
+}
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 8a3a27f78c0e4..d80c8bf364d43 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -22,9 +22,11 @@
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
-using namespace mlir::python;
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
-namespace {
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyPatternRewriter {
public:
@@ -60,6 +62,8 @@ class PyPatternRewriter {
PyMlirContextRef ctx;
};
+struct PyMlirPDLResultList : MlirPDLResultList {};
+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
@@ -118,7 +122,7 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(PyPatternRewriter(rewriter), results,
+ f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
@@ -133,7 +137,7 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
- f(PyPatternRewriter(rewriter), results,
+ f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
@@ -223,10 +227,8 @@ class PyRewritePatternSet {
MlirContext ctx;
};
-} // namespace
-
/// Create the `mlir.rewrite` here.
-void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
@@ -315,10 +317,10 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
- nb::class_<MlirPDLResultList>(m, "PDLResultList")
+ nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
.def(
"append",
- [](MlirPDLResultList results, const PyValue &value) {
+ [](PyMlirPDLResultList results, const PyValue &value) {
mlirPDLResultListPushBackValue(results, value);
},
// clang-format off
@@ -327,7 +329,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyOperation &op) {
+ [](PyMlirPDLResultList results, const PyOperation &op) {
mlirPDLResultListPushBackOperation(results, op);
},
// clang-format off
@@ -336,7 +338,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyType &type) {
+ [](PyMlirPDLResultList results, const PyType &type) {
mlirPDLResultListPushBackType(results, type);
},
// clang-format off
@@ -345,7 +347,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
)
.def(
"append",
- [](MlirPDLResultList results, const PyAttribute &attr) {
+ [](PyMlirPDLResultList results, const PyAttribute &attr) {
mlirPDLResultListPushBackAttribute(results, attr);
},
// clang-format off
@@ -355,9 +357,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
- [](PyPDLPatternModule &self, MlirModule module) {
- new (&self)
- PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
+ [](PyPDLPatternModule &self, PyModule &module) {
+ new (&self) PyPDLPatternModule(
+ mlirPDLPatternModuleFromModule(module.get()));
},
// clang-format off
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
@@ -416,9 +418,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"results.")
.def(
"apply_patterns_and_fold_greedily",
- [](PyModule &module, MlirFrozenRewritePatternSet set) {
+ [](PyModule &module, PyFrozenRewritePatternSet &set) {
auto status =
- mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
+ mlirApplyPatternsAndFoldGreedily(module.get(), set.get(), {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
@@ -447,9 +449,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"results.")
.def(
"apply_patterns_and_fold_greedily",
- [](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
+ [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
- op.getOperation(), set, {});
+ op.getOperation(), set.get(), {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
@@ -472,3 +474,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"Applies the given patterns to the given op by a fast walk-based "
"driver.");
}
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
index f8ffdc7bdc458..d287f19187708 100644
--- a/mlir/lib/Bindings/Python/Rewrite.h
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -13,9 +13,9 @@
namespace mlir {
namespace python {
-
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateRewriteSubmodule(nanobind::module_ &m);
-
+}
} // namespace python
} // namespace mlir
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b22d2ec75b3ba..2d2ae26bf3b28 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,6 +3,7 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
+add_compile_definitions("MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index c8b95e2316778..43573cbc305fa 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -27,7 +27,8 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
-struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
+struct PyTestType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPythonTestTestTypeGetTypeID;
@@ -37,7 +38,8 @@ struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::DefaultingPyMlirContext context) {
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
return PyTestType(context->getRef(),
mlirPythonTestTestTypeGet(context.get()->get()));
},
@@ -45,7 +47,9 @@ struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
}
};
-class PyTestAttr : public mlir::python::PyConcreteAttribute<PyTestAttr> {
+class PyTestAttr
+ : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
+ PyTestAttr> {
public:
static constexpr IsAFunctionTy isaFunction =
mlirAttributeIsAPythonTestTestAttribute;
@@ -57,7 +61,8 @@ class PyTestAttr : public mlir::python::PyConcreteAttribute<PyTestAttr> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::DefaultingPyMlirContext context) {
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context) {
return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
context.get()->get()));
},
>From 9e841ecc8ac8d5e7e86c8ff5603f64465633d47f Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 24 Dec 2025 10:55:35 -0800
Subject: [PATCH 21/38] try moving MLIR_BINDINGS_PYTHON_NB_DOMAIN compile defn
---
mlir/cmake/modules/AddMLIRPython.cmake | 6 ++++++
mlir/examples/standalone/python/CMakeLists.txt | 1 -
mlir/python/CMakeLists.txt | 1 -
3 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index dff5528aec917..a65f7147fdd56 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -369,6 +369,7 @@ function(add_mlir_python_modules name)
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
+ # TODO(max): do the same for MLIR_PYTHON_PACKAGE_PREFIX?
if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
endif()
@@ -863,6 +864,7 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
target_compile_definitions(${libname}
PRIVATE
NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
MLIR_CAPI_BUILDING_LIBRARY=1
)
if(MSVC)
@@ -875,6 +877,10 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
NB_SHARED
${ARG_SOURCES}
)
+ target_compile_definitions(${libname}
+ PRIVATE
+ MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
endif()
target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
if(APPLE)
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index d3b3aeadb6396..edaedf18cc843 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -3,7 +3,6 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `mlir_standalone`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
-add_compile_definitions("MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
################################################################################
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2d2ae26bf3b28..b22d2ec75b3ba 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,7 +3,6 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.`
# top level package (the API has been embedded in a relocatable way).
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
-add_compile_definitions("MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python")
>From 74a4c3ab7f2e519389963c4aaf882b130b99491a Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 24 Dec 2025 16:11:18 -0800
Subject: [PATCH 22/38] remove registerError
---
mlir/include/mlir/Bindings/Python/IRCore.h | 19 -------------------
mlir/lib/Bindings/Python/IRAttributes.cpp | 1 -
mlir/lib/Bindings/Python/IRCore.cpp | 16 ----------------
mlir/lib/Bindings/Python/IRTypes.cpp | 1 -
mlir/lib/Bindings/Python/MainModule.cpp | 16 ++++++++++++++--
mlir/lib/Bindings/Python/Pass.cpp | 1 -
6 files changed, 14 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 7ed0a9f63bfda..596ff7828631b 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1348,25 +1348,6 @@ struct MLIR_PYTHON_API_EXPORTED MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
-inline void registerMLIRError() {
- nanobind::register_exception_translator(
- [](const std::exception_ptr &p, void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nanobind::object obj =
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
-}
-
-MLIR_PYTHON_API_EXPORTED void registerMLIRErrorInCore();
-
//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index a4d308bf049d8..f0f0ae9ba741e 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1850,7 +1850,6 @@ void populateIRAttributes(nb::module_ &m) {
PyUnitAttribute::bind(m);
PyStridedLayoutAttribute::bind(m);
- registerMLIRError();
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 069e177708afc..26e0128752838 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1673,22 +1673,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
throw std::runtime_error(message);
}
}
-
-void registerMLIRErrorInCore() {
- nb::register_exception_translator([](const std::exception_ptr &p,
- void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
-}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 62fb2ef207d58..7350046f428c7 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -1180,7 +1180,6 @@ void populateIRTypes(nb::module_ &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
PyOpaqueType::bind(m);
- registerMLIRError();
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 392144ec5f0b7..071f106da04bb 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -2461,6 +2461,18 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
- registerMLIRError();
- // registerMLIRErrorInCore();
+ nanobind::register_exception_translator(
+ [](const std::exception_ptr &p, void *payload) {
+ // We can't define exceptions with custom fields through pybind, so
+ // instead the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ nanobind::object obj =
+ nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 708bf00186299..d8c0a253e8dda 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -276,7 +276,6 @@ void populatePassManagerSubmodule(nb::module_ &m) {
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
- registerMLIRError();
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
>From 683a50a800542524565bdea2bf47a3f990f813c5 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Dec 2025 13:54:14 -0800
Subject: [PATCH 23/38] comments
---
mlir/cmake/modules/AddMLIRPython.cmake | 15 +++++++++++++--
1 file changed, 13 insertions(+), 2 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index a65f7147fdd56..4a41cfe53e226 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -371,6 +371,9 @@ function(add_mlir_python_modules name)
# TODO(max): do the same for MLIR_PYTHON_PACKAGE_PREFIX?
if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ message(WARNING "MLIR_BINDINGS_PYTHON_NB_DOMAIN CMake var is not set - setting to a default `mlir`.\
+ It is highly recommend to set this to something unique so that your project's Python bindings do not collide with\
+ others'. See https://github.com/llvm/llvm-project/pull/171775 for more information.")
set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
endif()
@@ -858,13 +861,17 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
if(ARG_SUPPORT_LIB)
add_library(${libname} SHARED ${ARG_SOURCES})
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
- target_link_options(${libname} PRIVATE "-Wl,-z,undefs")
+ # nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols
+ # (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym)
+ # but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs
+ # we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much
+ # maintenance - and this shlib is the only one where we do this).
+ target_link_options(${libname} PRIVATE "LINKER:-z,undefs")
endif()
nanobind_link_options(${libname})
target_compile_definitions(${libname}
PRIVATE
NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
MLIR_CAPI_BUILDING_LIBRARY=1
)
if(MSVC)
@@ -883,6 +890,10 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
)
endif()
target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
+ target_compile_definitions(${libname}
+ PRIVATE
+ MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ )
if(APPLE)
# In llvm/cmake/modules/HandleLLVMOptions.cmake:268 we set -Wl,-flat_namespace which breaks
# the default name spacing on MacOS and causes "cross-wired" symbol resolution when multiple
>From 90654526a15eec449e9c6238cebad13cd787ec76 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Dec 2025 19:13:46 -0800
Subject: [PATCH 24/38] address comments
---
mlir/lib/Bindings/Python/Rewrite.cpp | 31 ----------------------------
1 file changed, 31 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index d80c8bf364d43..baa58f3680ae2 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -416,37 +416,6 @@ void populateRewriteSubmodule(nb::module_ &m) {
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
- .def(
- "apply_patterns_and_fold_greedily",
- [](PyModule &module, PyFrozenRewritePatternSet &set) {
- auto status =
- mlirApplyPatternsAndFoldGreedily(module.get(), set.get(), {});
- if (mlirLogicalResultIsFailure(status))
- throw std::runtime_error(
- "pattern application failed to converge");
- },
- "module"_a, "set"_a,
- // clang-format off
- nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
- "Applys the given patterns to the given module greedily while "
- "folding "
- "results.")
- .def(
- "apply_patterns_and_fold_greedily",
- [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
- auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
- op.getOperation(), set.get(), {});
- if (mlirLogicalResultIsFailure(status))
- throw std::runtime_error(
- "pattern application failed to converge");
- },
- "op"_a, "set"_a,
- // clang-format off
- nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
- // clang-format on
- "Applys the given patterns to the given op greedily while folding "
- "results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
>From 248b66e54e6f90952864efaf3201ee2c3ad7ac96 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Dec 2025 12:50:47 -0800
Subject: [PATCH 25/38] fix empty _mlir_python_support_libs
---
mlir/cmake/modules/AddMLIRPython.cmake | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 4a41cfe53e226..02b71e7826fb5 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -454,7 +454,7 @@ function(add_mlir_python_modules name)
# Build extensions.
foreach(sources_target ${_flat_targets})
- _process_target(${name} ${sources_target} ${_mlir_python_support_libs})
+ _process_target(${name} ${sources_target} "${_mlir_python_support_libs}")
endforeach()
# Create an install target.
>From 0bd8f3bdef2b9fda82bb5b013b1c4983de8073c8 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Dec 2025 13:24:06 -0800
Subject: [PATCH 26/38] parameteriez add_mlir_python_modules
---
mlir/cmake/modules/AddMLIRPython.cmake | 31 ++++++++++++++++----------
1 file changed, 19 insertions(+), 12 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 02b71e7826fb5..71316339d4a2a 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -305,7 +305,7 @@ endfunction()
function(build_nanobind_lib)
cmake_parse_arguments(ARG
""
- "INSTALL_COMPONENT;INSTALL_DESTINATION;OUTPUT_DIRECTORY"
+ "INSTALL_COMPONENT;INSTALL_DESTINATION;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
""
${ARGN})
@@ -314,12 +314,12 @@ function(build_nanobind_lib)
endif()
# nanobind does a string match on the suffix to figure out whether to build
# the lib with free threading...
- set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
+ set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE)
nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE)
target_compile_definitions(${NB_LIBRARY_TARGET_NAME}
PRIVATE
- NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
# nanobind configures with LTO for shared build which doesn't work everywhere
# (see https://github.com/llvm/llvm-project/issues/139602).
@@ -365,16 +365,20 @@ endfunction()
function(add_mlir_python_modules name)
cmake_parse_arguments(ARG
""
- "ROOT_PREFIX;INSTALL_PREFIX"
+ "ROOT_PREFIX;INSTALL_PREFIX;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
"COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES"
${ARGN})
# TODO(max): do the same for MLIR_PYTHON_PACKAGE_PREFIX?
- if(NOT MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ if((NOT ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN) AND MLIR_BINDINGS_PYTHON_NB_DOMAIN)
+ set(ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN})
+ endif()
+ if((NOT ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN) OR ("${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}" STREQUAL ""))
message(WARNING "MLIR_BINDINGS_PYTHON_NB_DOMAIN CMake var is not set - setting to a default `mlir`.\
It is highly recommend to set this to something unique so that your project's Python bindings do not collide with\
- others'. See https://github.com/llvm/llvm-project/pull/171775 for more information.")
- set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir" CACHE STRING "" FORCE)
+ others'. You also pass explicitly to `add_mlir_python_modules`.\
+ See https://github.com/llvm/llvm-project/pull/171775 for more information.")
+ set(ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir")
endif()
# This call sets NB_LIBRARY_TARGET_NAME.
@@ -382,6 +386,7 @@ function(add_mlir_python_modules name)
INSTALL_COMPONENT ${name}
INSTALL_DESTINATION "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
# Helper to process an individual target.
@@ -407,6 +412,7 @@ function(add_mlir_python_modules name)
INSTALL_COMPONENT ${modules_target}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
LINK_LIBS PRIVATE
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
@@ -433,12 +439,13 @@ function(add_mlir_python_modules name)
if(_source_type STREQUAL "support")
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
# Use a similar mechanism as nanobind to help the runtime loader pick the correct lib.
- set(_module_name "${_module_name}-${MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
+ set(_module_name "${_module_name}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
set(_extension_target "${name}.extension.${_module_name}.dso")
add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${name}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
SUPPORT_LIB
LINK_LIBS PRIVATE
LLVMSupport
@@ -842,7 +849,7 @@ endfunction()
function(add_mlir_python_extension libname extname nb_library_target_name)
cmake_parse_arguments(ARG
"SUPPORT_LIB"
- "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY"
+ "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
"SOURCES;LINK_LIBS"
${ARGN})
if(ARG_UNPARSED_ARGUMENTS)
@@ -871,7 +878,7 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
nanobind_link_options(${libname})
target_compile_definitions(${libname}
PRIVATE
- NB_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
MLIR_CAPI_BUILDING_LIBRARY=1
)
if(MSVC)
@@ -879,7 +886,7 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
endif()
else()
nanobind_add_module(${libname}
- NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
FREE_THREADED
NB_SHARED
${ARG_SOURCES}
@@ -892,7 +899,7 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
target_link_libraries(${libname} PRIVATE ${nb_library_target_name})
target_compile_definitions(${libname}
PRIVATE
- MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ MLIR_BINDINGS_PYTHON_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
if(APPLE)
# In llvm/cmake/modules/HandleLLVMOptions.cmake:268 we set -Wl,-flat_namespace which breaks
>From 2e3a6dfc95bea758a8791566a83afc8b077c7aba Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 26 Dec 2025 13:47:05 -0800
Subject: [PATCH 27/38] address comments
---
mlir/cmake/modules/AddMLIRPython.cmake | 50 ++++++++++++++++++--------
mlir/docs/Bindings/Python.md | 7 ++++
mlir/python/CMakeLists.txt | 2 +-
3 files changed, 43 insertions(+), 16 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 71316339d4a2a..ff59882583f22 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -228,15 +228,15 @@ endfunction()
# aggregate dylib that is linked against.
function(declare_mlir_python_extension name)
cmake_parse_arguments(ARG
- "SUPPORT_LIB"
- "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;SOURCES_TYPE"
+ "_PRIVATE_SUPPORT_LIB"
+ "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT"
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
${ARGN})
if(NOT ARG_ROOT_DIR)
set(ARG_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
endif()
- if(ARG_SUPPORT_LIB)
+ if(ARG__PRIVATE_SUPPORT_LIB)
set(SOURCES_TYPE "support")
else()
set(SOURCES_TYPE "extension")
@@ -309,6 +309,8 @@ function(build_nanobind_lib)
""
${ARGN})
+ # Only build in free-threaded mode if the Python ABI supports it.
+ # See https://github.com/wjakob/nanobind/blob/4ba51fcf795971c5d603d875ae4184bc0c9bd8e6/cmake/nanobind-config.cmake#L363-L371.
if (NB_ABI MATCHES "[0-9]t")
set(_ft "-ft")
endif()
@@ -321,6 +323,14 @@ function(build_nanobind_lib)
PRIVATE
NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
)
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ # nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols
+ # (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym)
+ # but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs
+ # we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much
+ # maintenance).
+ target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "LINKER:-z,undefs")
+ endif()
# nanobind configures with LTO for shared build which doesn't work everywhere
# (see https://github.com/llvm/llvm-project/issues/139602).
if(NOT LLVM_ENABLE_LTO)
@@ -329,13 +339,10 @@ function(build_nanobind_lib)
INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL OFF
)
endif()
- if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
- target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "-Wl,-z,undefs")
- endif()
set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
- # Needed for windows (and don't hurt others).
+ # Needed for windows (and doesn't hurt others).
RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
)
@@ -358,6 +365,11 @@ endfunction()
# for non-relocatable modules or a deeper directory tree for relocatable.
# INSTALL_PREFIX: Prefix into the install tree for installing the package.
# Typically mirrors the path above but without an absolute path.
+# MLIR_BINDINGS_PYTHON_NB_DOMAIN: nanobind (and MLIR) domain within which
+# extensions will be compiled. This determines whether this package
+# will share nanobind types with other bindings packages. Most likely
+# you want this to be unique to your project (and a specific set of bindings,
+# if your project builds several bindings packages).
# DECLARED_SOURCES: List of declared source groups to include. The entire
# DAG of source modules is included.
# COMMON_CAPI_LINK_LIBS: List of dylibs (typically one) to make every
@@ -446,10 +458,9 @@ function(add_mlir_python_modules name)
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- SUPPORT_LIB
+ _PRIVATE_SUPPORT_LIB
LINK_LIBS PRIVATE
LLVMSupport
- Python::Module
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
)
@@ -726,7 +737,7 @@ function(add_mlir_python_common_capi_library name)
set_target_properties(${name} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
- # Needed for windows (and don't hurt others).
+ # Needed for windows (and doesn't hurt others).
RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
)
@@ -848,7 +859,7 @@ endfunction()
################################################################################
function(add_mlir_python_extension libname extname nb_library_target_name)
cmake_parse_arguments(ARG
- "SUPPORT_LIB"
+ "_PRIVATE_SUPPORT_LIB"
"INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN"
"SOURCES;LINK_LIBS"
${ARGN})
@@ -865,14 +876,14 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
set(eh_rtti_enable -frtti -fexceptions)
endif ()
- if(ARG_SUPPORT_LIB)
+ if(ARG__PRIVATE_SUPPORT_LIB)
add_library(${libname} SHARED ${ARG_SOURCES})
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols
# (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym)
# but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs
# we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much
- # maintenance - and this shlib is the only one where we do this).
+ # maintenance).
target_link_options(${libname} PRIVATE "LINKER:-z,undefs")
endif()
nanobind_link_options(${libname})
@@ -952,12 +963,21 @@ function(add_mlir_python_extension libname extname nb_library_target_name)
target_compile_options(${libname} PRIVATE ${eh_rtti_enable})
- # Configure the output to match python expectations.
- if (ARG_SUPPORT_LIB)
+ # Quoting CMake:
+ #
+ # "If you use it on normal shared libraries which other targets link against, on some platforms a
+ # linker will insert a full path to the library (as specified at link time) into the dynamic section of the
+ # dependent binary. Therefore, once installed, dynamic loader may eventually fail to locate the library
+ # for the binary."
+ #
+ # So for support libs we do need an SO name but for extensions we do not (they're MODULEs anyway -
+ # i.e., can't be linked against, only loaded).
+ if (ARG__PRIVATE_SUPPORT_LIB)
set(_no_soname OFF)
else ()
set(_no_soname ON)
endif ()
+ # Configure the output to match python expectations.
set_target_properties(
${libname} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${ARG_OUTPUT_DIRECTORY}
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 4f4f531f7723c..3f6ac8172dbac 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -37,6 +37,13 @@
LLVM ERROR: ... unregistered/uninitialized dialect/type/pass ...`
```
+* **`MLIR_BINDINGS_PYTHON_NB_DOMAIN`**: `STRING`
+
+ nanobind (and MLIR) domain within which extensions will be compiled.
+ This determines whether this package will share nanobind types with other bindings packages.
+ Most likely you want this to be unique to your project (and a specific set of bindings).
+ This can also be passed explicitly to `add_mlir_python_modules` if your project builds several bindings packages.
+
### Recommended development practices
It is recommended to use a Python virtual environment. Many ways exist for this,
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b22d2ec75b3ba..4a9fb127ee08c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -841,7 +841,7 @@ if(MLIR_INCLUDE_TESTS)
endif()
declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
- SUPPORT_LIB
+ _PRIVATE_SUPPORT_LIB
MODULE_NAME MLIRPythonSupport
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
>From cbb95585e2086f5486530574b14e990c5513f4db Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Sat, 27 Dec 2025 04:57:19 +0000
Subject: [PATCH 28/38] Reflect rename in bazel file
---
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 785f1e01f5416..35e573bee8a1a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1172,14 +1172,15 @@ PYBIND11_FEATURES = [
filegroup(
name = "MLIRBindingsPythonSourceFiles",
srcs = [
+ "lib/Bindings/Python/Globals.cpp",
"lib/Bindings/Python/IRAffine.cpp",
"lib/Bindings/Python/IRAttributes.cpp",
"lib/Bindings/Python/IRCore.cpp",
"lib/Bindings/Python/IRInterfaces.cpp",
- "lib/Bindings/Python/IRModule.cpp",
"lib/Bindings/Python/IRTypes.cpp",
"lib/Bindings/Python/Pass.cpp",
"lib/Bindings/Python/Rewrite.cpp",
+
],
)
>From 8c3e16680b5a50a52bb1d22d78c98aa2b51b38a9 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 27 Dec 2025 10:23:40 -0800
Subject: [PATCH 29/38] address jpienaar comments
---
mlir/cmake/modules/AddMLIRPython.cmake | 8 ++++----
mlir/docs/Bindings/Python.md | 4 ++--
mlir/include/mlir/Bindings/Python/IRCore.h | 13 ++++++++-----
mlir/lib/Bindings/Python/IRCore.cpp | 3 +--
4 files changed, 15 insertions(+), 13 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index ff59882583f22..f4d078dfe7118 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -367,9 +367,9 @@ endfunction()
# Typically mirrors the path above but without an absolute path.
# MLIR_BINDINGS_PYTHON_NB_DOMAIN: nanobind (and MLIR) domain within which
# extensions will be compiled. This determines whether this package
-# will share nanobind types with other bindings packages. Most likely
-# you want this to be unique to your project (and a specific set of bindings,
-# if your project builds several bindings packages).
+# will share nanobind types with other bindings packages. Expected to be unique
+# per project (and per specific set of bindings, for projects with multiple
+# bindings packages).
# DECLARED_SOURCES: List of declared source groups to include. The entire
# DAG of source modules is included.
# COMMON_CAPI_LINK_LIBS: List of dylibs (typically one) to make every
@@ -452,7 +452,7 @@ function(add_mlir_python_modules name)
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
# Use a similar mechanism as nanobind to help the runtime loader pick the correct lib.
set(_module_name "${_module_name}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}")
- set(_extension_target "${name}.extension.${_module_name}.dso")
+ set(_extension_target "${name}.extension.${_module_name}.so")
add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME}
INSTALL_COMPONENT ${name}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 3f6ac8172dbac..4278774933a4a 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -41,8 +41,8 @@
nanobind (and MLIR) domain within which extensions will be compiled.
This determines whether this package will share nanobind types with other bindings packages.
- Most likely you want this to be unique to your project (and a specific set of bindings).
- This can also be passed explicitly to `add_mlir_python_modules` if your project builds several bindings packages.
+ Expected to be unique per project (and per specific set of bindings, for projects with multiple bindings packages).
+ Can also be passed explicitly to `add_mlir_python_modules`.
### Recommended development practices
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 596ff7828631b..616a9636ec799 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -29,6 +29,7 @@
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ThreadPool.h"
namespace mlir {
@@ -1403,11 +1404,13 @@ createBlock(const nanobind::sequence &pyArgTypes,
argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
}
- if (argTypes.size() != argLocs.size())
- throw nanobind::value_error(("Expected " + Twine(argTypes.size()) +
- " locations, got: " + Twine(argLocs.size()))
- .str()
- .c_str());
+ if (argTypes.size() != argLocs.size()) {
+ throw nanobind::value_error(
+ llvm::formatv("Expected {0} locations, got: {1}", argTypes.size(),
+ argLocs.size())
+ .str()
+ .c_str());
+ }
return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
}
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 26e0128752838..0fe508de38e85 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -924,9 +924,8 @@ nb::object PyOperation::create(std::string_view name,
// Construct the operation.
PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
- if (!operation.ptr) {
+ if (!operation.ptr)
throw MLIRError("Operation creation failed", errors.take());
- }
PyOperationRef created =
PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
>From ebc57692587b1ec38276faf9bfb276f20806b5bf Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 27 Dec 2025 11:05:25 -0800
Subject: [PATCH 30/38] move impls
---
mlir/include/mlir/Bindings/Python/IRCore.h | 444 +++-----------------
mlir/lib/Bindings/Python/IRCore.cpp | 454 +++++++++++++++++++++
mlir/lib/Bindings/Python/MainModule.cpp | 20 +
3 files changed, 539 insertions(+), 379 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 616a9636ec799..3ee6eac0cbf3f 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -171,20 +171,14 @@ class MLIR_PYTHON_API_EXPORTED PyThreadContextEntry {
/// Python object owns the C++ thread pool
class MLIR_PYTHON_API_EXPORTED PyThreadPool {
public:
- PyThreadPool() {
- ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
- }
+ PyThreadPool();
PyThreadPool(const PyThreadPool &) = delete;
PyThreadPool(PyThreadPool &&) = delete;
int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
- std::string _mlir_thread_pool_ptr() const {
- std::stringstream ss;
- ss << ownedThreadPool.get();
- return ss.str();
- }
+ std::string _mlir_thread_pool_ptr() const;
private:
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
@@ -209,9 +203,7 @@ class MLIR_PYTHON_API_EXPORTED PyMlirContext {
/// Gets a strong reference to this context, which will ensure it is kept
/// alive for the life of the reference.
- PyMlirContextRef getRef() {
- return PyMlirContextRef(this, nanobind::cast(this));
- }
+ PyMlirContextRef getRef();
/// Gets a capsule wrapping the void* within the MlirContext.
nanobind::object getCapsule();
@@ -652,32 +644,17 @@ class MLIR_PYTHON_API_EXPORTED PyOperation : public PyOperationBase,
/// Detaches the operation from its parent block and updates its state
/// accordingly.
- void detachFromParent() {
- mlirOperationRemoveFromParent(getOperation());
- setDetached();
- parentKeepAlive = nanobind::object();
- }
+ void detachFromParent();
/// Gets the backing operation.
operator MlirOperation() const { return get(); }
- MlirOperation get() const {
- checkValid();
- return operation;
- }
+ MlirOperation get() const;
- PyOperationRef getRef() {
- return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
- }
+ PyOperationRef getRef();
bool isAttached() { return attached; }
- void setAttached(const nanobind::object &parent = nanobind::object()) {
- assert(!attached && "operation already attached");
- attached = true;
- }
- void setDetached() {
- assert(attached && "operation already detached");
- attached = false;
- }
+ void setAttached(const nanobind::object &parent = nanobind::object());
+ void setDetached();
void checkValid() const;
/// Gets the owning block or raises an exception if the operation has no
@@ -802,24 +779,8 @@ class MLIR_PYTHON_API_EXPORTED PyRegion {
/// Wrapper around an MlirAsmState.
class MLIR_PYTHON_API_EXPORTED PyAsmState {
public:
- PyAsmState(MlirValue value, bool useLocalScope) {
- flags = mlirOpPrintingFlagsCreate();
- // The OpPrintingFlags are not exposed Python side, create locally and
- // associate lifetime with the state.
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- state = mlirAsmStateCreateForValue(value, flags);
- }
-
- PyAsmState(PyOperationBase &operation, bool useLocalScope) {
- flags = mlirOpPrintingFlagsCreate();
- // The OpPrintingFlags are not exposed Python side, create locally and
- // associate lifetime with the state.
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- state =
- mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
- }
+ PyAsmState(MlirValue value, bool useLocalScope);
+ PyAsmState(PyOperationBase &operation, bool useLocalScope);
~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
// Delete copy constructors.
PyAsmState(PyAsmState &other) = delete;
@@ -898,6 +859,7 @@ class MLIR_PYTHON_API_EXPORTED PyInsertionPoint {
std::optional<PyOperationRef> refOperation;
PyBlock block;
};
+
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
class MLIR_PYTHON_API_EXPORTED PyType : public BaseContextObject {
@@ -1353,26 +1315,6 @@ struct MLIR_PYTHON_API_EXPORTED MLIRError {
// Utilities.
//------------------------------------------------------------------------------
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-nanobind::object classmethod(Func f, Args... args) {
- nanobind::object cf = nanobind::cpp_function(f, args...);
- return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
-}
-
-inline nanobind::object
-createCustomDialectWrapper(const std::string &dialectNamespace,
- nanobind::object dialectDescriptor) {
- auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
- if (!dialectClass) {
- // Use the base class.
- return nanobind::cast(PyDialect(std::move(dialectDescriptor)));
- }
-
- // Create the custom implementation.
- return (*dialectClass)(std::move(dialectDescriptor));
-}
-
inline MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
@@ -1387,49 +1329,16 @@ inline MlirStringRef toMlirStringRef(const nanobind::bytes &s) {
/// Create a block, using the current location context if no locations are
/// specified.
-inline MlirBlock
+MlirBlock MLIR_PYTHON_API_EXPORTED
createBlock(const nanobind::sequence &pyArgTypes,
- const std::optional<nanobind::sequence> &pyArgLocs) {
- SmallVector<MlirType> argTypes;
- argTypes.reserve(nanobind::len(pyArgTypes));
- for (const auto &pyType : pyArgTypes)
- argTypes.push_back(nanobind::cast<PyType &>(pyType));
-
- SmallVector<MlirLocation> argLocs;
- if (pyArgLocs) {
- argLocs.reserve(nanobind::len(*pyArgLocs));
- for (const auto &pyLoc : *pyArgLocs)
- argLocs.push_back(nanobind::cast<PyLocation &>(pyLoc));
- } else if (!argTypes.empty()) {
- argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
- }
+ const std::optional<nanobind::sequence> &pyArgLocs);
- if (argTypes.size() != argLocs.size()) {
- throw nanobind::value_error(
- llvm::formatv("Expected {0} locations, got: {1}", argTypes.size(),
- argLocs.size())
- .str()
- .c_str());
- }
- return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
-}
-
-struct PyAttrBuilderMap {
- static bool dunderContains(const std::string &attributeKind) {
- return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
- }
+struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
+ static bool dunderContains(const std::string &attributeKind);
static nanobind::callable
- dunderGetItemNamed(const std::string &attributeKind) {
- auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
- if (!builder)
- throw nanobind::key_error(attributeKind.c_str());
- return *builder;
- }
+ dunderGetItemNamed(const std::string &attributeKind);
static void dunderSetItemNamed(const std::string &attributeKind,
- nanobind::callable func, bool replace) {
- PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
- replace);
- }
+ nanobind::callable func, bool replace);
static void bind(nanobind::module_ &m) {
nanobind::class_<PyAttrBuilderMap>(m, "AttrBuilder")
@@ -1450,14 +1359,6 @@ struct PyAttrBuilderMap {
}
};
-//------------------------------------------------------------------------------
-// PyBlock
-//------------------------------------------------------------------------------
-
-inline nanobind::object PyBlock::getCapsule() {
- return nanobind::steal<nanobind::object>(mlirPythonBlockToCapsule(get()));
-}
-
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@@ -1469,14 +1370,7 @@ class MLIR_PYTHON_API_EXPORTED PyRegionIterator {
PyRegionIterator &dunderIter() { return *this; }
- PyRegion dunderNext() {
- operation->checkValid();
- if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
- throw nanobind::stop_iteration();
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
- return PyRegion(operation, region);
- }
+ PyRegion dunderNext();
static void bind(nanobind::module_ &m) {
nanobind::class_<PyRegionIterator>(m, "RegionIterator")
@@ -1499,17 +1393,9 @@ class MLIR_PYTHON_API_EXPORTED PyRegionList
static constexpr const char *pyClassName = "RegionSequence";
PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumRegions(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
+ intptr_t length = -1, intptr_t step = 1);
- PyRegionIterator dunderIter() {
- operation->checkValid();
- return PyRegionIterator(operation, startIndex);
- }
+ PyRegionIterator dunderIter();
static void bindDerived(ClassTy &c) {
c.def("__iter__", &PyRegionList::dunderIter,
@@ -1520,19 +1406,11 @@ class MLIR_PYTHON_API_EXPORTED PyRegionList
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyRegionList, PyRegion>;
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumRegions(operation->get());
- }
+ intptr_t getRawNumElements();
- PyRegion getRawElement(intptr_t pos) {
- operation->checkValid();
- return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
- }
+ PyRegion getRawElement(intptr_t pos);
- PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyRegionList(operation, startIndex, length, step);
- }
+ PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) const;
PyOperationRef operation;
};
@@ -1544,16 +1422,7 @@ class MLIR_PYTHON_API_EXPORTED PyBlockIterator {
PyBlockIterator &dunderIter() { return *this; }
- PyBlock dunderNext() {
- operation->checkValid();
- if (mlirBlockIsNull(next)) {
- throw nanobind::stop_iteration();
- }
-
- PyBlock returnBlock(operation, next);
- next = mlirBlockGetNextInRegion(next);
- return returnBlock;
- }
+ PyBlock dunderNext();
static void bind(nanobind::module_ &m) {
nanobind::class_<PyBlockIterator>(m, "BlockIterator")
@@ -1576,49 +1445,14 @@ class MLIR_PYTHON_API_EXPORTED PyBlockList {
PyBlockList(PyOperationRef operation, MlirRegion region)
: operation(std::move(operation)), region(region) {}
- PyBlockIterator dunderIter() {
- operation->checkValid();
- return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
- }
+ PyBlockIterator dunderIter();
- intptr_t dunderLen() {
- operation->checkValid();
- intptr_t count = 0;
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- count += 1;
- block = mlirBlockGetNextInRegion(block);
- }
- return count;
- }
+ intptr_t dunderLen();
- PyBlock dunderGetItem(intptr_t index) {
- operation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nanobind::index_error("attempt to access out of bounds block");
- }
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- if (index == 0) {
- return PyBlock(operation, block);
- }
- block = mlirBlockGetNextInRegion(block);
- index -= 1;
- }
- throw nanobind::index_error("attempt to access out of bounds block");
- }
+ PyBlock dunderGetItem(intptr_t index);
PyBlock appendBlock(const nanobind::args &pyArgTypes,
- const std::optional<nanobind::sequence> &pyArgLocs) {
- operation->checkValid();
- MlirBlock block =
- createBlock(nanobind::cast<nanobind::sequence>(pyArgTypes), pyArgLocs);
- mlirRegionAppendOwnedBlock(region, block);
- return PyBlock(operation, block);
- }
+ const std::optional<nanobind::sequence> &pyArgLocs);
static void bind(nanobind::module_ &m) {
nanobind::class_<PyBlockList>(m, "BlockList")
@@ -1651,17 +1485,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationIterator {
PyOperationIterator &dunderIter() { return *this; }
- nanobind::typed<nanobind::object, PyOpView> dunderNext() {
- parentOperation->checkValid();
- if (mlirOperationIsNull(next)) {
- throw nanobind::stop_iteration();
- }
-
- PyOperationRef returnOperation =
- PyOperation::forOperation(parentOperation->getContext(), next);
- next = mlirOperationGetNextInBlock(next);
- return returnOperation->createOpView();
- }
+ nanobind::typed<nanobind::object, PyOpView> dunderNext();
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOperationIterator>(m, "OperationIterator")
@@ -1691,36 +1515,9 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
mlirBlockGetFirstOperation(block));
}
- intptr_t dunderLen() {
- parentOperation->checkValid();
- intptr_t count = 0;
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- count += 1;
- childOp = mlirOperationGetNextInBlock(childOp);
- }
- return count;
- }
+ intptr_t dunderLen();
- nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index) {
- parentOperation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nanobind::index_error("attempt to access out of bounds operation");
- }
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- if (index == 0) {
- return PyOperation::forOperation(parentOperation->getContext(), childOp)
- ->createOpView();
- }
- childOp = mlirOperationGetNextInBlock(childOp);
- index -= 1;
- }
- throw nanobind::index_error("attempt to access out of bounds operation");
- }
+ nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index);
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOperationList>(m, "OperationList")
@@ -1741,14 +1538,9 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
- nanobind::typed<nanobind::object, PyOpView> getOwner() {
- MlirOperation owner = mlirOpOperandGetOwner(opOperand);
- PyMlirContextRef context =
- PyMlirContext::forContext(mlirOperationGetContext(owner));
- return PyOperation::forOperation(context, owner)->createOpView();
- }
+ nanobind::typed<nanobind::object, PyOpView> getOwner() const;
- size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
+ size_t getOperandNumber() const;
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOpOperand>(m, "OpOperand")
@@ -1768,14 +1560,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandIterator {
PyOpOperandIterator &dunderIter() { return *this; }
- PyOpOperand dunderNext() {
- if (mlirOpOperandIsNull(opOperand))
- throw nanobind::stop_iteration();
-
- PyOpOperand returnOpOperand(opOperand);
- opOperand = mlirOpOperandGetNextUse(opOperand);
- return returnOpOperand;
- }
+ PyOpOperand dunderNext();
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOpOperandIterator>(m, "OpOperandIterator")
@@ -1931,19 +1716,12 @@ class MLIR_PYTHON_API_EXPORTED PyOpResultList
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyOpResultList, PyOpResult>;
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
- }
+ intptr_t getRawNumElements();
- PyOpResult getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
- }
+ PyOpResult getRawElement(intptr_t index);
- PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpResultList(operation, startIndex, length, step);
- }
+ PyOpResultList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
PyOperationRef operation;
};
@@ -2017,22 +1795,14 @@ class MLIR_PYTHON_API_EXPORTED PyBlockArgumentList
friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
/// Returns the number of arguments in the list.
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirBlockGetNumArguments(block);
- }
+ intptr_t getRawNumElements();
/// Returns `pos`-the element in the list.
- PyBlockArgument getRawElement(intptr_t pos) {
- MlirValue argument = mlirBlockGetArgument(block, pos);
- return PyBlockArgument(operation, argument);
- }
+ PyBlockArgument getRawElement(intptr_t pos) const;
/// Returns a sublist of this list.
PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockArgumentList(operation, block, startIndex, length, step);
- }
+ intptr_t step) const;
PyOperationRef operation;
MlirBlock block;
@@ -2056,10 +1826,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandList
step),
operation(operation) {}
- void dunderSetItem(intptr_t index, PyValue value) {
- index = wrapIndex(index);
- mlirOperationSetOperand(operation->get(), index, value.get());
- }
+ void dunderSetItem(intptr_t index, PyValue value);
static void bindDerived(ClassTy &c) {
c.def("__setitem__", &PyOpOperandList::dunderSetItem,
@@ -2071,28 +1838,12 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandList
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyOpOperandList, PyValue>;
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumOperands(operation->get());
- }
+ intptr_t getRawNumElements();
- PyValue getRawElement(intptr_t pos) {
- MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
- return PyValue(pyOwner, operand);
- }
+ PyValue getRawElement(intptr_t pos);
- PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpOperandList(operation, startIndex, length, step);
- }
+ PyOpOperandList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
PyOperationRef operation;
};
@@ -2114,10 +1865,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpSuccessors
step),
operation(operation) {}
- void dunderSetItem(intptr_t index, PyBlock block) {
- index = wrapIndex(index);
- mlirOperationSetSuccessor(operation->get(), index, block.get());
- }
+ void dunderSetItem(intptr_t index, PyBlock block);
static void bindDerived(ClassTy &c) {
c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nanobind::arg("index"),
@@ -2129,19 +1877,12 @@ class MLIR_PYTHON_API_EXPORTED PyOpSuccessors
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyOpSuccessors, PyBlock>;
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumSuccessors(operation->get());
- }
+ intptr_t getRawNumElements();
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
- return PyBlock(operation, block);
- }
+ PyBlock getRawElement(intptr_t pos);
- PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpSuccessors(operation, startIndex, length, step);
- }
+ PyOpSuccessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
PyOperationRef operation;
};
@@ -2168,19 +1909,12 @@ class MLIR_PYTHON_API_EXPORTED PyBlockSuccessors
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockSuccessors, PyBlock>;
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumSuccessors(block.get());
- }
+ intptr_t getRawNumElements();
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
+ PyBlock getRawElement(intptr_t pos);
- PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyBlockSuccessors(block, operation, startIndex, length, step);
- }
+ PyBlockSuccessors slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const;
PyOperationRef operation;
PyBlock block;
@@ -2211,20 +1945,12 @@ class MLIR_PYTHON_API_EXPORTED PyBlockPredecessors
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockPredecessors, PyBlock>;
- intptr_t getRawNumElements() {
- block.checkValid();
- return mlirBlockGetNumPredecessors(block.get());
- }
+ intptr_t getRawNumElements();
- PyBlock getRawElement(intptr_t pos) {
- MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
- return PyBlock(operation, block);
- }
+ PyBlock getRawElement(intptr_t pos);
PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyBlockPredecessors(block, operation, startIndex, length, step);
- }
+ intptr_t step) const;
PyOperationRef operation;
PyBlock block;
@@ -2238,61 +1964,21 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
: operation(std::move(operation)) {}
nanobind::typed<nanobind::object, PyAttribute>
- dunderGetItemNamed(const std::string &name) {
- MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
- throw nanobind::key_error("attempt to access a non-existent attribute");
- }
- return PyAttribute(operation->getContext(), attr).maybeDownCast();
- }
+ dunderGetItemNamed(const std::string &name);
- PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0 || index >= dunderLen()) {
- throw nanobind::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr =
- mlirOperationGetAttribute(operation->get(), index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data,
- mlirIdentifierStr(namedAttr.name).length));
- }
+ PyNamedAttribute dunderGetItemIndexed(intptr_t index);
- void dunderSetItem(const std::string &name, const PyAttribute &attr) {
- mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
- attr);
- }
+ void dunderSetItem(const std::string &name, const PyAttribute &attr);
- void dunderDelItem(const std::string &name) {
- int removed = mlirOperationRemoveAttributeByName(operation->get(),
- toMlirStringRef(name));
- if (!removed)
- throw nanobind::key_error("attempt to delete a non-existent attribute");
- }
+ void dunderDelItem(const std::string &name);
- intptr_t dunderLen() {
- return mlirOperationGetNumAttributes(operation->get());
- }
+ intptr_t dunderLen();
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
- operation->get(), toMlirStringRef(name)));
- }
+ bool dunderContains(const std::string &name);
static void
forEachAttr(MlirOperation op,
- llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
- intptr_t n = mlirOperationGetNumAttributes(op);
- for (intptr_t i = 0; i < n; ++i) {
- MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
- MlirStringRef name = mlirIdentifierStr(na.name);
- fn(name, na.attribute);
- }
- }
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn);
static void bind(nanobind::module_ &m) {
nanobind::class_<PyOpAttributeMap>(m, "OpAttributeMap")
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0fe508de38e85..2ea07a6c9adec 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -39,6 +39,20 @@ using llvm::Twine;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+//------------------------------------------------------------------------------
+// PyThreadPool
+//------------------------------------------------------------------------------
+
+PyThreadPool::PyThreadPool() {
+ ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+}
+
+std::string PyThreadPool::_mlir_thread_pool_ptr() const {
+ std::stringstream ss;
+ ss << ownedThreadPool.get();
+ return ss.str();
+}
+
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
@@ -62,6 +76,9 @@ PyMlirContext::~PyMlirContext() {
mlirContextDestroy(context);
}
+PyMlirContextRef PyMlirContext::getRef() {
+ return PyMlirContextRef(this, nanobind::cast(this));
+}
nb::object PyMlirContext::getCapsule() {
return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
}
@@ -598,6 +615,31 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
return PyOperation::createDetached(std::move(contextRef), op);
}
+void PyOperation::detachFromParent() {
+ mlirOperationRemoveFromParent(getOperation());
+ setDetached();
+ parentKeepAlive = nanobind::object();
+}
+
+MlirOperation PyOperation::get() const {
+ checkValid();
+ return operation;
+}
+
+PyOperationRef PyOperation::getRef() {
+ return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
+}
+
+void PyOperation::setAttached(const nanobind::object &parent) {
+ assert(!attached && "operation already attached");
+ attached = true;
+}
+
+void PyOperation::setDetached() {
+ assert(attached && "operation already detached");
+ attached = false;
+}
+
void PyOperation::checkValid() const {
if (!valid) {
throw std::runtime_error("the operation has been invalidated");
@@ -1292,6 +1334,36 @@ PyOpView::PyOpView(const nb::object &operationObject)
: operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
operationObject(operation.getRef().getObject()) {}
+//------------------------------------------------------------------------------
+// PyBlock
+//------------------------------------------------------------------------------
+
+nanobind::object PyBlock::getCapsule() {
+ return nanobind::steal<nanobind::object>(mlirPythonBlockToCapsule(get()));
+}
+
+//------------------------------------------------------------------------------
+// PyAsmState
+//------------------------------------------------------------------------------
+
+PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
+ flags = mlirOpPrintingFlagsCreate();
+ // The OpPrintingFlags are not exposed Python side, create locally and
+ // associate lifetime with the state.
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ state = mlirAsmStateCreateForValue(value, flags);
+}
+
+PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) {
+ flags = mlirOpPrintingFlagsCreate();
+ // The OpPrintingFlags are not exposed Python side, create locally and
+ // associate lifetime with the state.
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
+}
+
//------------------------------------------------------------------------------
// PyInsertionPoint.
//------------------------------------------------------------------------------
@@ -1672,6 +1744,388 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
throw std::runtime_error(message);
}
}
+
+MlirBlock createBlock(const nanobind::sequence &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ SmallVector<MlirType> argTypes;
+ argTypes.reserve(nanobind::len(pyArgTypes));
+ for (const auto &pyType : pyArgTypes)
+ argTypes.push_back(nanobind::cast<PyType &>(pyType));
+
+ SmallVector<MlirLocation> argLocs;
+ if (pyArgLocs) {
+ argLocs.reserve(nanobind::len(*pyArgLocs));
+ for (const auto &pyLoc : *pyArgLocs)
+ argLocs.push_back(nanobind::cast<PyLocation &>(pyLoc));
+ } else if (!argTypes.empty()) {
+ argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
+ }
+
+ if (argTypes.size() != argLocs.size()) {
+ throw nanobind::value_error(("Expected " + Twine(argTypes.size()) +
+ " locations, got: " + Twine(argLocs.size()))
+ .str()
+ .c_str());
+ }
+ return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
+}
+
+//------------------------------------------------------------------------------
+// PyAttrBuilderMap
+//------------------------------------------------------------------------------
+
+bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
+ return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+}
+
+nanobind::callable
+PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
+ auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+ if (!builder)
+ throw nanobind::key_error(attributeKind.c_str());
+ return *builder;
+}
+
+void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
+ nanobind::callable func,
+ bool replace) {
+ PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
+ replace);
+}
+
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+PyRegion PyRegionIterator::dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw nanobind::stop_iteration();
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+}
+
+PyRegionList::PyRegionList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+PyRegionIterator PyRegionList::dunderIter() {
+ operation->checkValid();
+ return PyRegionIterator(operation, startIndex);
+}
+
+intptr_t PyRegionList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+}
+
+PyRegion PyRegionList::getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
+}
+
+PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyRegionList(operation, startIndex, length, step);
+}
+
+PyBlock PyBlockIterator::dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+}
+
+PyBlockIterator PyBlockList::dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+}
+
+intptr_t PyBlockList::dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
+ }
+ return count;
+}
+
+PyBlock PyBlockList::dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds block");
+ }
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
+ }
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds block");
+}
+
+PyBlock
+PyBlockList::appendBlock(const nanobind::args &pyArgTypes,
+ const std::optional<nanobind::sequence> &pyArgLocs) {
+ operation->checkValid();
+ MlirBlock block =
+ createBlock(nanobind::cast<nanobind::sequence>(pyArgTypes), pyArgLocs);
+ mlirRegionAppendOwnedBlock(region, block);
+ return PyBlock(operation, block);
+}
+
+nanobind::typed<nanobind::object, PyOpView> PyOperationIterator::dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw nanobind::stop_iteration();
+ }
+
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation->createOpView();
+}
+
+intptr_t PyOperationList::dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
+ }
+ return count;
+}
+
+nanobind::typed<nanobind::object, PyOpView>
+PyOperationList::dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nanobind::index_error("attempt to access out of bounds operation");
+ }
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ ->createOpView();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
+ }
+ throw nanobind::index_error("attempt to access out of bounds operation");
+}
+
+nanobind::typed<nanobind::object, PyOpView> PyOpOperand::getOwner() const {
+ MlirOperation owner = mlirOpOperandGetOwner(opOperand);
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(owner));
+ return PyOperation::forOperation(context, owner)->createOpView();
+}
+
+size_t PyOpOperand::getOperandNumber() const {
+ return mlirOpOperandGetOperandNumber(opOperand);
+}
+
+PyOpOperand PyOpOperandIterator::dunderNext() {
+ if (mlirOpOperandIsNull(opOperand))
+ throw nanobind::stop_iteration();
+
+ PyOpOperand returnOpOperand(opOperand);
+ opOperand = mlirOpOperandGetNextUse(opOperand);
+ return returnOpOperand;
+}
+
+//------------------------------------------------------------------------------
+// PyConcreteValue
+//------------------------------------------------------------------------------
+
+intptr_t PyOpResultList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+}
+
+PyOpResult PyOpResultList::getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+}
+
+PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpResultList(operation, startIndex, length, step);
+}
+
+intptr_t PyBlockArgumentList::getRawNumElements() {
+ operation->checkValid();
+ return mlirBlockGetNumArguments(block);
+}
+
+PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
+ MlirValue argument = mlirBlockGetArgument(block, pos);
+ return PyBlockArgument(operation, argument);
+}
+
+PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
+ intptr_t length,
+ intptr_t step) const {
+ return PyBlockArgumentList(operation, block, startIndex, length, step);
+}
+
+void PyOpOperandList::dunderSetItem(intptr_t index, PyValue value) {
+ index = wrapIndex(index);
+ mlirOperationSetOperand(operation->get(), index, value.get());
+}
+
+intptr_t PyOpOperandList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+}
+
+PyValue PyOpOperandList::getRawElement(intptr_t pos) {
+ MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(operand))
+ owner = mlirOpResultGetOwner(operand);
+ else if (mlirValueIsABlockArgument(operand))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ PyOperationRef pyOwner =
+ PyOperation::forOperation(operation->getContext(), owner);
+ return PyValue(pyOwner, operand);
+}
+
+PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpOperandList(operation, startIndex, length, step);
+}
+
+void PyOpSuccessors::dunderSetItem(intptr_t index, PyBlock block) {
+ index = wrapIndex(index);
+ mlirOperationSetSuccessor(operation->get(), index, block.get());
+}
+
+intptr_t PyOpSuccessors::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumSuccessors(operation->get());
+}
+
+PyBlock PyOpSuccessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
+ return PyBlock(operation, block);
+}
+
+PyOpSuccessors PyOpSuccessors::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpSuccessors(operation, startIndex, length, step);
+}
+
+intptr_t PyBlockSuccessors::getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumSuccessors(block.get());
+}
+
+PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+}
+
+PyBlockSuccessors PyBlockSuccessors::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyBlockSuccessors(block, operation, startIndex, length, step);
+}
+
+intptr_t PyBlockPredecessors::getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumPredecessors(block.get());
+}
+
+PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+}
+
+PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
+ intptr_t length,
+ intptr_t step) const {
+ return PyBlockPredecessors(block, operation, startIndex, length, step);
+}
+
+nanobind::typed<nanobind::object, PyAttribute>
+PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
+ MlirAttribute attr =
+ mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr)) {
+ throw nanobind::key_error("attempt to access a non-existent attribute");
+ }
+ return PyAttribute(operation->getContext(), attr).maybeDownCast();
+}
+
+PyNamedAttribute PyOpAttributeMap::dunderGetItemIndexed(intptr_t index) {
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0 || index >= dunderLen()) {
+ throw nanobind::index_error("attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr =
+ mlirOperationGetAttribute(operation->get(), index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data,
+ mlirIdentifierStr(namedAttr.name).length));
+}
+
+void PyOpAttributeMap::dunderSetItem(const std::string &name,
+ const PyAttribute &attr) {
+ mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
+ attr);
+}
+
+void PyOpAttributeMap::dunderDelItem(const std::string &name) {
+ int removed = mlirOperationRemoveAttributeByName(operation->get(),
+ toMlirStringRef(name));
+ if (!removed)
+ throw nanobind::key_error("attempt to delete a non-existent attribute");
+}
+
+intptr_t PyOpAttributeMap::dunderLen() {
+ return mlirOperationGetNumAttributes(operation->get());
+}
+
+bool PyOpAttributeMap::dunderContains(const std::string &name) {
+ return !mlirAttributeIsNull(
+ mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)));
+}
+
+void PyOpAttributeMap::forEachAttr(
+ MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
+ }
+}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 071f106da04bb..79c8e36609d76 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -235,6 +235,26 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
return {ref, mlirLoc};
}
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+nanobind::object classmethod(Func f, Args... args) {
+ nanobind::object cf = nanobind::cpp_function(f, args...);
+ return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
+}
+
+nanobind::object
+createCustomDialectWrapper(const std::string &dialectNamespace,
+ nanobind::object dialectDescriptor) {
+ auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+ if (!dialectClass) {
+ // Use the base class.
+ return nanobind::cast(PyDialect(std::move(dialectDescriptor)));
+ }
+
+ // Create the custom implementation.
+ return (*dialectClass)(std::move(dialectDescriptor));
+}
} // namespace
//------------------------------------------------------------------------------
>From b2446c8996ce4a58845ae63e64cda8de662306f6 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 27 Dec 2025 11:45:51 -0800
Subject: [PATCH 31/38] remove stray newline
---
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 -
1 file changed, 1 deletion(-)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 35e573bee8a1a..40faca3c826df 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1180,7 +1180,6 @@ filegroup(
"lib/Bindings/Python/IRTypes.cpp",
"lib/Bindings/Python/Pass.cpp",
"lib/Bindings/Python/Rewrite.cpp",
-
],
)
>From bf596cd234c3f24b6e1a31ad647e7932f094a14a Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 29 Dec 2025 11:24:45 -0800
Subject: [PATCH 32/38] jakub's suggestion
---
mlir/include/mlir/Bindings/Python/IRCore.h | 7 +++----
mlir/lib/Bindings/Python/Pass.cpp | 2 +-
2 files changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 3ee6eac0cbf3f..23c6ef02aebd1 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -319,22 +319,21 @@ class MLIR_PYTHON_API_EXPORTED PyLocation : public BaseContextObject {
MlirLocation loc;
};
-enum PyMlirDiagnosticSeverity : std::underlying_type<
- MlirDiagnosticSeverity>::type {
+enum PyMlirDiagnosticSeverity : std::underlying_type_t<MlirDiagnosticSeverity> {
MlirDiagnosticError = MlirDiagnosticError,
MlirDiagnosticWarning = MlirDiagnosticWarning,
MlirDiagnosticNote = MlirDiagnosticNote,
MlirDiagnosticRemark = MlirDiagnosticRemark
};
-enum PyMlirWalkResult : std::underlying_type<MlirWalkResult>::type {
+enum PyMlirWalkResult : std::underlying_type_t<MlirWalkResult> {
MlirWalkResultAdvance = MlirWalkResultAdvance,
MlirWalkResultInterrupt = MlirWalkResultInterrupt,
MlirWalkResultSkip = MlirWalkResultSkip
};
/// Traversal order for operation walk.
-enum PyMlirWalkOrder : std::underlying_type<MlirWalkOrder>::type {
+enum PyMlirWalkOrder : std::underlying_type_t<MlirWalkOrder> {
MlirWalkPreOrder = MlirWalkPreOrder,
MlirWalkPostOrder = MlirWalkPostOrder
};
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index d8c0a253e8dda..b4a256d847ad5 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -55,7 +55,7 @@ class PyPassManager {
MlirPassManager passManager;
};
-enum PyMlirPassDisplayMode : std::underlying_type<MlirPassDisplayMode>::type {
+enum PyMlirPassDisplayMode : std::underlying_type_t<MlirPassDisplayMode> {
MLIR_PASS_DISPLAY_MODE_LIST = MLIR_PASS_DISPLAY_MODE_LIST,
MLIR_PASS_DISPLAY_MODE_PIPELINE = MLIR_PASS_DISPLAY_MODE_PIPELINE
};
>From 14dd9b7e2096b196d6d01d355b6a5c8d10377682 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 29 Dec 2025 16:16:23 -0800
Subject: [PATCH 33/38] factor out more impls
---
mlir/include/mlir/Bindings/Python/IRCore.h | 291 ++-------------------
mlir/lib/Bindings/Python/IRCore.cpp | 49 ++++
mlir/lib/Bindings/Python/MainModule.cpp | 268 ++++++++++++++++++-
3 files changed, 330 insertions(+), 278 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 23c6ef02aebd1..0f402b4ce15ff 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1339,23 +1339,7 @@ struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
static void dunderSetItemNamed(const std::string &attributeKind,
nanobind::callable func, bool replace);
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyAttrBuilderMap>(m, "AttrBuilder")
- .def_static("contains", &PyAttrBuilderMap::dunderContains,
- nanobind::arg("attribute_kind"),
- "Checks whether an attribute builder is registered for the "
- "given attribute kind.")
- .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
- nanobind::arg("attribute_kind"),
- "Gets the registered attribute builder for the given "
- "attribute kind.")
- .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
- nanobind::arg("attribute_kind"),
- nanobind::arg("attr_builder"),
- nanobind::arg("replace") = false,
- "Register an attribute builder for building MLIR "
- "attributes from Python values.");
- }
+ static void bind(nanobind::module_ &m);
};
//------------------------------------------------------------------------------
@@ -1371,13 +1355,7 @@ class MLIR_PYTHON_API_EXPORTED PyRegionIterator {
PyRegion dunderNext();
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyRegionIterator>(m, "RegionIterator")
- .def("__iter__", &PyRegionIterator::dunderIter,
- "Returns an iterator over the regions in the operation.")
- .def("__next__", &PyRegionIterator::dunderNext,
- "Returns the next region in the iteration.");
- }
+ static void bind(nanobind::module_ &m);
private:
PyOperationRef operation;
@@ -1396,10 +1374,7 @@ class MLIR_PYTHON_API_EXPORTED PyRegionList
PyRegionIterator dunderIter();
- static void bindDerived(ClassTy &c) {
- c.def("__iter__", &PyRegionList::dunderIter,
- "Returns an iterator over the regions in the sequence.");
- }
+ static void bindDerived(ClassTy &c);
private:
/// Give the parent CRTP class access to hook implementations below.
@@ -1423,13 +1398,7 @@ class MLIR_PYTHON_API_EXPORTED PyBlockIterator {
PyBlock dunderNext();
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyBlockIterator>(m, "BlockIterator")
- .def("__iter__", &PyBlockIterator::dunderIter,
- "Returns an iterator over the blocks in the operation's region.")
- .def("__next__", &PyBlockIterator::dunderNext,
- "Returns the next block in the iteration.");
- }
+ static void bind(nanobind::module_ &m);
private:
PyOperationRef operation;
@@ -1453,24 +1422,7 @@ class MLIR_PYTHON_API_EXPORTED PyBlockList {
PyBlock appendBlock(const nanobind::args &pyArgTypes,
const std::optional<nanobind::sequence> &pyArgLocs);
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyBlockList>(m, "BlockList")
- .def("__getitem__", &PyBlockList::dunderGetItem,
- "Returns the block at the specified index.")
- .def("__iter__", &PyBlockList::dunderIter,
- "Returns an iterator over blocks in the operation's region.")
- .def("__len__", &PyBlockList::dunderLen,
- "Returns the number of blocks in the operation's region.")
- .def("append", &PyBlockList::appendBlock,
- R"(
- Appends a new block, with argument types as positional args.
-
- Returns:
- The created block.
- )",
- nanobind::arg("args"), nanobind::kw_only(),
- nanobind::arg("arg_locs") = std::nullopt);
- }
+ static void bind(nanobind::module_ &m);
private:
PyOperationRef operation;
@@ -1486,13 +1438,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationIterator {
nanobind::typed<nanobind::object, PyOpView> dunderNext();
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyOperationIterator>(m, "OperationIterator")
- .def("__iter__", &PyOperationIterator::dunderIter,
- "Returns an iterator over the operations in an operation's block.")
- .def("__next__", &PyOperationIterator::dunderNext,
- "Returns the next operation in the iteration.");
- }
+ static void bind(nanobind::module_ &m);
private:
PyOperationRef parentOperation;
@@ -1518,15 +1464,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
nanobind::typed<nanobind::object, PyOpView> dunderGetItem(intptr_t index);
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyOperationList>(m, "OperationList")
- .def("__getitem__", &PyOperationList::dunderGetItem,
- "Returns the operation at the specified index.")
- .def("__iter__", &PyOperationList::dunderIter,
- "Returns an iterator over operations in the list.")
- .def("__len__", &PyOperationList::dunderLen,
- "Returns the number of operations in the list.");
- }
+ static void bind(nanobind::module_ &m);
private:
PyOperationRef parentOperation;
@@ -1541,13 +1479,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperand {
size_t getOperandNumber() const;
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyOpOperand>(m, "OpOperand")
- .def_prop_ro("owner", &PyOpOperand::getOwner,
- "Returns the operation that owns this operand.")
- .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
- "Returns the operand number in the owning operation.");
- }
+ static void bind(nanobind::module_ &m);
private:
MlirOpOperand opOperand;
@@ -1561,13 +1493,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandIterator {
PyOpOperand dunderNext();
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyOpOperandIterator>(m, "OpOperandIterator")
- .def("__iter__", &PyOpOperandIterator::dunderIter,
- "Returns an iterator over operands.")
- .def("__next__", &PyOpOperandIterator::dunderNext,
- "Returns the next operand in the iteration.");
- }
+ static void bind(nanobind::module_ &m);
private:
MlirOpOperand opOperand;
@@ -1641,24 +1567,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
static constexpr const char *pyClassName = "OpResult";
using PyConcreteValue::PyConcreteValue;
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOpView> {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in "
- "the IR");
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation that produces this result.");
- c.def_prop_ro(
- "result_number",
- [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- },
- "Returns the position of this result in the operation's result list.");
- }
+ static void bindDerived(ClassTy &c);
};
/// Returns the list of types of the values held by container.
@@ -1686,28 +1595,9 @@ class MLIR_PYTHON_API_EXPORTED PyOpResultList
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumResults(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all results in this result list.");
- c.def_prop_ro(
- "owner",
- [](PyOpResultList &self)
- -> nanobind::typed<nanobind::object, PyOpView> {
- return self.operation->createOpView();
- },
- "Returns the operation that owns this result list.");
- }
+ intptr_t length = -1, intptr_t step = 1);
+
+ static void bindDerived(ClassTy &c);
PyOperationRef &getOperation() { return operation; }
@@ -1733,33 +1623,7 @@ class MLIR_PYTHON_API_EXPORTED PyBlockArgument
static constexpr const char *pyClassName = "BlockArgument";
using PyConcreteValue::PyConcreteValue;
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyBlockArgument &self) {
- return PyBlock(self.getParentOperation(),
- mlirBlockArgumentGetOwner(self.get()));
- },
- "Returns the block that owns this argument.");
- c.def_prop_ro(
- "arg_number",
- [](PyBlockArgument &self) {
- return mlirBlockArgumentGetArgNumber(self.get());
- },
- "Returns the position of this argument in the block's argument list.");
- c.def(
- "set_type",
- [](PyBlockArgument &self, PyType type) {
- return mlirBlockArgumentSetType(self.get(), type);
- },
- nanobind::arg("type"), "Sets the type of this block argument.");
- c.def(
- "set_location",
- [](PyBlockArgument &self, PyLocation loc) {
- return mlirBlockArgumentSetLocation(self.get(), loc);
- },
- nanobind::arg("loc"), "Sets the location of this block argument.");
- }
+ static void bindDerived(ClassTy &c);
};
/// A list of block arguments. Internally, these are stored as consecutive
@@ -1774,20 +1638,9 @@ class MLIR_PYTHON_API_EXPORTED PyBlockArgumentList
PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumArguments(block) : length,
- step),
- operation(std::move(operation)), block(block) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyBlockArgumentList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all arguments in this argument list.");
- }
+ intptr_t step = 1);
+
+ static void bindDerived(ClassTy &c);
private:
/// Give the parent CRTP class access to hook implementations below.
@@ -1818,20 +1671,11 @@ class MLIR_PYTHON_API_EXPORTED PyOpOperandList
using SliceableT = Sliceable<PyOpOperandList, PyValue>;
PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumOperands(operation->get())
- : length,
- step),
- operation(operation) {}
+ intptr_t length = -1, intptr_t step = 1);
void dunderSetItem(intptr_t index, PyValue value);
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpOperandList::dunderSetItem,
- nanobind::arg("index"), nanobind::arg("value"),
- "Sets the operand at the specified index to a new value.");
- }
+ static void bindDerived(ClassTy &c);
private:
/// Give the parent CRTP class access to hook implementations below.
@@ -1857,20 +1701,11 @@ class MLIR_PYTHON_API_EXPORTED PyOpSuccessors
static constexpr const char *pyClassName = "OpSuccessors";
PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumSuccessors(operation->get())
- : length,
- step),
- operation(operation) {}
+ intptr_t length = -1, intptr_t step = 1);
void dunderSetItem(intptr_t index, PyBlock block);
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nanobind::arg("index"),
- nanobind::arg("block"),
- "Sets the successor block at the specified index.");
- }
+ static void bindDerived(ClassTy &c);
private:
/// Give the parent CRTP class access to hook implementations below.
@@ -1897,12 +1732,7 @@ class MLIR_PYTHON_API_EXPORTED PyBlockSuccessors
PyBlockSuccessors(PyBlock block, PyOperationRef operation,
intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumSuccessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
+ intptr_t step = 1);
private:
/// Give the parent CRTP class access to hook implementations below.
@@ -1933,12 +1763,7 @@ class MLIR_PYTHON_API_EXPORTED PyBlockPredecessors
PyBlockPredecessors(PyBlock block, PyOperationRef operation,
intptr_t startIndex = 0, intptr_t length = -1,
- intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumPredecessors(block.get())
- : length,
- step),
- operation(operation), block(block) {}
+ intptr_t step = 1);
private:
/// Give the parent CRTP class access to hook implementations below.
@@ -1979,75 +1804,7 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
forEachAttr(MlirOperation op,
llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn);
- static void bind(nanobind::module_ &m) {
- nanobind::class_<PyOpAttributeMap>(m, "OpAttributeMap")
- .def("__contains__", &PyOpAttributeMap::dunderContains,
- nanobind::arg("name"),
- "Checks if an attribute with the given name exists in the map.")
- .def("__len__", &PyOpAttributeMap::dunderLen,
- "Returns the number of attributes in the map.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
- nanobind::arg("name"), "Gets an attribute by name.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
- nanobind::arg("index"), "Gets a named attribute by index.")
- .def("__setitem__", &PyOpAttributeMap::dunderSetItem,
- nanobind::arg("name"), nanobind::arg("attr"),
- "Sets an attribute with the given name.")
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem,
- nanobind::arg("name"), "Deletes an attribute with the given name.")
- .def(
- "__iter__",
- [](PyOpAttributeMap &self) {
- nanobind::list keys;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- keys.append(nanobind::str(name.data, name.length));
- });
- return nanobind::iter(keys);
- },
- "Iterates over attribute names.")
- .def(
- "keys",
- [](PyOpAttributeMap &self) {
- nanobind::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- out.append(nanobind::str(name.data, name.length));
- });
- return out;
- },
- "Returns a list of attribute names.")
- .def(
- "values",
- [](PyOpAttributeMap &self) {
- nanobind::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef, MlirAttribute attr) {
- out.append(PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast());
- });
- return out;
- },
- "Returns a list of attribute values.")
- .def(
- "items",
- [](PyOpAttributeMap &self) {
- nanobind::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute attr) {
- out.append(nanobind::make_tuple(
- nanobind::str(name.data, name.length),
- PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast()));
- });
- return out;
- },
- "Returns a list of `(name, attribute)` tuples.");
- }
+ static void bind(nanobind::module_ &m);
private:
PyOperationRef operation;
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2ea07a6c9adec..a204dd7a4c3b8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1958,6 +1958,14 @@ PyOpOperand PyOpOperandIterator::dunderNext() {
// PyConcreteValue
//------------------------------------------------------------------------------
+PyOpResultList::PyOpResultList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
intptr_t PyOpResultList::getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumResults(operation->get());
@@ -1973,6 +1981,13 @@ PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
return PyOpResultList(operation, startIndex, length, step);
}
+PyBlockArgumentList::PyBlockArgumentList(PyOperationRef operation,
+ MlirBlock block, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumArguments(block) : length, step),
+ operation(std::move(operation)), block(block) {}
+
intptr_t PyBlockArgumentList::getRawNumElements() {
operation->checkValid();
return mlirBlockGetNumArguments(block);
@@ -1989,6 +2004,14 @@ PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex,
return PyBlockArgumentList(operation, block, startIndex, length, step);
}
+PyOpOperandList::PyOpOperandList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumOperands(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
void PyOpOperandList::dunderSetItem(intptr_t index, PyValue value) {
index = wrapIndex(index);
mlirOperationSetOperand(operation->get(), index, value.get());
@@ -2018,6 +2041,14 @@ PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
return PyOpOperandList(operation, startIndex, length, step);
}
+PyOpSuccessors::PyOpSuccessors(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumSuccessors(operation->get())
+ : length,
+ step),
+ operation(operation) {}
+
void PyOpSuccessors::dunderSetItem(intptr_t index, PyBlock block) {
index = wrapIndex(index);
mlirOperationSetSuccessor(operation->get(), index, block.get());
@@ -2038,6 +2069,14 @@ PyOpSuccessors PyOpSuccessors::slice(intptr_t startIndex, intptr_t length,
return PyOpSuccessors(operation, startIndex, length, step);
}
+PyBlockSuccessors::PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex, intptr_t length,
+ intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumSuccessors(block.get()) : length,
+ step),
+ operation(operation), block(block) {}
+
intptr_t PyBlockSuccessors::getRawNumElements() {
block.checkValid();
return mlirBlockGetNumSuccessors(block.get());
@@ -2053,6 +2092,16 @@ PyBlockSuccessors PyBlockSuccessors::slice(intptr_t startIndex, intptr_t length,
return PyBlockSuccessors(block, operation, startIndex, length, step);
}
+PyBlockPredecessors::PyBlockPredecessors(PyBlock block,
+ PyOperationRef operation,
+ intptr_t startIndex, intptr_t length,
+ intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumPredecessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
intptr_t PyBlockPredecessors::getRawNumElements() {
block.checkValid();
return mlirBlockGetNumPredecessors(block.get());
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 79c8e36609d76..9790a8feb8d03 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -257,6 +257,263 @@ createCustomDialectWrapper(const std::string &dialectNamespace,
}
} // namespace
+//===----------------------------------------------------------------------===//
+// NB: all bind and bindDerived methods need to reside in the same
+// binary/extension as the NB_MODULE macro/call. This is because
+// nb_internals *internals within the non-unique nanobind::detail (i.e., the
+// same namespace for all bindings packages).
+//===----------------------------------------------------------------------===//
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void PyRegionList::bindDerived(ClassTy &c) {
+ c.def("__iter__", &PyRegionList::dunderIter,
+ "Returns an iterator over the regions in the sequence.");
+}
+
+void PyOpResult::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOpView> {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in "
+ "the IR");
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the operation that produces this result.");
+ c.def_prop_ro(
+ "result_number",
+ [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
+ "Returns the position of this result in the operation's result list.");
+}
+
+void PyOpResultList::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all results in this result list.");
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResultList &self) -> nanobind::typed<nanobind::object, PyOpView> {
+ return self.operation->createOpView();
+ },
+ "Returns the operation that owns this result list.");
+}
+
+void PyBlockArgument::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyBlockArgument &self) {
+ return PyBlock(self.getParentOperation(),
+ mlirBlockArgumentGetOwner(self.get()));
+ },
+ "Returns the block that owns this argument.");
+ c.def_prop_ro(
+ "arg_number",
+ [](PyBlockArgument &self) {
+ return mlirBlockArgumentGetArgNumber(self.get());
+ },
+ "Returns the position of this argument in the block's argument list.");
+ c.def(
+ "set_type",
+ [](PyBlockArgument &self, PyType type) {
+ return mlirBlockArgumentSetType(self.get(), type);
+ },
+ nanobind::arg("type"), "Sets the type of this block argument.");
+ c.def(
+ "set_location",
+ [](PyBlockArgument &self, PyLocation loc) {
+ return mlirBlockArgumentSetLocation(self.get(), loc);
+ },
+ nanobind::arg("loc"), "Sets the location of this block argument.");
+}
+
+void PyOpOperandList::bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem, nanobind::arg("index"),
+ nanobind::arg("value"),
+ "Sets the operand at the specified index to a new value.");
+}
+
+void PyOpSuccessors::bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nanobind::arg("index"),
+ nanobind::arg("block"),
+ "Sets the successor block at the specified index.");
+}
+
+void PyBlockArgumentList::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all arguments in this argument list.");
+}
+
+void PyAttrBuilderMap::bind(nanobind::module_ &m) {
+ nanobind::class_<PyAttrBuilderMap>(m, "AttrBuilder")
+ .def_static("contains", &PyAttrBuilderMap::dunderContains,
+ nanobind::arg("attribute_kind"),
+ "Checks whether an attribute builder is registered for the "
+ "given attribute kind.")
+ .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
+ nanobind::arg("attribute_kind"),
+ "Gets the registered attribute builder for the given "
+ "attribute kind.")
+ .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
+ nanobind::arg("attribute_kind"),
+ nanobind::arg("attr_builder"),
+ nanobind::arg("replace") = false,
+ "Register an attribute builder for building MLIR "
+ "attributes from Python values.");
+}
+
+void PyRegionIterator::bind(nanobind::module_ &m) {
+ nanobind::class_<PyRegionIterator>(m, "RegionIterator")
+ .def("__iter__", &PyRegionIterator::dunderIter,
+ "Returns an iterator over the regions in the operation.")
+ .def("__next__", &PyRegionIterator::dunderNext,
+ "Returns the next region in the iteration.");
+}
+
+void PyBlockIterator::bind(nanobind::module_ &m) {
+ nanobind::class_<PyBlockIterator>(m, "BlockIterator")
+ .def("__iter__", &PyBlockIterator::dunderIter,
+ "Returns an iterator over the blocks in the operation's region.")
+ .def("__next__", &PyBlockIterator::dunderNext,
+ "Returns the next block in the iteration.");
+}
+
+void PyBlockList::bind(nanobind::module_ &m) {
+ nanobind::class_<PyBlockList>(m, "BlockList")
+ .def("__getitem__", &PyBlockList::dunderGetItem,
+ "Returns the block at the specified index.")
+ .def("__iter__", &PyBlockList::dunderIter,
+ "Returns an iterator over blocks in the operation's region.")
+ .def("__len__", &PyBlockList::dunderLen,
+ "Returns the number of blocks in the operation's region.")
+ .def("append", &PyBlockList::appendBlock,
+ R"(
+ Appends a new block, with argument types as positional args.
+
+ Returns:
+ The created block.
+ )",
+ nanobind::arg("args"), nanobind::kw_only(),
+ nanobind::arg("arg_locs") = std::nullopt);
+}
+
+void PyOperationIterator::bind(nanobind::module_ &m) {
+ nanobind::class_<PyOperationIterator>(m, "OperationIterator")
+ .def("__iter__", &PyOperationIterator::dunderIter,
+ "Returns an iterator over the operations in an operation's block.")
+ .def("__next__", &PyOperationIterator::dunderNext,
+ "Returns the next operation in the iteration.");
+}
+
+void PyOperationList::bind(nanobind::module_ &m) {
+ nanobind::class_<PyOperationList>(m, "OperationList")
+ .def("__getitem__", &PyOperationList::dunderGetItem,
+ "Returns the operation at the specified index.")
+ .def("__iter__", &PyOperationList::dunderIter,
+ "Returns an iterator over operations in the list.")
+ .def("__len__", &PyOperationList::dunderLen,
+ "Returns the number of operations in the list.");
+}
+
+void PyOpOperand::bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpOperand>(m, "OpOperand")
+ .def_prop_ro("owner", &PyOpOperand::getOwner,
+ "Returns the operation that owns this operand.")
+ .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
+ "Returns the operand number in the owning operation.");
+}
+
+void PyOpOperandIterator::bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpOperandIterator>(m, "OpOperandIterator")
+ .def("__iter__", &PyOpOperandIterator::dunderIter,
+ "Returns an iterator over operands.")
+ .def("__next__", &PyOpOperandIterator::dunderNext,
+ "Returns the next operand in the iteration.");
+}
+
+void PyOpAttributeMap::bind(nanobind::module_ &m) {
+ nanobind::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+ .def("__contains__", &PyOpAttributeMap::dunderContains,
+ nanobind::arg("name"),
+ "Checks if an attribute with the given name exists in the map.")
+ .def("__len__", &PyOpAttributeMap::dunderLen,
+ "Returns the number of attributes in the map.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
+ nanobind::arg("name"), "Gets an attribute by name.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
+ nanobind::arg("index"), "Gets a named attribute by index.")
+ .def("__setitem__", &PyOpAttributeMap::dunderSetItem,
+ nanobind::arg("name"), nanobind::arg("attr"),
+ "Sets an attribute with the given name.")
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem,
+ nanobind::arg("name"), "Deletes an attribute with the given name.")
+ .def(
+ "__iter__",
+ [](PyOpAttributeMap &self) {
+ nanobind::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nanobind::str(name.data, name.length));
+ });
+ return nanobind::iter(keys);
+ },
+ "Iterates over attribute names.")
+ .def(
+ "keys",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
+ out.append(nanobind::str(name.data, name.length));
+ });
+ return out;
+ },
+ "Returns a list of attribute names.")
+ .def(
+ "values",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ },
+ "Returns a list of attribute values.")
+ .def(
+ "items",
+ [](PyOpAttributeMap &self) {
+ nanobind::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nanobind::make_tuple(
+ nanobind::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ },
+ "Returns a list of `(name, attribute)` tuples.");
+}
+
+void populateIRAffine(nb::module_ &m);
+void populateIRAttributes(nb::module_ &m);
+void populateIRInterfaces(nb::module_ &m);
+void populateIRTypes(nb::module_ &m);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
//------------------------------------------------------------------------------
// Populates the core exports of the 'ir' submodule.
//------------------------------------------------------------------------------
@@ -2330,17 +2587,6 @@ static void populateIRCore(nb::module_ &m) {
PyAttrBuilderMap::bind(m);
}
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-void populateIRAffine(nb::module_ &m);
-void populateIRAttributes(nb::module_ &m);
-void populateIRInterfaces(nb::module_ &m);
-void populateIRTypes(nb::module_ &m);
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
-
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
>From a3996ca9add4e722f190e60c365f08712fb6f234 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 30 Dec 2025 17:37:06 -0800
Subject: [PATCH 34/38] try twolevel_namespace
---
mlir/lib/Bindings/Python/IRCore.cpp | 28 +++++++++++++++++++++++++
mlir/lib/Bindings/Python/MainModule.cpp | 28 -------------------------
2 files changed, 28 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a204dd7a4c3b8..57fa5a420776f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1993,6 +1993,34 @@ intptr_t PyBlockArgumentList::getRawNumElements() {
return mlirBlockGetNumArguments(block);
}
+void PyBlockArgument::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyBlockArgument &self) {
+ return PyBlock(self.getParentOperation(),
+ mlirBlockArgumentGetOwner(self.get()));
+ },
+ "Returns the block that owns this argument.");
+ c.def_prop_ro(
+ "arg_number",
+ [](PyBlockArgument &self) {
+ return mlirBlockArgumentGetArgNumber(self.get());
+ },
+ "Returns the position of this argument in the block's argument list.");
+ c.def(
+ "set_type",
+ [](PyBlockArgument &self, PyType type) {
+ return mlirBlockArgumentSetType(self.get(), type);
+ },
+ nanobind::arg("type"), "Sets the type of this block argument.");
+ c.def(
+ "set_location",
+ [](PyBlockArgument &self, PyLocation loc) {
+ return mlirBlockArgumentSetLocation(self.get(), loc);
+ },
+ nanobind::arg("loc"), "Sets the location of this block argument.");
+}
+
PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
MlirValue argument = mlirBlockGetArgument(block, pos);
return PyBlockArgument(operation, argument);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 9790a8feb8d03..b49a9f1e3af24 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -303,34 +303,6 @@ void PyOpResultList::bindDerived(ClassTy &c) {
"Returns the operation that owns this result list.");
}
-void PyBlockArgument::bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyBlockArgument &self) {
- return PyBlock(self.getParentOperation(),
- mlirBlockArgumentGetOwner(self.get()));
- },
- "Returns the block that owns this argument.");
- c.def_prop_ro(
- "arg_number",
- [](PyBlockArgument &self) {
- return mlirBlockArgumentGetArgNumber(self.get());
- },
- "Returns the position of this argument in the block's argument list.");
- c.def(
- "set_type",
- [](PyBlockArgument &self, PyType type) {
- return mlirBlockArgumentSetType(self.get(), type);
- },
- nanobind::arg("type"), "Sets the type of this block argument.");
- c.def(
- "set_location",
- [](PyBlockArgument &self, PyLocation loc) {
- return mlirBlockArgumentSetLocation(self.get(), loc);
- },
- nanobind::arg("loc"), "Sets the location of this block argument.");
-}
-
void PyOpOperandList::bindDerived(ClassTy &c) {
c.def("__setitem__", &PyOpOperandList::dunderSetItem, nanobind::arg("index"),
nanobind::arg("value"),
>From ab047834362d419117f49970325c751615fde6fc Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 30 Dec 2025 20:44:13 -0800
Subject: [PATCH 35/38] move impls
---
mlir/include/mlir/Bindings/Python/IRCore.h | 32 +-
mlir/lib/Bindings/Python/IRCore.cpp | 3152 ++++++++++++++++++--
mlir/lib/Bindings/Python/MainModule.cpp | 2659 +----------------
3 files changed, 2902 insertions(+), 2941 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 0f402b4ce15ff..af6c8dbbb7fa8 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -51,6 +51,16 @@ class PyType;
class PySymbolTable;
class PyValue;
+/// Wrapper for the global LLVM debugging flag.
+struct MLIR_PYTHON_API_EXPORTED PyGlobalDebugFlag {
+ static void set(nanobind::object &o, bool enable);
+ static bool get(const nanobind::object &);
+ static void bind(nanobind::module_ &m);
+
+private:
+ static nanobind::ft_mutex mutex;
+};
+
/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
template <typename T>
@@ -1454,11 +1464,7 @@ class MLIR_PYTHON_API_EXPORTED PyOperationList {
PyOperationList(PyOperationRef parentOperation, MlirBlock block)
: parentOperation(std::move(parentOperation)), block(block) {}
- PyOperationIterator dunderIter() {
- parentOperation->checkValid();
- return PyOperationIterator(parentOperation,
- mlirBlockGetFirstOperation(block));
- }
+ PyOperationIterator dunderIter();
intptr_t dunderLen();
@@ -1570,20 +1576,6 @@ class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue<PyOpResult> {
static void bindDerived(ClassTy &c);
};
-/// Returns the list of types of the values held by container.
-template <typename Container>
-std::vector<nanobind::typed<nanobind::object, PyType>>
-getValueTypes(Container &container, PyMlirContextRef &context) {
- std::vector<nanobind::typed<nanobind::object, PyType>> result;
- result.reserve(container.size());
- for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(PyType(context->getRef(),
- mlirValueGetType(container.getElement(i).get()))
- .maybeDownCast());
- }
- return result;
-}
-
/// A list of operation results. Internally, these are stored as consecutive
/// elements, random access is cheap. The (returned) result list is associated
/// with the operation whose results these are, and thus extends the lifetime of
@@ -1811,6 +1803,8 @@ class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap {
};
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
+MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
+MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 57fa5a420776f..483c930b115b7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -24,9 +24,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
-#include <iostream>
#include <optional>
-#include <typeinfo>
namespace nb = nanobind;
using namespace nb::literals;
@@ -36,9 +34,399 @@ using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
+static const char kModuleParseDocstring[] =
+ R"(Parses a module's assembly format from a string.
+
+Returns a new MlirModule or raises an MLIRError if the parsing fails.
+
+See also: https://mlir.llvm.org/docs/LangRef/
+)";
+
+static const char kDumpDocstring[] =
+ "Dumps a debug representation of the object to stderr.";
+
+static const char kValueReplaceAllUsesExceptDocstring[] =
+ R"(Replace all uses of this value with the `with` value, except for those
+in `exceptions`. `exceptions` can be either a single operation or a list of
+operations.
+)";
+
+//------------------------------------------------------------------------------
+// Utilities.
+//------------------------------------------------------------------------------
+
+/// Helper for creating an @classmethod.
+template <class Func, typename... Args>
+static nb::object classmethod(Func f, Args... args) {
+ nb::object cf = nb::cpp_function(f, args...);
+ return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
+}
+
+static nb::object
+createCustomDialectWrapper(const std::string &dialectNamespace,
+ nb::object dialectDescriptor) {
+ auto dialectClass =
+ python::MLIR_BINDINGS_PYTHON_DOMAIN::PyGlobals::get().lookupDialectClass(
+ dialectNamespace);
+ if (!dialectClass) {
+ // Use the base class.
+ return nb::cast(python::MLIR_BINDINGS_PYTHON_DOMAIN::PyDialect(
+ std::move(dialectDescriptor)));
+ }
+
+ // Create the custom implementation.
+ return (*dialectClass)(std::move(dialectDescriptor));
+}
+
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+MlirBlock createBlock(const nb::sequence &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ SmallVector<MlirType> argTypes;
+ argTypes.reserve(nb::len(pyArgTypes));
+ for (const auto &pyType : pyArgTypes)
+ argTypes.push_back(
+ nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyType &>(pyType));
+
+ SmallVector<MlirLocation> argLocs;
+ if (pyArgLocs) {
+ argLocs.reserve(nb::len(*pyArgLocs));
+ for (const auto &pyLoc : *pyArgLocs)
+ argLocs.push_back(
+ nb::cast<python::MLIR_BINDINGS_PYTHON_DOMAIN::PyLocation &>(pyLoc));
+ } else if (!argTypes.empty()) {
+ argLocs.assign(
+ argTypes.size(),
+ python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation::resolve());
+ }
+
+ if (argTypes.size() != argLocs.size())
+ throw nb::value_error(("Expected " + Twine(argTypes.size()) +
+ " locations, got: " + Twine(argLocs.size()))
+ .str()
+ .c_str());
+ return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
+}
+
+void PyGlobalDebugFlag::set(nb::object &o, bool enable) {
+ nb::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+}
+
+bool PyGlobalDebugFlag::get(const nb::object &) {
+ nb::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+}
+
+void PyGlobalDebugFlag::bind(nb::module_ &m) {
+ // Debug flags.
+ nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
+ .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
+ .def_static(
+ "set_types",
+ [](const std::string &type) {
+ nb::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugType(type.c_str());
+ },
+ "types"_a, "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nb::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ "types"_a,
+ "Sets multiple specific debug types to be produced by LLVM.");
+}
+
+nb::ft_mutex PyGlobalDebugFlag::mutex;
+
+bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
+ return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+}
+
+nb::callable
+PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
+ auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+ if (!builder)
+ throw nb::key_error(attributeKind.c_str());
+ return *builder;
+}
+
+void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
+ nb::callable func, bool replace) {
+ PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
+ replace);
+}
+
+void PyAttrBuilderMap::bind(nb::module_ &m) {
+ nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
+ .def_static("contains", &PyAttrBuilderMap::dunderContains,
+ "attribute_kind"_a,
+ "Checks whether an attribute builder is registered for the "
+ "given attribute kind.")
+ .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
+ "attribute_kind"_a,
+ "Gets the registered attribute builder for the given "
+ "attribute kind.")
+ .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
+ "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
+ "Register an attribute builder for building MLIR "
+ "attributes from Python values.");
+}
+
+//------------------------------------------------------------------------------
+// PyBlock
+//------------------------------------------------------------------------------
+
+nb::object PyBlock::getCapsule() {
+ return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
+}
+
+//------------------------------------------------------------------------------
+// Collections.
+//------------------------------------------------------------------------------
+
+PyRegion PyRegionIterator::dunderNext() {
+ operation->checkValid();
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ throw nb::stop_iteration();
+ }
+ MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
+ return PyRegion(operation, region);
+}
+
+void PyRegionIterator::bind(nb::module_ &m) {
+ nb::class_<PyRegionIterator>(m, "RegionIterator")
+ .def("__iter__", &PyRegionIterator::dunderIter,
+ "Returns an iterator over the regions in the operation.")
+ .def("__next__", &PyRegionIterator::dunderNext,
+ "Returns the next region in the iteration.");
+}
+
+PyRegionList::PyRegionList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+PyRegionIterator PyRegionList::dunderIter() {
+ operation->checkValid();
+ return PyRegionIterator(operation, startIndex);
+}
+
+void PyRegionList::bindDerived(ClassTy &c) {
+ c.def("__iter__", &PyRegionList::dunderIter,
+ "Returns an iterator over the regions in the sequence.");
+}
+
+intptr_t PyRegionList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumRegions(operation->get());
+}
+
+PyRegion PyRegionList::getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
+}
+
+PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyRegionList(operation, startIndex, length, step);
+}
+
+PyBlock PyBlockIterator::dunderNext() {
+ operation->checkValid();
+ if (mlirBlockIsNull(next)) {
+ throw nb::stop_iteration();
+ }
+
+ PyBlock returnBlock(operation, next);
+ next = mlirBlockGetNextInRegion(next);
+ return returnBlock;
+}
+
+void PyBlockIterator::bind(nb::module_ &m) {
+ nb::class_<PyBlockIterator>(m, "BlockIterator")
+ .def("__iter__", &PyBlockIterator::dunderIter,
+ "Returns an iterator over the blocks in the operation's region.")
+ .def("__next__", &PyBlockIterator::dunderNext,
+ "Returns the next block in the iteration.");
+}
+
+PyBlockIterator PyBlockList::dunderIter() {
+ operation->checkValid();
+ return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
+}
+
+intptr_t PyBlockList::dunderLen() {
+ operation->checkValid();
+ intptr_t count = 0;
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ count += 1;
+ block = mlirBlockGetNextInRegion(block);
+ }
+ return count;
+}
+
+PyBlock PyBlockList::dunderGetItem(intptr_t index) {
+ operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nb::index_error("attempt to access out of bounds block");
+ }
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ while (!mlirBlockIsNull(block)) {
+ if (index == 0) {
+ return PyBlock(operation, block);
+ }
+ block = mlirBlockGetNextInRegion(block);
+ index -= 1;
+ }
+ throw nb::index_error("attempt to access out of bounds block");
+}
+
+PyBlock PyBlockList::appendBlock(const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ operation->checkValid();
+ MlirBlock block = createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ mlirRegionAppendOwnedBlock(region, block);
+ return PyBlock(operation, block);
+}
+
+void PyBlockList::bind(nb::module_ &m) {
+ nb::class_<PyBlockList>(m, "BlockList")
+ .def("__getitem__", &PyBlockList::dunderGetItem,
+ "Returns the block at the specified index.")
+ .def("__iter__", &PyBlockList::dunderIter,
+ "Returns an iterator over blocks in the operation's region.")
+ .def("__len__", &PyBlockList::dunderLen,
+ "Returns the number of blocks in the operation's region.")
+ .def("append", &PyBlockList::appendBlock,
+ R"(
+ Appends a new block, with argument types as positional args.
+
+ Returns:
+ The created block.
+ )",
+ "args"_a, nb::kw_only(), "arg_locs"_a = std::nullopt);
+}
+
+nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
+ parentOperation->checkValid();
+ if (mlirOperationIsNull(next)) {
+ throw nb::stop_iteration();
+ }
+
+ PyOperationRef returnOperation =
+ PyOperation::forOperation(parentOperation->getContext(), next);
+ next = mlirOperationGetNextInBlock(next);
+ return returnOperation->createOpView();
+}
+
+void PyOperationIterator::bind(nb::module_ &m) {
+ nb::class_<PyOperationIterator>(m, "OperationIterator")
+ .def("__iter__", &PyOperationIterator::dunderIter,
+ "Returns an iterator over the operations in an operation's block.")
+ .def("__next__", &PyOperationIterator::dunderNext,
+ "Returns the next operation in the iteration.");
+}
+
+PyOperationIterator PyOperationList::dunderIter() {
+ parentOperation->checkValid();
+ return PyOperationIterator(parentOperation,
+ mlirBlockGetFirstOperation(block));
+}
+
+intptr_t PyOperationList::dunderLen() {
+ parentOperation->checkValid();
+ intptr_t count = 0;
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ count += 1;
+ childOp = mlirOperationGetNextInBlock(childOp);
+ }
+ return count;
+}
+
+nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
+ parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
+ if (index < 0) {
+ throw nb::index_error("attempt to access out of bounds operation");
+ }
+ MlirOperation childOp = mlirBlockGetFirstOperation(block);
+ while (!mlirOperationIsNull(childOp)) {
+ if (index == 0) {
+ return PyOperation::forOperation(parentOperation->getContext(), childOp)
+ ->createOpView();
+ }
+ childOp = mlirOperationGetNextInBlock(childOp);
+ index -= 1;
+ }
+ throw nb::index_error("attempt to access out of bounds operation");
+}
+
+void PyOperationList::bind(nb::module_ &m) {
+ nb::class_<PyOperationList>(m, "OperationList")
+ .def("__getitem__", &PyOperationList::dunderGetItem,
+ "Returns the operation at the specified index.")
+ .def("__iter__", &PyOperationList::dunderIter,
+ "Returns an iterator over operations in the list.")
+ .def("__len__", &PyOperationList::dunderLen,
+ "Returns the number of operations in the list.");
+}
+
+nb::typed<nb::object, PyOpView> PyOpOperand::getOwner() const {
+ MlirOperation owner = mlirOpOperandGetOwner(opOperand);
+ PyMlirContextRef context =
+ PyMlirContext::forContext(mlirOperationGetContext(owner));
+ return PyOperation::forOperation(context, owner)->createOpView();
+}
+
+size_t PyOpOperand::getOperandNumber() const {
+ return mlirOpOperandGetOperandNumber(opOperand);
+}
+
+void PyOpOperand::bind(nb::module_ &m) {
+ nb::class_<PyOpOperand>(m, "OpOperand")
+ .def_prop_ro("owner", &PyOpOperand::getOwner,
+ "Returns the operation that owns this operand.")
+ .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
+ "Returns the operand number in the owning operation.");
+}
+
+PyOpOperand PyOpOperandIterator::dunderNext() {
+ if (mlirOpOperandIsNull(opOperand))
+ throw nb::stop_iteration();
+
+ PyOpOperand returnOpOperand(opOperand);
+ opOperand = mlirOpOperandGetNextUse(opOperand);
+ return returnOpOperand;
+}
+
+void PyOpOperandIterator::bind(nb::module_ &m) {
+ nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
+ .def("__iter__", &PyOpOperandIterator::dunderIter,
+ "Returns an iterator over operands.")
+ .def("__next__", &PyOpOperandIterator::dunderNext,
+ "Returns the next operand in the iteration.");
+}
+
//------------------------------------------------------------------------------
// PyThreadPool
//------------------------------------------------------------------------------
@@ -77,8 +465,9 @@ PyMlirContext::~PyMlirContext() {
}
PyMlirContextRef PyMlirContext::getRef() {
- return PyMlirContextRef(this, nanobind::cast(this));
+ return PyMlirContextRef(this, nb::cast(this));
}
+
nb::object PyMlirContext::getCapsule() {
return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
}
@@ -618,7 +1007,7 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
void PyOperation::detachFromParent() {
mlirOperationRemoveFromParent(getOperation());
setDetached();
- parentKeepAlive = nanobind::object();
+ parentKeepAlive = nb::object();
}
MlirOperation PyOperation::get() const {
@@ -627,10 +1016,10 @@ MlirOperation PyOperation::get() const {
}
PyOperationRef PyOperation::getRef() {
- return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
+ return PyOperationRef(this, nb::borrow<nb::object>(handle));
}
-void PyOperation::setAttached(const nanobind::object &parent) {
+void PyOperation::setAttached(const nb::object &parent) {
assert(!attached && "operation already attached");
attached = true;
}
@@ -1001,6 +1390,75 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
+void PyOpResult::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in "
+ "the IR");
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the operation that produces this result.");
+ c.def_prop_ro(
+ "result_number",
+ [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
+ "Returns the position of this result in the operation's result list.");
+}
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+ std::vector<nb::typed<nb::object, PyType>> result;
+ result.reserve(container.size());
+ for (int i = 0, e = container.size(); i < e; ++i) {
+ result.push_back(PyType(context->getRef(),
+ mlirValueGetType(container.getElement(i).get()))
+ .maybeDownCast());
+ }
+ return result;
+}
+
+PyOpResultList::PyOpResultList(PyOperationRef operation, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+void PyOpResultList::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all results in this result list.");
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
+ return self.operation->createOpView();
+ },
+ "Returns the operation that owns this result list.");
+}
+
+intptr_t PyOpResultList::getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+}
+
+PyOpResult PyOpResultList::getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+}
+
+PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) const {
+ return PyOpResultList(operation, startIndex, length, step);
+}
+
//------------------------------------------------------------------------------
// PyOpView
//------------------------------------------------------------------------------
@@ -1335,15 +1793,7 @@ PyOpView::PyOpView(const nb::object &operationObject)
operationObject(operation.getRef().getObject()) {}
//------------------------------------------------------------------------------
-// PyBlock
-//------------------------------------------------------------------------------
-
-nanobind::object PyBlock::getCapsule() {
- return nanobind::steal<nanobind::object>(mlirPythonBlockToCapsule(get()));
-}
-
-//------------------------------------------------------------------------------
-// PyAsmState
+// PyAsmState
//------------------------------------------------------------------------------
PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) {
@@ -1745,254 +2195,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
}
-MlirBlock createBlock(const nanobind::sequence &pyArgTypes,
- const std::optional<nanobind::sequence> &pyArgLocs) {
- SmallVector<MlirType> argTypes;
- argTypes.reserve(nanobind::len(pyArgTypes));
- for (const auto &pyType : pyArgTypes)
- argTypes.push_back(nanobind::cast<PyType &>(pyType));
-
- SmallVector<MlirLocation> argLocs;
- if (pyArgLocs) {
- argLocs.reserve(nanobind::len(*pyArgLocs));
- for (const auto &pyLoc : *pyArgLocs)
- argLocs.push_back(nanobind::cast<PyLocation &>(pyLoc));
- } else if (!argTypes.empty()) {
- argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
- }
-
- if (argTypes.size() != argLocs.size()) {
- throw nanobind::value_error(("Expected " + Twine(argTypes.size()) +
- " locations, got: " + Twine(argLocs.size()))
- .str()
- .c_str());
- }
- return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
-}
-
-//------------------------------------------------------------------------------
-// PyAttrBuilderMap
-//------------------------------------------------------------------------------
-
-bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) {
- return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
-}
-
-nanobind::callable
-PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
- auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
- if (!builder)
- throw nanobind::key_error(attributeKind.c_str());
- return *builder;
-}
-
-void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
- nanobind::callable func,
- bool replace) {
- PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
- replace);
-}
-
-//------------------------------------------------------------------------------
-// Collections.
-//------------------------------------------------------------------------------
-
-PyRegion PyRegionIterator::dunderNext() {
- operation->checkValid();
- if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
- throw nanobind::stop_iteration();
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
- return PyRegion(operation, region);
-}
-
-PyRegionList::PyRegionList(PyOperationRef operation, intptr_t startIndex,
- intptr_t length, intptr_t step)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumRegions(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
-PyRegionIterator PyRegionList::dunderIter() {
- operation->checkValid();
- return PyRegionIterator(operation, startIndex);
-}
-
-intptr_t PyRegionList::getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumRegions(operation->get());
-}
-
-PyRegion PyRegionList::getRawElement(intptr_t pos) {
- operation->checkValid();
- return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
-}
-
-PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
- intptr_t step) const {
- return PyRegionList(operation, startIndex, length, step);
-}
-
-PyBlock PyBlockIterator::dunderNext() {
- operation->checkValid();
- if (mlirBlockIsNull(next)) {
- throw nanobind::stop_iteration();
- }
-
- PyBlock returnBlock(operation, next);
- next = mlirBlockGetNextInRegion(next);
- return returnBlock;
-}
-
-PyBlockIterator PyBlockList::dunderIter() {
- operation->checkValid();
- return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
-}
-
-intptr_t PyBlockList::dunderLen() {
- operation->checkValid();
- intptr_t count = 0;
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- count += 1;
- block = mlirBlockGetNextInRegion(block);
- }
- return count;
-}
-
-PyBlock PyBlockList::dunderGetItem(intptr_t index) {
- operation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nanobind::index_error("attempt to access out of bounds block");
- }
- MlirBlock block = mlirRegionGetFirstBlock(region);
- while (!mlirBlockIsNull(block)) {
- if (index == 0) {
- return PyBlock(operation, block);
- }
- block = mlirBlockGetNextInRegion(block);
- index -= 1;
- }
- throw nanobind::index_error("attempt to access out of bounds block");
-}
-
-PyBlock
-PyBlockList::appendBlock(const nanobind::args &pyArgTypes,
- const std::optional<nanobind::sequence> &pyArgLocs) {
- operation->checkValid();
- MlirBlock block =
- createBlock(nanobind::cast<nanobind::sequence>(pyArgTypes), pyArgLocs);
- mlirRegionAppendOwnedBlock(region, block);
- return PyBlock(operation, block);
-}
-
-nanobind::typed<nanobind::object, PyOpView> PyOperationIterator::dunderNext() {
- parentOperation->checkValid();
- if (mlirOperationIsNull(next)) {
- throw nanobind::stop_iteration();
- }
-
- PyOperationRef returnOperation =
- PyOperation::forOperation(parentOperation->getContext(), next);
- next = mlirOperationGetNextInBlock(next);
- return returnOperation->createOpView();
-}
-
-intptr_t PyOperationList::dunderLen() {
- parentOperation->checkValid();
- intptr_t count = 0;
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- count += 1;
- childOp = mlirOperationGetNextInBlock(childOp);
- }
- return count;
-}
-
-nanobind::typed<nanobind::object, PyOpView>
-PyOperationList::dunderGetItem(intptr_t index) {
- parentOperation->checkValid();
- if (index < 0) {
- index += dunderLen();
- }
- if (index < 0) {
- throw nanobind::index_error("attempt to access out of bounds operation");
- }
- MlirOperation childOp = mlirBlockGetFirstOperation(block);
- while (!mlirOperationIsNull(childOp)) {
- if (index == 0) {
- return PyOperation::forOperation(parentOperation->getContext(), childOp)
- ->createOpView();
- }
- childOp = mlirOperationGetNextInBlock(childOp);
- index -= 1;
- }
- throw nanobind::index_error("attempt to access out of bounds operation");
-}
-
-nanobind::typed<nanobind::object, PyOpView> PyOpOperand::getOwner() const {
- MlirOperation owner = mlirOpOperandGetOwner(opOperand);
- PyMlirContextRef context =
- PyMlirContext::forContext(mlirOperationGetContext(owner));
- return PyOperation::forOperation(context, owner)->createOpView();
-}
-
-size_t PyOpOperand::getOperandNumber() const {
- return mlirOpOperandGetOperandNumber(opOperand);
-}
-
-PyOpOperand PyOpOperandIterator::dunderNext() {
- if (mlirOpOperandIsNull(opOperand))
- throw nanobind::stop_iteration();
-
- PyOpOperand returnOpOperand(opOperand);
- opOperand = mlirOpOperandGetNextUse(opOperand);
- return returnOpOperand;
-}
-
-//------------------------------------------------------------------------------
-// PyConcreteValue
-//------------------------------------------------------------------------------
-
-PyOpResultList::PyOpResultList(PyOperationRef operation, intptr_t startIndex,
- intptr_t length, intptr_t step)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumResults(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
-intptr_t PyOpResultList::getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
-}
-
-PyOpResult PyOpResultList::getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
-}
-
-PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length,
- intptr_t step) const {
- return PyOpResultList(operation, startIndex, length, step);
-}
-
-PyBlockArgumentList::PyBlockArgumentList(PyOperationRef operation,
- MlirBlock block, intptr_t startIndex,
- intptr_t length, intptr_t step)
- : Sliceable(startIndex,
- length == -1 ? mlirBlockGetNumArguments(block) : length, step),
- operation(std::move(operation)), block(block) {}
-
-intptr_t PyBlockArgumentList::getRawNumElements() {
- operation->checkValid();
- return mlirBlockGetNumArguments(block);
-}
-
void PyBlockArgument::bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
@@ -2012,13 +2214,34 @@ void PyBlockArgument::bindDerived(ClassTy &c) {
[](PyBlockArgument &self, PyType type) {
return mlirBlockArgumentSetType(self.get(), type);
},
- nanobind::arg("type"), "Sets the type of this block argument.");
+ "type"_a, "Sets the type of this block argument.");
c.def(
"set_location",
[](PyBlockArgument &self, PyLocation loc) {
return mlirBlockArgumentSetLocation(self.get(), loc);
},
- nanobind::arg("loc"), "Sets the location of this block argument.");
+ "loc"_a, "Sets the location of this block argument.");
+}
+
+PyBlockArgumentList::PyBlockArgumentList(PyOperationRef operation,
+ MlirBlock block, intptr_t startIndex,
+ intptr_t length, intptr_t step)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumArguments(block) : length, step),
+ operation(std::move(operation)), block(block) {}
+
+void PyBlockArgumentList::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "types",
+ [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all arguments in this argument list.");
+}
+
+intptr_t PyBlockArgumentList::getRawNumElements() {
+ operation->checkValid();
+ return mlirBlockGetNumArguments(block);
}
PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const {
@@ -2045,6 +2268,11 @@ void PyOpOperandList::dunderSetItem(intptr_t index, PyValue value) {
mlirOperationSetOperand(operation->get(), index, value.get());
}
+void PyOpOperandList::bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem, "index"_a, "value"_a,
+ "Sets the operand at the specified index to a new value.");
+}
+
intptr_t PyOpOperandList::getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumOperands(operation->get());
@@ -2082,6 +2310,11 @@ void PyOpSuccessors::dunderSetItem(intptr_t index, PyBlock block) {
mlirOperationSetSuccessor(operation->get(), index, block.get());
}
+void PyOpSuccessors::bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpSuccessors::dunderSetItem, "index"_a, "block"_a,
+ "Sets the successor block at the specified index.");
+}
+
intptr_t PyOpSuccessors::getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumSuccessors(operation->get());
@@ -2146,12 +2379,12 @@ PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex,
return PyBlockPredecessors(block, operation, startIndex, length, step);
}
-nanobind::typed<nanobind::object, PyAttribute>
+nb::typed<nb::object, PyAttribute>
PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
MlirAttribute attr =
mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
- throw nanobind::key_error("attempt to access a non-existent attribute");
+ throw nb::key_error("attempt to access a non-existent attribute");
}
return PyAttribute(operation->getContext(), attr).maybeDownCast();
}
@@ -2161,7 +2394,7 @@ PyNamedAttribute PyOpAttributeMap::dunderGetItemIndexed(intptr_t index) {
index += dunderLen();
}
if (index < 0 || index >= dunderLen()) {
- throw nanobind::index_error("attempt to access out of bounds attribute");
+ throw nb::index_error("attempt to access out of bounds attribute");
}
MlirNamedAttribute namedAttr =
mlirOperationGetAttribute(operation->get(), index);
@@ -2181,7 +2414,7 @@ void PyOpAttributeMap::dunderDelItem(const std::string &name) {
int removed = mlirOperationRemoveAttributeByName(operation->get(),
toMlirStringRef(name));
if (!removed)
- throw nanobind::key_error("attempt to delete a non-existent attribute");
+ throw nb::key_error("attempt to delete a non-existent attribute");
}
intptr_t PyOpAttributeMap::dunderLen() {
@@ -2203,6 +2436,2387 @@ void PyOpAttributeMap::forEachAttr(
fn(name, na.attribute);
}
}
+
+void PyOpAttributeMap::bind(nb::module_ &m) {
+ nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+ .def("__contains__", &PyOpAttributeMap::dunderContains, "name"_a,
+ "Checks if an attribute with the given name exists in the map.")
+ .def("__len__", &PyOpAttributeMap::dunderLen,
+ "Returns the number of attributes in the map.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, "name"_a,
+ "Gets an attribute by name.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, "index"_a,
+ "Gets a named attribute by index.")
+ .def("__setitem__", &PyOpAttributeMap::dunderSetItem, "name"_a, "attr"_a,
+ "Sets an attribute with the given name.")
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a,
+ "Deletes an attribute with the given name.")
+ .def(
+ "__iter__",
+ [](PyOpAttributeMap &self) {
+ nb::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nb::str(name.data, name.length));
+ });
+ return nb::iter(keys);
+ },
+ "Iterates over attribute names.")
+ .def(
+ "keys",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
+ out.append(nb::str(name.data, name.length));
+ });
+ return out;
+ },
+ "Returns a list of attribute names.")
+ .def(
+ "values",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ },
+ "Returns a list of attribute values.")
+ .def(
+ "items",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nb::make_tuple(
+ nb::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ },
+ "Returns a list of `(name, attribute)` tuples.");
+}
+
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
+namespace {
+// see
+// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
+
+#ifndef _Py_CAST
+#define _Py_CAST(type, expr) ((type)(expr))
+#endif
+
+// Static inline functions should use _Py_NULL rather than using directly NULL
+// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
+// _Py_NULL is defined as nullptr.
+#ifndef _Py_NULL
+#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
+ (defined(__cplusplus) && __cplusplus >= 201103)
+#define _Py_NULL nullptr
+#else
+#define _Py_NULL NULL
+#endif
+#endif
+
+// Python 3.10.0a3
+#if PY_VERSION_HEX < 0x030A00A3
+
+// bpo-42262 added Py_XNewRef()
+#if !defined(Py_XNewRef)
+[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
+ Py_XINCREF(obj);
+ return obj;
+}
+#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
+#endif
+
+// bpo-42262 added Py_NewRef()
+#if !defined(Py_NewRef)
+[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
+ Py_INCREF(obj);
+ return obj;
+}
+#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
+#endif
+
+#endif // Python 3.10.0a3
+
+// Python 3.9.0b1
+#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
+
+// bpo-40429 added PyThreadState_GetFrame()
+PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
+ assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
+}
+
+// bpo-40421 added PyFrame_GetBack()
+PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
+}
+
+// bpo-40421 added PyFrame_GetCode()
+PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
+ return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
+}
+
+#endif // Python 3.9.0b1
+
+using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
+
+MlirLocation tracebackToLocation(MlirContext ctx) {
+ size_t framesLimit =
+ PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
+ // Use a thread_local here to avoid requiring a large amount of space.
+ thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
+ frames;
+ size_t count = 0;
+
+ nb::gil_scoped_acquire acquire;
+ PyThreadState *tstate = PyThreadState_GET();
+ PyFrameObject *next;
+ PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
+ // In the increment expression:
+ // 1. get the next prev frame;
+ // 2. decrement the ref count on the current frame (in order that it can get
+ // gc'd, along with any objects in its closure and etc);
+ // 3. set current = next.
+ for (; pyFrame != nullptr && count < framesLimit;
+ next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
+ PyCodeObject *code = PyFrame_GetCode(pyFrame);
+ auto fileNameStr =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
+ llvm::StringRef fileName(fileNameStr);
+ if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
+ continue;
+
+ // co_qualname and PyCode_Addr2Location added in py3.11
+#if PY_VERSION_HEX < 0x030B00F0
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
+ llvm::StringRef funcName(name);
+ int startLine = PyFrame_GetLineNumber(pyFrame);
+ MlirLocation loc =
+ mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
+#else
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
+ llvm::StringRef funcName(name);
+ int startLine, startCol, endLine, endCol;
+ int lasti = PyFrame_GetLasti(pyFrame);
+ if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
+ &endCol)) {
+ throw nb::python_error();
+ }
+ MlirLocation loc = mlirLocationFileLineColRangeGet(
+ ctx, wrap(fileName), startLine, startCol, endLine, endCol);
+#endif
+
+ frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
+ ++count;
+ }
+ // When the loop breaks (after the last iter), current frame (if non-null)
+ // is leaked without this.
+ Py_XDECREF(pyFrame);
+
+ if (count == 0)
+ return mlirLocationUnknownGet(ctx);
+
+ MlirLocation callee = frames[0];
+ assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
+ if (count == 1)
+ return callee;
+
+ MlirLocation caller = frames[count - 1];
+ assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
+ for (int i = count - 2; i >= 1; i--)
+ caller = mlirLocationCallSiteGet(frames[i], caller);
+
+ return mlirLocationCallSiteGet(callee, caller);
+}
+
+PyLocation
+maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
+ if (location.has_value())
+ return location.value();
+ if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
+ return DefaultingPyLocation::resolve();
+
+ PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
+ MlirLocation mlirLoc = tracebackToLocation(ctx.get());
+ PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
+ return {ref, mlirLoc};
+}
+} // namespace
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+void populateRoot(nb::module_ &m) {
+ m.attr("T") = nb::type_var("T");
+ m.attr("U") = nb::type_var("U");
+
+ nb::class_<PyGlobals>(m, "_Globals")
+ .def_prop_rw("dialect_search_modules",
+ &PyGlobals::getDialectSearchPrefixes,
+ &PyGlobals::setDialectSearchPrefixes)
+ .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
+ "module_name"_a)
+ .def(
+ "_check_dialect_module_loaded",
+ [](PyGlobals &self, const std::string &dialectNamespace) {
+ return self.loadDialectModule(dialectNamespace);
+ },
+ "dialect_namespace"_a)
+ .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
+ "dialect_namespace"_a, "dialect_class"_a,
+ "Testing hook for directly registering a dialect")
+ .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
+ "operation_name"_a, "operation_class"_a, nb::kw_only(),
+ "replace"_a = false,
+ "Testing hook for directly registering an operation")
+ .def("loc_tracebacks_enabled",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebacksEnabled();
+ })
+ .def("set_loc_tracebacks_enabled",
+ [](PyGlobals &self, bool enabled) {
+ self.getTracebackLoc().setLocTracebacksEnabled(enabled);
+ })
+ .def("loc_tracebacks_frame_limit",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebackFramesLimit();
+ })
+ .def("set_loc_tracebacks_frame_limit",
+ [](PyGlobals &self, std::optional<int> n) {
+ self.getTracebackLoc().setLocTracebackFramesLimit(
+ n.value_or(PyGlobals::TracebackLoc::kMaxFrames));
+ })
+ .def("register_traceback_file_inclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileInclusion(filename);
+ })
+ .def("register_traceback_file_exclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileExclusion(filename);
+ });
+
+ // Aside from making the globals accessible to python, having python manage
+ // it is necessary to make sure it is destroyed (and releases its python
+ // resources) properly.
+ m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
+
+ // Registration decorators.
+ m.def(
+ "register_dialect",
+ [](nb::type_object pyClass) {
+ std::string dialectNamespace =
+ nb::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
+ PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
+ return pyClass;
+ },
+ "dialect_class"_a,
+ "Class decorator for registering a custom Dialect wrapper");
+ m.def(
+ "register_operation",
+ [](const nb::type_object &dialectClass, bool replace) -> nb::object {
+ return nb::cpp_function(
+ [dialectClass,
+ replace](nb::type_object opClass) -> nb::type_object {
+ std::string operationName =
+ nb::cast<std::string>(opClass.attr("OPERATION_NAME"));
+ PyGlobals::get().registerOperationImpl(operationName, opClass,
+ replace);
+ // Dict-stuff the new opClass by name onto the dialect class.
+ nb::object opClassName = opClass.attr("__name__");
+ dialectClass.attr(opClassName) = opClass;
+ return opClass;
+ });
+ },
+ // clang-format off
+ nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
+ "-> typing.Callable[[type[T]], type[T]]"),
+ // clang-format on
+ "dialect_class"_a, nb::kw_only(), "replace"_a = false,
+ "Produce a class decorator for registering an Operation class as part of "
+ "a dialect");
+ m.def(
+ MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
+ [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ return nb::cpp_function([mlirTypeID, replace](
+ nb::callable typeCaster) -> nb::object {
+ PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
+ return typeCaster;
+ });
+ },
+ // clang-format off
+ nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
+ "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
+ // clang-format on
+ "typeid"_a, nb::kw_only(), "replace"_a = false,
+ "Register a type caster for casting MLIR types to custom user types.");
+ m.def(
+ MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
+ [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
+ return nb::cpp_function(
+ [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
+ PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
+ replace);
+ return valueCaster;
+ });
+ },
+ // clang-format off
+ nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
+ "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
+ // clang-format on
+ "typeid"_a, nb::kw_only(), "replace"_a = false,
+ "Register a value caster for casting MLIR values to custom user values.");
+}
+
+//------------------------------------------------------------------------------
+// Populates the core exports of the 'ir' submodule.
+//------------------------------------------------------------------------------
+void populateIRCore(nb::module_ &m) {
+ //----------------------------------------------------------------------------
+ // Enums.
+ //----------------------------------------------------------------------------
+ nb::enum_<PyMlirDiagnosticSeverity>(m, "DiagnosticSeverity")
+ .value("ERROR", PyMlirDiagnosticSeverity::MlirDiagnosticError)
+ .value("WARNING", PyMlirDiagnosticSeverity::MlirDiagnosticWarning)
+ .value("NOTE", PyMlirDiagnosticSeverity::MlirDiagnosticNote)
+ .value("REMARK", PyMlirDiagnosticSeverity::MlirDiagnosticRemark);
+
+ nb::enum_<PyMlirWalkOrder>(m, "WalkOrder")
+ .value("PRE_ORDER", PyMlirWalkOrder::MlirWalkPreOrder)
+ .value("POST_ORDER", PyMlirWalkOrder::MlirWalkPostOrder);
+
+ nb::enum_<PyMlirWalkResult>(m, "WalkResult")
+ .value("ADVANCE", PyMlirWalkResult::MlirWalkResultAdvance)
+ .value("INTERRUPT", PyMlirWalkResult::MlirWalkResultInterrupt)
+ .value("SKIP", PyMlirWalkResult::MlirWalkResultSkip);
+
+ //----------------------------------------------------------------------------
+ // Mapping of Diagnostics.
+ //----------------------------------------------------------------------------
+ nb::class_<PyDiagnostic>(m, "Diagnostic")
+ .def_prop_ro("severity", &PyDiagnostic::getSeverity,
+ "Returns the severity of the diagnostic.")
+ .def_prop_ro("location", &PyDiagnostic::getLocation,
+ "Returns the location associated with the diagnostic.")
+ .def_prop_ro("message", &PyDiagnostic::getMessage,
+ "Returns the message text of the diagnostic.")
+ .def_prop_ro("notes", &PyDiagnostic::getNotes,
+ "Returns a tuple of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic &self) -> nb::str {
+ if (!self.isValid())
+ return nb::str("<Invalid Diagnostic>");
+ return self.getMessage();
+ },
+ "Returns the diagnostic message as a string.");
+
+ nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
+ .def(
+ "__init__",
+ [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
+ new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
+ },
+ "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
+ .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
+ "The severity level of the diagnostic.")
+ .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
+ "The location associated with the diagnostic.")
+ .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
+ "The message text of the diagnostic.")
+ .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
+ "List of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
+ "Returns the diagnostic message as a string.");
+
+ nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
+ .def("detach", &PyDiagnosticHandler::detach,
+ "Detaches the diagnostic handler from the context.")
+ .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
+ "Returns True if the handler is attached to a context.")
+ .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
+ "Returns True if an error was encountered during diagnostic "
+ "handling.")
+ .def("__enter__", &PyDiagnosticHandler::contextEnter,
+ "Enters the diagnostic handler as a context manager.")
+ .def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(),
+ "exc_value"_a.none(), "traceback"_a.none(),
+ "Exits the diagnostic handler context manager.");
+
+ // Expose DefaultThreadPool to python
+ nb::class_<PyThreadPool>(m, "ThreadPool")
+ .def(
+ "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
+ "Creates a new thread pool with default concurrency.")
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
+ "Returns the maximum number of threads in the pool.")
+ .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
+ "Returns the raw pointer to the LLVM thread pool as a string.");
+
+ nb::class_<PyMlirContext>(m, "Context")
+ .def(
+ "__init__",
+ [](PyMlirContext &self) {
+ MlirContext context = mlirContextCreateWithThreading(false);
+ new (&self) PyMlirContext(context);
+ },
+ R"(
+ Creates a new MLIR context.
+
+ The context is the top-level container for all MLIR objects. It owns the storage
+ for types, attributes, locations, and other core IR objects. A context can be
+ configured to allow or disallow unregistered dialects and can have dialects
+ loaded on-demand.)")
+ .def_static("_get_live_count", &PyMlirContext::getLiveCount,
+ "Gets the number of live Context objects.")
+ .def(
+ "_get_context_again",
+ [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+ return ref.releaseObject();
+ },
+ "Gets another reference to the same context.")
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
+ "Gets the number of live modules owned by this context.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
+ "Gets a capsule wrapping the MlirContext.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyMlirContext::createFromCapsule,
+ "Creates a Context from a capsule wrapping MlirContext.")
+ .def("__enter__", &PyMlirContext::contextEnter,
+ "Enters the context as a context manager.")
+ .def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(),
+ "exc_value"_a.none(), "traceback"_a.none(),
+ "Exits the context manager.")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/)
+ -> std::optional<nb::typed<nb::object, PyMlirContext>> {
+ auto *context = PyThreadContextEntry::getDefaultContext();
+ if (!context)
+ return {};
+ return nb::cast(context);
+ },
+ nb::sig("def current(/) -> Context | None"),
+ "Gets the Context bound to the current thread or returns None if no "
+ "context is set.")
+ .def_prop_ro(
+ "dialects",
+ [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Gets a container for accessing dialects by name.")
+ .def_prop_ro(
+ "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Alias for `dialects`.")
+ .def(
+ "get_dialect_descriptor",
+ [=](PyMlirContext &self, std::string &name) {
+ MlirDialect dialect = mlirContextGetOrLoadDialect(
+ self.get(), {name.data(), name.size()});
+ if (mlirDialectIsNull(dialect)) {
+ throw nb::value_error(
+ (Twine("Dialect '") + name + "' not found").str().c_str());
+ }
+ return PyDialectDescriptor(self.getRef(), dialect);
+ },
+ "dialect_name"_a,
+ "Gets or loads a dialect by name, returning its descriptor object.")
+ .def_prop_rw(
+ "allow_unregistered_dialects",
+ [](PyMlirContext &self) -> bool {
+ return mlirContextGetAllowUnregisteredDialects(self.get());
+ },
+ [](PyMlirContext &self, bool value) {
+ mlirContextSetAllowUnregisteredDialects(self.get(), value);
+ },
+ "Controls whether unregistered dialects are allowed in this context.")
+ .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
+ "callback"_a,
+ "Attaches a diagnostic handler that will receive callbacks.")
+ .def(
+ "enable_multithreading",
+ [](PyMlirContext &self, bool enable) {
+ mlirContextEnableMultithreading(self.get(), enable);
+ },
+ "enable"_a,
+ R"(
+ Enables or disables multi-threading support in the context.
+
+ Args:
+ enable: Whether to enable (True) or disable (False) multi-threading.
+ )")
+ .def(
+ "set_thread_pool",
+ [](PyMlirContext &self, PyThreadPool &pool) {
+ // we should disable multi-threading first before setting
+ // new thread pool otherwise the assert in
+ // MLIRContext::setThreadPool will be raised.
+ mlirContextEnableMultithreading(self.get(), false);
+ mlirContextSetThreadPool(self.get(), pool.get());
+ },
+ R"(
+ Sets a custom thread pool for the context to use.
+
+ Args:
+ pool: A ThreadPool object to use for parallel operations.
+
+ Note:
+ Multi-threading is automatically disabled before setting the thread pool.)")
+ .def(
+ "get_num_threads",
+ [](PyMlirContext &self) {
+ return mlirContextGetNumThreads(self.get());
+ },
+ "Gets the number of threads in the context's thread pool.")
+ .def(
+ "_mlir_thread_pool_ptr",
+ [](PyMlirContext &self) {
+ MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
+ std::stringstream ss;
+ ss << pool.ptr;
+ return ss.str();
+ },
+ "Gets the raw pointer to the LLVM thread pool as a string.")
+ .def(
+ "is_registered_operation",
+ [](PyMlirContext &self, std::string &name) {
+ return mlirContextIsRegisteredOperation(
+ self.get(), MlirStringRef{name.data(), name.size()});
+ },
+ "operation_name"_a,
+ R"(
+ Checks whether an operation with the given name is registered.
+
+ Args:
+ operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
+
+ Returns:
+ True if the operation is registered, False otherwise.)")
+ .def(
+ "append_dialect_registry",
+ [](PyMlirContext &self, PyDialectRegistry ®istry) {
+ mlirContextAppendDialectRegistry(self.get(), registry);
+ },
+ "registry"_a,
+ R"(
+ Appends the contents of a dialect registry to the context.
+
+ Args:
+ registry: A DialectRegistry containing dialects to append.)")
+ .def_prop_rw("emit_error_diagnostics",
+ &PyMlirContext::getEmitErrorDiagnostics,
+ &PyMlirContext::setEmitErrorDiagnostics,
+ R"(
+ Controls whether error diagnostics are emitted to diagnostic handlers.
+
+ By default, error diagnostics are captured and reported through MLIRError exceptions.)")
+ .def(
+ "load_all_available_dialects",
+ [](PyMlirContext &self) {
+ mlirContextLoadAllAvailableDialects(self.get());
+ },
+ R"(
+ Loads all dialects available in the registry into the context.
+
+ This eagerly loads all dialects that have been registered, making them
+ immediately available for use.)");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectDescriptor
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+ .def_prop_ro(
+ "namespace",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ return nb::str(ns.data, ns.length);
+ },
+ "Returns the namespace of the dialect.")
+ .def(
+ "__repr__",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ std::string repr("<DialectDescriptor ");
+ repr.append(ns.data, ns.length);
+ repr.append(">");
+ return repr;
+ },
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect descriptor.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialects
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialects>(m, "Dialects")
+ .def(
+ "__getitem__",
+ [=](PyDialects &self, std::string keyName) {
+ MlirDialect dialect =
+ self.getDialectForKey(keyName, /*attrError=*/false);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(keyName, std::move(descriptor));
+ },
+ "Gets a dialect by name using subscript notation.")
+ .def(
+ "__getattr__",
+ [=](PyDialects &self, std::string attrName) {
+ MlirDialect dialect =
+ self.getDialectForKey(attrName, /*attrError=*/true);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(attrName, std::move(descriptor));
+ },
+ "Gets a dialect by name using attribute notation.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialect
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialect>(m, "Dialect")
+ .def(nb::init<nb::object>(), "descriptor"_a,
+ "Creates a Dialect from a DialectDescriptor.")
+ .def_prop_ro(
+ "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
+ "Returns the DialectDescriptor for this dialect.")
+ .def(
+ "__repr__",
+ [](const nb::object &self) {
+ auto clazz = self.attr("__class__");
+ return nb::str("<Dialect ") +
+ self.attr("descriptor").attr("namespace") +
+ nb::str(" (class ") + clazz.attr("__module__") +
+ nb::str(".") + clazz.attr("__name__") + nb::str(")>");
+ },
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectRegistry
+ //----------------------------------------------------------------------------
+ nb::class_<PyDialectRegistry>(m, "DialectRegistry")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
+ "Gets a capsule wrapping the MlirDialectRegistry.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyDialectRegistry::createFromCapsule,
+ "Creates a DialectRegistry from a capsule wrapping "
+ "`MlirDialectRegistry`.")
+ .def(nb::init<>(), "Creates a new empty dialect registry.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Location
+ //----------------------------------------------------------------------------
+ nb::class_<PyLocation>(m, "Location")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
+ "Gets a capsule wrapping the MlirLocation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
+ "Creates a Location from a capsule wrapping MlirLocation.")
+ .def("__enter__", &PyLocation::contextEnter,
+ "Enters the location as a context manager.")
+ .def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(),
+ "exc_value"_a.none(), "traceback"_a.none(),
+ "Exits the location context manager.")
+ .def(
+ "__eq__",
+ [](PyLocation &self, PyLocation &other) -> bool {
+ return mlirLocationEqual(self, other);
+ },
+ "Compares two locations for equality.")
+ .def(
+ "__eq__", [](PyLocation &self, nb::object other) { return false; },
+ "Compares location with non-location object (always returns False).")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/) -> std::optional<PyLocation *> {
+ auto *loc = PyThreadContextEntry::getDefaultLocation();
+ if (!loc)
+ return std::nullopt;
+ return loc;
+ },
+ // clang-format off
+ nb::sig("def current(/) -> Location | None"),
+ // clang-format on
+ "Gets the Location bound to the current thread or raises ValueError.")
+ .def_static(
+ "unknown",
+ [](DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationUnknownGet(context->get()));
+ },
+ "context"_a = nb::none(),
+ "Gets a Location representing an unknown location.")
+ .def_static(
+ "callsite",
+ [](PyLocation callee, const std::vector<PyLocation> &frames,
+ DefaultingPyMlirContext context) {
+ if (frames.empty())
+ throw nb::value_error("No caller frames provided.");
+ MlirLocation caller = frames.back().get();
+ for (const PyLocation &frame :
+ llvm::reverse(llvm::ArrayRef(frames).drop_back()))
+ caller = mlirLocationCallSiteGet(frame.get(), caller);
+ return PyLocation(context->getRef(),
+ mlirLocationCallSiteGet(callee.get(), caller));
+ },
+ "callee"_a, "frames"_a, "context"_a = nb::none(),
+ "Gets a Location representing a caller and callsite.")
+ .def("is_a_callsite", mlirLocationIsACallSite,
+ "Returns True if this location is a CallSiteLoc.")
+ .def_prop_ro(
+ "callee",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCallee(self));
+ },
+ "Gets the callee location from a CallSiteLoc.")
+ .def_prop_ro(
+ "caller",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCaller(self));
+ },
+ "Gets the caller location from a CallSiteLoc.")
+ .def_static(
+ "file",
+ [](std::string filename, int line, int col,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationFileLineColGet(
+ context->get(), toMlirStringRef(filename), line, col));
+ },
+ "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(),
+ "Gets a Location representing a file, line and column.")
+ .def_static(
+ "file",
+ [](std::string filename, int startLine, int startCol, int endLine,
+ int endCol, DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationFileLineColRangeGet(
+ context->get(), toMlirStringRef(filename),
+ startLine, startCol, endLine, endCol));
+ },
+ "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a,
+ "end_col"_a, "context"_a = nb::none(),
+ "Gets a Location representing a file, line and column range.")
+ .def("is_a_file", mlirLocationIsAFileLineColRange,
+ "Returns True if this location is a FileLineColLoc.")
+ .def_prop_ro(
+ "filename",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(
+ mlirLocationFileLineColRangeGetFilename(loc));
+ },
+ "Gets the filename from a FileLineColLoc.")
+ .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
+ "Gets the start line number from a `FileLineColLoc`.")
+ .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
+ "Gets the start column number from a `FileLineColLoc`.")
+ .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
+ "Gets the end line number from a `FileLineColLoc`.")
+ .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
+ "Gets the end column number from a `FileLineColLoc`.")
+ .def_static(
+ "fused",
+ [](const std::vector<PyLocation> &pyLocations,
+ std::optional<PyAttribute> metadata,
+ DefaultingPyMlirContext context) {
+ llvm::SmallVector<MlirLocation, 4> locations;
+ locations.reserve(pyLocations.size());
+ for (auto &pyLocation : pyLocations)
+ locations.push_back(pyLocation.get());
+ MlirLocation location = mlirLocationFusedGet(
+ context->get(), locations.size(), locations.data(),
+ metadata ? metadata->get() : MlirAttribute{0});
+ return PyLocation(context->getRef(), location);
+ },
+ "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(),
+ "Gets a Location representing a fused location with optional "
+ "metadata.")
+ .def("is_a_fused", mlirLocationIsAFused,
+ "Returns True if this location is a `FusedLoc`.")
+ .def_prop_ro(
+ "locations",
+ [](PyLocation &self) {
+ unsigned numLocations = mlirLocationFusedGetNumLocations(self);
+ std::vector<MlirLocation> locations(numLocations);
+ if (numLocations)
+ mlirLocationFusedGetLocations(self, locations.data());
+ std::vector<PyLocation> pyLocations{};
+ pyLocations.reserve(numLocations);
+ for (unsigned i = 0; i < numLocations; ++i)
+ pyLocations.emplace_back(self.getContext(), locations[i]);
+ return pyLocations;
+ },
+ "Gets the list of locations from a `FusedLoc`.")
+ .def_static(
+ "name",
+ [](std::string name, std::optional<PyLocation> childLoc,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationNameGet(
+ context->get(), toMlirStringRef(name),
+ childLoc ? childLoc->get()
+ : mlirLocationUnknownGet(context->get())));
+ },
+ "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(),
+ "Gets a Location representing a named location with optional child "
+ "location.")
+ .def("is_a_name", mlirLocationIsAName,
+ "Returns True if this location is a `NameLoc`.")
+ .def_prop_ro(
+ "name_str",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(mlirLocationNameGetName(loc));
+ },
+ "Gets the name string from a `NameLoc`.")
+ .def_prop_ro(
+ "child_loc",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationNameGetChildLoc(self));
+ },
+ "Gets the child location from a `NameLoc`.")
+ .def_static(
+ "from_attr",
+ [](PyAttribute &attribute, DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationFromAttribute(attribute));
+ },
+ "attribute"_a, "context"_a = nb::none(),
+ "Gets a Location from a `LocationAttr`.")
+ .def_prop_ro(
+ "context",
+ [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Location`.")
+ .def_prop_ro(
+ "attr",
+ [](PyLocation &self) {
+ return PyAttribute(self.getContext(),
+ mlirLocationGetAttribute(self));
+ },
+ "Get the underlying `LocationAttr`.")
+ .def(
+ "emit_error",
+ [](PyLocation &self, std::string message) {
+ mlirEmitError(self, message.c_str());
+ },
+ "message"_a,
+ R"(
+ Emits an error diagnostic at this location.
+
+ Args:
+ message: The error message to emit.)")
+ .def(
+ "__repr__",
+ [](PyLocation &self) {
+ PyPrintAccumulator printAccum;
+ mlirLocationPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly representation of the location.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Module
+ //----------------------------------------------------------------------------
+ nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
+ "Gets a capsule wrapping the MlirModule.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+ R"(
+ Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
+
+ This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
+ prevent double-frees (of the underlying `mlir::Module`).)")
+ .def("_clear_mlir_module", &PyModule::clearMlirModule,
+ R"(
+ Clears the internal MLIR module reference.
+
+ This is used internally to prevent double-free when ownership is transferred
+ via the C API capsule mechanism. Not intended for normal use.)")
+ .def_static(
+ "parse",
+ [](const std::string &moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
+ .def_static(
+ "parse",
+ [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParse(
+ context->get(), toMlirStringRef(moduleAsm));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ "asm"_a, "context"_a = nb::none(), kModuleParseDocstring)
+ .def_static(
+ "parseFile",
+ [](const std::string &path, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParseFromFile(
+ context->get(), toMlirStringRef(path));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ "path"_a, "context"_a = nb::none(), kModuleParseDocstring)
+ .def_static(
+ "create",
+ [](const std::optional<PyLocation> &loc)
+ -> nb::typed<nb::object, PyModule> {
+ PyLocation pyLoc = maybeGetTracebackLocation(loc);
+ MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
+ return PyModule::forModule(module).releaseObject();
+ },
+ "loc"_a = nb::none(), "Creates an empty module.")
+ .def_prop_ro(
+ "context",
+ [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that created the `Module`.")
+ .def_prop_ro(
+ "operation",
+ [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
+ return PyOperation::forOperation(self.getContext(),
+ mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject())
+ .releaseObject();
+ },
+ "Accesses the module as an operation.")
+ .def_prop_ro(
+ "body",
+ [](PyModule &self) {
+ PyOperationRef moduleOp = PyOperation::forOperation(
+ self.getContext(), mlirModuleGetOperation(self.get()),
+ self.getRef().releaseObject());
+ PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
+ return returnBlock;
+ },
+ "Return the block for this module.")
+ .def(
+ "dump",
+ [](PyModule &self) {
+ mlirOperationDump(mlirModuleGetOperation(self.get()));
+ },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](const nb::object &self) {
+ // Defer to the operation's __str__.
+ return self.attr("operation").attr("__str__")();
+ },
+ nb::sig("def __str__(self) -> str"),
+ R"(
+ Gets the assembly form of the operation with default options.
+
+ If more advanced control over the assembly formatting or I/O options is needed,
+ use the dedicated print or get_asm method, which supports keyword arguments to
+ customize behavior.
+ )")
+ .def(
+ "__eq__",
+ [](PyModule &self, PyModule &other) {
+ return mlirModuleEqual(self.get(), other.get());
+ },
+ "other"_a, "Compares two modules for equality.")
+ .def(
+ "__hash__",
+ [](PyModule &self) { return mlirModuleHashValue(self.get()); },
+ "Returns the hash value of the module.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Operation.
+ //----------------------------------------------------------------------------
+ nb::class_<PyOperationBase>(m, "_OperationBase")
+ .def_prop_ro(
+ MLIR_PYTHON_CAPI_PTR_ATTR,
+ [](PyOperationBase &self) {
+ return self.getOperation().getCapsule();
+ },
+ "Gets a capsule wrapping the `MlirOperation`.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, PyOperationBase &other) {
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
+ },
+ "Compares two operations for equality.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, nb::object other) { return false; },
+ "Compares operation with non-operation object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyOperationBase &self) {
+ return mlirOperationHashValue(self.getOperation().get());
+ },
+ "Returns the hash value of the operation.")
+ .def_prop_ro(
+ "attributes",
+ [](PyOperationBase &self) {
+ return PyOpAttributeMap(self.getOperation().getRef());
+ },
+ "Returns a dictionary-like map of operation attributes.")
+ .def_prop_ro(
+ "context",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyOperation &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ return concreteOperation.getContext().getObject();
+ },
+ "Context that owns the operation.")
+ .def_prop_ro(
+ "name",
+ [](PyOperationBase &self) {
+ auto &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ MlirOperation operation = concreteOperation.get();
+ return mlirIdentifierStr(mlirOperationGetName(operation));
+ },
+ "Returns the fully qualified name of the operation.")
+ .def_prop_ro(
+ "operands",
+ [](PyOperationBase &self) {
+ return PyOpOperandList(self.getOperation().getRef());
+ },
+ "Returns the list of operation operands.")
+ .def_prop_ro(
+ "regions",
+ [](PyOperationBase &self) {
+ return PyRegionList(self.getOperation().getRef());
+ },
+ "Returns the list of operation regions.")
+ .def_prop_ro(
+ "results",
+ [](PyOperationBase &self) {
+ return PyOpResultList(self.getOperation().getRef());
+ },
+ "Returns the list of Operation results.")
+ .def_prop_ro(
+ "result",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
+ auto &operation = self.getOperation();
+ return PyOpResult(operation.getRef(), getUniqueResult(operation))
+ .maybeDownCast();
+ },
+ "Shortcut to get an op result if it has only one (throws an error "
+ "otherwise).")
+ .def_prop_rw(
+ "location",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ return PyLocation(operation.getContext(),
+ mlirOperationGetLocation(operation.get()));
+ },
+ [](PyOperationBase &self, const PyLocation &location) {
+ PyOperation &operation = self.getOperation();
+ mlirOperationSetLocation(operation.get(), location.get());
+ },
+ nb::for_getter("Returns the source location the operation was "
+ "defined or derived from."),
+ nb::for_setter("Sets the source location the operation was defined "
+ "or derived from."))
+ .def_prop_ro(
+ "parent",
+ [](PyOperationBase &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto parent = self.getOperation().getParentOperation();
+ if (parent)
+ return parent->getObject();
+ return {};
+ },
+ "Returns the parent operation, or `None` if at top level.")
+ .def(
+ "__str__",
+ [](PyOperationBase &self) {
+ return self.getAsm(/*binary=*/false,
+ /*largeElementsLimit=*/std::nullopt,
+ /*largeResourceLimit=*/std::nullopt,
+ /*enableDebugInfo=*/false,
+ /*prettyDebugInfo=*/false,
+ /*printGenericOpForm=*/false,
+ /*useLocalScope=*/false,
+ /*useNameLocAsPrefix=*/false,
+ /*assumeVerified=*/false,
+ /*skipRegions=*/false);
+ },
+ nb::sig("def __str__(self) -> str"),
+ "Returns the assembly form of the operation.")
+ .def("print",
+ nb::overload_cast<PyAsmState &, nb::object, bool>(
+ &PyOperationBase::print),
+ "state"_a, "file"_a = nb::none(), "binary"_a = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ state: `AsmState` capturing the operation numbering and flags.
+ file: Optional file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
+ .def("print",
+ nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
+ bool, bool, bool, bool, bool, bool, nb::object,
+ bool, bool>(&PyOperationBase::print),
+ // Careful: Lots of arguments must match up with print method.
+ "large_elements_limit"_a = nb::none(),
+ "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
+ "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
+ "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
+ "assume_verified"_a = false, "file"_a = nb::none(),
+ "binary"_a = false, "skip_regions"_a = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ large_resource_limit: Whether to elide resource attributes above this
+ number of characters. Defaults to None (no limit). If large_elements_limit
+ is set and this is None, the behavior will be to use large_elements_limit
+ as large_resource_limit.
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable). Defaults to False.
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+ use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
+ prefixes for the SSA identifiers. Defaults to False.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ skip_regions: Whether to skip printing regions. Defaults to False.)")
+ .def("write_bytecode", &PyOperationBase::writeBytecode, "file"_a,
+ "desired_version"_a = nb::none(),
+ R"(
+ Write the bytecode form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to.
+ desired_version: Optional version of bytecode to emit.
+ Returns:
+ The bytecode writer status.)")
+ .def("get_asm", &PyOperationBase::getAsm,
+ // Careful: Lots of arguments must match up with get_asm method.
+ "binary"_a = false, "large_elements_limit"_a = nb::none(),
+ "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
+ "pretty_debug_info"_a = false, "print_generic_op_form"_a = false,
+ "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
+ "assume_verified"_a = false, "skip_regions"_a = false,
+ R"(
+ Gets the assembly form of the operation with all options available.
+
+ Args:
+ binary: Whether to return a bytes (True) or str (False) object. Defaults to
+ False.
+ ... others ...: See the print() method for common keyword arguments for
+ configuring the printout.
+ Returns:
+ Either a bytes or str object, depending on the setting of the `binary`
+ argument.)")
+ .def("verify", &PyOperationBase::verify,
+ "Verify the operation. Raises MLIRError if verification fails, and "
+ "returns true otherwise.")
+ .def("move_after", &PyOperationBase::moveAfter, "other"_a,
+ "Puts self immediately after the other operation in its parent "
+ "block.")
+ .def("move_before", &PyOperationBase::moveBefore, "other"_a,
+ "Puts self immediately before the other operation in its parent "
+ "block.")
+ .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, "other"_a,
+ R"(
+ Checks if this operation is before another in the same block.
+
+ Args:
+ other: Another operation in the same parent block.
+
+ Returns:
+ True if this operation is before `other` in the operation list of the parent block.)")
+ .def(
+ "clone",
+ [](PyOperationBase &self,
+ const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperation().clone(ip);
+ },
+ "ip"_a = nb::none(),
+ R"(
+ Creates a deep copy of the operation.
+
+ Args:
+ ip: Optional insertion point where the cloned operation should be inserted.
+ If None, the current insertion point is used. If False, the operation
+ remains detached.
+
+ Returns:
+ A new Operation that is a clone of this operation.)")
+ .def(
+ "detach_from_parent",
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ if (!operation.isAttached())
+ throw nb::value_error("Detached operation has no parent.");
+
+ operation.detachFromParent();
+ return operation.createOpView();
+ },
+ "Detaches the operation from its parent block.")
+ .def_prop_ro(
+ "attached",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ return operation.isAttached();
+ },
+ "Reports if the operation is attached to its parent block.")
+ .def(
+ "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
+ R"(
+ Erases the operation and frees its memory.
+
+ Note:
+ After erasing, any Python references to the operation become invalid.)")
+ .def("walk", &PyOperationBase::walk, "callback"_a,
+ "walk_order"_a = PyMlirWalkOrder::MlirWalkPostOrder,
+ // clang-format off
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+ // clang-format on
+ R"(
+ Walks the operation tree with a callback function.
+
+ Args:
+ callback: A callable that takes an Operation and returns a WalkResult.
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
+
+ nb::class_<PyOperation, PyOperationBase>(m, "Operation")
+ .def_static(
+ "create",
+ [](std::string_view name,
+ std::optional<std::vector<PyType *>> results,
+ std::optional<std::vector<PyValue *>> operands,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors, int regions,
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp,
+ bool inferType) -> nb::typed<nb::object, PyOperation> {
+ // Unpack/validate operands.
+ llvm::SmallVector<MlirValue, 4> mlirOperands;
+ if (operands) {
+ mlirOperands.reserve(operands->size());
+ for (PyValue *operand : *operands) {
+ if (!operand)
+ throw nb::value_error("operand value cannot be None");
+ mlirOperands.push_back(operand->get());
+ }
+ }
+
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ return PyOperation::create(name, results, mlirOperands, attributes,
+ successors, regions, pyLoc, maybeIp,
+ inferType);
+ },
+ "name"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
+ "attributes"_a = nb::none(), "successors"_a = nb::none(),
+ "regions"_a = 0, "loc"_a = nb::none(), "ip"_a = nb::none(),
+ "infer_type"_a = false,
+ R"(
+ Creates a new operation.
+
+ Args:
+ name: Operation name (e.g. `dialect.operation`).
+ results: Optional sequence of Type representing op result types.
+ operands: Optional operands of the operation.
+ attributes: Optional Dict of {str: Attribute}.
+ successors: Optional List of Block for the operation's successors.
+ regions: Number of regions to create (default = 0).
+ location: Optional Location object (defaults to resolve from context manager).
+ ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
+ infer_type: Whether to infer result types (default = False).
+ Returns:
+ A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
+ .def_static(
+ "parse",
+ [](const std::string &sourceStr, const std::string &sourceName,
+ DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyOpView> {
+ return PyOperation::parse(context->getRef(), sourceStr, sourceName)
+ ->createOpView();
+ },
+ "source"_a, nb::kw_only(), "source_name"_a = "",
+ "context"_a = nb::none(),
+ "Parses an operation. Supports both text assembly format and binary "
+ "bytecode format.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
+ "Gets a capsule wrapping the MlirOperation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyOperation::createFromCapsule,
+ "Creates an Operation from a capsule wrapping MlirOperation.")
+ .def_prop_ro(
+ "operation",
+ [](nb::object self) -> nb::typed<nb::object, PyOperation> {
+ return self;
+ },
+ "Returns self (the operation).")
+ .def_prop_ro(
+ "opview",
+ [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
+ return self.createOpView();
+ },
+ R"(
+ Returns an OpView of this operation.
+
+ Note:
+ If the operation has a registered and loaded dialect then this OpView will
+ be concrete wrapper class.)")
+ .def_prop_ro("block", &PyOperation::getBlock,
+ "Returns the block containing this operation.")
+ .def_prop_ro(
+ "successors",
+ [](PyOperationBase &self) {
+ return PyOpSuccessors(self.getOperation().getRef());
+ },
+ "Returns the list of Operation successors.")
+ .def(
+ "replace_uses_of_with",
+ [](PyOperation &self, PyValue &of, PyValue &with) {
+ mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
+ },
+ "of"_a, "with_"_a,
+ "Replaces uses of the 'of' value with the 'with' value inside the "
+ "operation.")
+ .def("_set_invalid", &PyOperation::setInvalid,
+ "Invalidate the operation.");
+
+ auto opViewClass =
+ nb::class_<PyOpView, PyOperationBase>(m, "OpView")
+ .def(nb::init<nb::typed<nb::object, PyOperation>>(), "operation"_a)
+ .def(
+ "__init__",
+ [](PyOpView *self, std::string_view name,
+ std::tuple<int, bool> opRegionSpec,
+ nb::object operandSegmentSpecObj,
+ nb::object resultSegmentSpecObj,
+ std::optional<nb::list> resultTypeList, nb::list operandList,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions,
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp) {
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ new (self) PyOpView(PyOpView::buildGeneric(
+ name, opRegionSpec, operandSegmentSpecObj,
+ resultSegmentSpecObj, resultTypeList, operandList,
+ attributes, successors, regions, pyLoc, maybeIp));
+ },
+ "name"_a, "opRegionSpec"_a,
+ "operandSegmentSpecObj"_a = nb::none(),
+ "resultSegmentSpecObj"_a = nb::none(), "results"_a = nb::none(),
+ "operands"_a = nb::none(), "attributes"_a = nb::none(),
+ "successors"_a = nb::none(), "regions"_a = nb::none(),
+ "loc"_a = nb::none(), "ip"_a = nb::none())
+ .def_prop_ro(
+ "operation",
+ [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperationObject();
+ })
+ .def_prop_ro("opview",
+ [](nb::object self) -> nb::typed<nb::object, PyOpView> {
+ return self;
+ })
+ .def(
+ "__str__",
+ [](PyOpView &self) { return nb::str(self.getOperationObject()); })
+ .def_prop_ro(
+ "successors",
+ [](PyOperationBase &self) {
+ return PyOpSuccessors(self.getOperation().getRef());
+ },
+ "Returns the list of Operation successors.")
+ .def(
+ "_set_invalid",
+ [](PyOpView &self) { self.getOperation().setInvalid(); },
+ "Invalidate the operation.");
+ opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
+ opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
+ opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
+ // It is faster to pass the operation_name, ods_regions, and
+ // ods_operand_segments/ods_result_segments as arguments to the constructor,
+ // rather than to access them as attributes.
+ opViewClass.attr("build_generic") = classmethod(
+ [](nb::handle cls, std::optional<nb::list> resultTypeList,
+ nb::list operandList, std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors,
+ std::optional<int> regions, std::optional<PyLocation> location,
+ const nb::object &maybeIp) {
+ std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+ std::tuple<int, bool> opRegionSpec =
+ nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+ nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
+ nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
+ return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
+ resultSegmentSpec, resultTypeList,
+ operandList, attributes, successors,
+ regions, pyLoc, maybeIp);
+ },
+ "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
+ "attributes"_a = nb::none(), "successors"_a = nb::none(),
+ "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
+ "Builds a specific, generated OpView based on class level attributes.");
+ opViewClass.attr("parse") = classmethod(
+ [](const nb::object &cls, const std::string &sourceStr,
+ const std::string &sourceName,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
+ PyOperationRef parsed =
+ PyOperation::parse(context->getRef(), sourceStr, sourceName);
+
+ // Check if the expected operation was parsed, and cast to to the
+ // appropriate `OpView` subclass if successful.
+ // NOTE: This accesses attributes that have been automatically added to
+ // `OpView` subclasses, and is not intended to be used on `OpView`
+ // directly.
+ std::string clsOpName =
+ nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+ MlirStringRef identifier =
+ mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
+ std::string_view parsedOpName(identifier.data, identifier.length);
+ if (clsOpName != parsedOpName)
+ throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
+ parsedOpName + "'");
+ return PyOpView::constructDerived(cls, parsed.getObject());
+ },
+ "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
+ "context"_a = nb::none(),
+ "Parses a specific, generated OpView based on class level attributes.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyRegion.
+ //----------------------------------------------------------------------------
+ nb::class_<PyRegion>(m, "Region")
+ .def_prop_ro(
+ "blocks",
+ [](PyRegion &self) {
+ return PyBlockList(self.getParentOperation(), self.get());
+ },
+ "Returns a forward-optimized sequence of blocks.")
+ .def_prop_ro(
+ "owner",
+ [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the operation owning this region.")
+ .def(
+ "__iter__",
+ [](PyRegion &self) {
+ self.checkValid();
+ MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
+ return PyBlockIterator(self.getParentOperation(), firstBlock);
+ },
+ "Iterates over blocks in the region.")
+ .def(
+ "__eq__",
+ [](PyRegion &self, PyRegion &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two regions for pointer equality.")
+ .def(
+ "__eq__", [](PyRegion &self, nb::object &other) { return false; },
+ "Compares region with non-region object (always returns False).");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyBlock.
+ //----------------------------------------------------------------------------
+ nb::class_<PyBlock>(m, "Block")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
+ "Gets a capsule wrapping the MlirBlock.")
+ .def_prop_ro(
+ "owner",
+ [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the owning operation of this block.")
+ .def_prop_ro(
+ "region",
+ [](PyBlock &self) {
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ return PyRegion(self.getParentOperation(), region);
+ },
+ "Returns the owning region of this block.")
+ .def_prop_ro(
+ "arguments",
+ [](PyBlock &self) {
+ return PyBlockArgumentList(self.getParentOperation(), self.get());
+ },
+ "Returns a list of block arguments.")
+ .def(
+ "add_argument",
+ [](PyBlock &self, const PyType &type, const PyLocation &loc) {
+ return PyBlockArgument(self.getParentOperation(),
+ mlirBlockAddArgument(self.get(), type, loc));
+ },
+ "type"_a, "loc"_a,
+ R"(
+ Appends an argument of the specified type to the block.
+
+ Args:
+ type: The type of the argument to add.
+ loc: The source location for the argument.
+
+ Returns:
+ The newly added block argument.)")
+ .def(
+ "erase_argument",
+ [](PyBlock &self, unsigned index) {
+ return mlirBlockEraseArgument(self.get(), index);
+ },
+ "index"_a,
+ R"(
+ Erases the argument at the specified index.
+
+ Args:
+ index: The index of the argument to erase.)")
+ .def_prop_ro(
+ "operations",
+ [](PyBlock &self) {
+ return PyOperationList(self.getParentOperation(), self.get());
+ },
+ "Returns a forward-optimized sequence of operations.")
+ .def_static(
+ "create_at_start",
+ [](PyRegion &parent, const nb::sequence &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ parent.checkValid();
+ MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
+ mlirRegionInsertOwnedBlock(parent, 0, block);
+ return PyBlock(parent.getParentOperation(), block);
+ },
+ "parent"_a, "arg_types"_a = nb::list(), "arg_locs"_a = std::nullopt,
+ "Creates and returns a new Block at the beginning of the given "
+ "region (with given argument types and locations).")
+ .def(
+ "append_to",
+ [](PyBlock &self, PyRegion ®ion) {
+ MlirBlock b = self.get();
+ if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
+ mlirBlockDetach(b);
+ mlirRegionAppendOwnedBlock(region.get(), b);
+ },
+ "region"_a,
+ R"(
+ Appends this block to a region.
+
+ Transfers ownership if the block is currently owned by another region.
+
+ Args:
+ region: The region to append the block to.)")
+ .def(
+ "create_before",
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ self.checkValid();
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
+ "Creates and returns a new Block before this block "
+ "(with given argument types and locations).")
+ .def(
+ "create_after",
+ [](PyBlock &self, const nb::args &pyArgTypes,
+ const std::optional<nb::sequence> &pyArgLocs) {
+ self.checkValid();
+ MlirBlock block =
+ createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
+ MlirRegion region = mlirBlockGetParentRegion(self.get());
+ mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
+ return PyBlock(self.getParentOperation(), block);
+ },
+ "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt,
+ "Creates and returns a new Block after this block "
+ "(with given argument types and locations).")
+ .def(
+ "__iter__",
+ [](PyBlock &self) {
+ self.checkValid();
+ MlirOperation firstOperation =
+ mlirBlockGetFirstOperation(self.get());
+ return PyOperationIterator(self.getParentOperation(),
+ firstOperation);
+ },
+ "Iterates over operations in the block.")
+ .def(
+ "__eq__",
+ [](PyBlock &self, PyBlock &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two blocks for pointer equality.")
+ .def(
+ "__eq__", [](PyBlock &self, nb::object &other) { return false; },
+ "Compares block with non-block object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyBlock &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the block.")
+ .def(
+ "__str__",
+ [](PyBlock &self) {
+ self.checkValid();
+ PyPrintAccumulator printAccum;
+ mlirBlockPrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the block.")
+ .def(
+ "append",
+ [](PyBlock &self, PyOperationBase &operation) {
+ if (operation.getOperation().isAttached())
+ operation.getOperation().detachFromParent();
+
+ MlirOperation mlirOperation = operation.getOperation().get();
+ mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
+ operation.getOperation().setAttached(
+ self.getParentOperation().getObject());
+ },
+ "operation"_a,
+ R"(
+ Appends an operation to this block.
+
+ If the operation is currently in another block, it will be moved.
+
+ Args:
+ operation: The operation to append to the block.)")
+ .def_prop_ro(
+ "successors",
+ [](PyBlock &self) {
+ return PyBlockSuccessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block successors.")
+ .def_prop_ro(
+ "predecessors",
+ [](PyBlock &self) {
+ return PyBlockPredecessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block predecessors.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyInsertionPoint.
+ //----------------------------------------------------------------------------
+
+ nb::class_<PyInsertionPoint>(m, "InsertionPoint")
+ .def(nb::init<PyBlock &>(), "block"_a,
+ "Inserts after the last operation but still inside the block.")
+ .def("__enter__", &PyInsertionPoint::contextEnter,
+ "Enters the insertion point as a context manager.")
+ .def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(),
+ "exc_value"_a.none(), "traceback"_a.none(),
+ "Exits the insertion point context manager.")
+ .def_prop_ro_static(
+ "current",
+ [](nb::object & /*class*/) {
+ auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
+ if (!ip)
+ throw nb::value_error("No current InsertionPoint");
+ return ip;
+ },
+ nb::sig("def current(/) -> InsertionPoint"),
+ "Gets the InsertionPoint bound to the current thread or raises "
+ "ValueError if none has been set.")
+ .def(nb::init<PyOperationBase &>(), "beforeOperation"_a,
+ "Inserts before a referenced operation.")
+ .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, "block"_a,
+ R"(
+ Creates an insertion point at the beginning of a block.
+
+ Args:
+ block: The block at whose beginning operations should be inserted.
+
+ Returns:
+ An InsertionPoint at the block's beginning.)")
+ .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
+ "block"_a,
+ R"(
+ Creates an insertion point before a block's terminator.
+
+ Args:
+ block: The block whose terminator to insert before.
+
+ Returns:
+ An InsertionPoint before the terminator.
+
+ Raises:
+ ValueError: If the block has no terminator.)")
+ .def_static("after", &PyInsertionPoint::after, "operation"_a,
+ R"(
+ Creates an insertion point immediately after an operation.
+
+ Args:
+ operation: The operation after which to insert.
+
+ Returns:
+ An InsertionPoint after the operation.)")
+ .def("insert", &PyInsertionPoint::insert, "operation"_a,
+ R"(
+ Inserts an operation at this insertion point.
+
+ Args:
+ operation: The operation to insert.)")
+ .def_prop_ro(
+ "block", [](PyInsertionPoint &self) { return self.getBlock(); },
+ "Returns the block that this `InsertionPoint` points to.")
+ .def_prop_ro(
+ "ref_operation",
+ [](PyInsertionPoint &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto refOperation = self.getRefOperation();
+ if (refOperation)
+ return refOperation->getObject();
+ return {};
+ },
+ "The reference operation before which new operations are "
+ "inserted, or None if the insertion point is at the end of "
+ "the block.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyAttribute.
+ //----------------------------------------------------------------------------
+ nb::class_<PyAttribute>(m, "Attribute")
+ // Delegate to the PyAttribute copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirAttribute.
+ .def(nb::init<PyAttribute &>(), "cast_from_type"_a,
+ "Casts the passed attribute to the generic `Attribute`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
+ "Gets a capsule wrapping the MlirAttribute.")
+ .def_static(
+ MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
+ "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
+ .def_static(
+ "parse",
+ [](const std::string &attrSpec, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyAttribute> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr = mlirAttributeParseGet(
+ context->get(), toMlirStringRef(attrSpec));
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Unable to parse attribute", errors.take());
+ return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
+ },
+ "asm"_a, "context"_a = nb::none(),
+ "Parses an attribute from an assembly form. Raises an `MLIRError` on "
+ "failure.")
+ .def_prop_ro(
+ "context",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Attribute`.")
+ .def_prop_ro(
+ "type",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirAttributeGetType(self))
+ .maybeDownCast();
+ },
+ "Returns the type of the `Attribute`.")
+ .def(
+ "get_named",
+ [](PyAttribute &self, std::string name) {
+ return PyNamedAttribute(self, std::move(name));
+ },
+ nb::keep_alive<0, 1>(),
+ R"(
+ Binds a name to the attribute, creating a `NamedAttribute`.
+
+ Args:
+ name: The name to bind to the `Attribute`.
+
+ Returns:
+ A `NamedAttribute` with the given name and this attribute.)")
+ .def(
+ "__eq__",
+ [](PyAttribute &self, PyAttribute &other) { return self == other; },
+ "Compares two attributes for equality.")
+ .def(
+ "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
+ "Compares attribute with non-attribute object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyAttribute &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the attribute.")
+ .def(
+ "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyAttribute &self) {
+ PyPrintAccumulator printAccum;
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the Attribute.")
+ .def(
+ "__repr__",
+ [](PyAttribute &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, attribute values are generally considered useful and
+ // are printed. This may need to be re-evaluated if debug dumps end
+ // up being excessive.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Attribute(");
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the attribute.")
+ .def_prop_ro(
+ "typeid",
+ [](PyAttribute &self) {
+ MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ return PyTypeID(mlirTypeID);
+ },
+ "Returns the `TypeID` of the attribute.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the attribute to a more specific attribute if possible.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyNamedAttribute
+ //----------------------------------------------------------------------------
+ nb::class_<PyNamedAttribute>(m, "NamedAttribute")
+ .def(
+ "__repr__",
+ [](PyNamedAttribute &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("NamedAttribute(");
+ printAccum.parts.append(
+ nb::str(mlirIdentifierStr(self.namedAttr.name).data,
+ mlirIdentifierStr(self.namedAttr.name).length));
+ printAccum.parts.append("=");
+ mlirAttributePrint(self.namedAttr.attribute,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the named attribute.")
+ .def_prop_ro(
+ "name",
+ [](PyNamedAttribute &self) {
+ return mlirIdentifierStr(self.namedAttr.name);
+ },
+ "The name of the `NamedAttribute` binding.")
+ .def_prop_ro(
+ "attr",
+ [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
+ nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
+ "The underlying generic attribute of the `NamedAttribute` binding.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyType.
+ //----------------------------------------------------------------------------
+ nb::class_<PyType>(m, "Type")
+ // Delegate to the PyType copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirType.
+ .def(nb::init<PyType &>(), "cast_from_type"_a,
+ "Casts the passed type to the generic `Type`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
+ "Gets a capsule wrapping the `MlirType`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
+ "Creates a Type from a capsule wrapping `MlirType`.")
+ .def_static(
+ "parse",
+ [](std::string typeSpec,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type =
+ mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Unable to parse type", errors.take());
+ return PyType(context.get()->getRef(), type).maybeDownCast();
+ },
+ "asm"_a, "context"_a = nb::none(),
+ R"(
+ Parses the assembly form of a type.
+
+ Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
+
+ See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
+ .def_prop_ro(
+ "context",
+ [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
+ "Context that owns the `Type`.")
+ .def(
+ "__eq__", [](PyType &self, PyType &other) { return self == other; },
+ "Compares two types for equality.")
+ .def(
+ "__eq__", [](PyType &self, nb::object &other) { return false; },
+ "other"_a.none(),
+ "Compares type with non-type object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyType &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the `Type`.")
+ .def(
+ "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyType &self) {
+ PyPrintAccumulator printAccum;
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly form of the `Type`.")
+ .def(
+ "__repr__",
+ [](PyType &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, types are an exception as they typically have compact
+ // assembly forms and printing them is useful.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Type(");
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the `Type`.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) -> nb::typed<nb::object, PyType> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the Type to a more specific `Type` if possible.")
+ .def_prop_ro(
+ "typeid",
+ [](PyType &self) {
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+ if (!mlirTypeIDIsNull(mlirTypeID))
+ return PyTypeID(mlirTypeID);
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
+ throw nb::value_error(
+ (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
+ },
+ "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
+ "`Type` has no "
+ "`TypeID`.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyTypeID.
+ //----------------------------------------------------------------------------
+ nb::class_<PyTypeID>(m, "TypeID")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
+ "Gets a capsule wrapping the `MlirTypeID`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
+ "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
+ // Note, this tests whether the underlying TypeIDs are the same,
+ // not whether the wrapper MlirTypeIDs are the same, nor whether
+ // the Python objects are the same (i.e., PyTypeID is a value type).
+ .def(
+ "__eq__",
+ [](PyTypeID &self, PyTypeID &other) { return self == other; },
+ "Compares two `TypeID`s for equality.")
+ .def(
+ "__eq__",
+ [](PyTypeID &self, const nb::object &other) { return false; },
+ "Compares TypeID with non-TypeID object (always returns False).")
+ // Note, this gives the hash value of the underlying TypeID, not the
+ // hash value of the Python object, nor the hash value of the
+ // MlirTypeID wrapper.
+ .def(
+ "__hash__",
+ [](PyTypeID &self) {
+ return static_cast<size_t>(mlirTypeIDHashValue(self));
+ },
+ "Returns the hash value of the `TypeID`.");
+
+ //----------------------------------------------------------------------------
+ // Mapping of Value.
+ //----------------------------------------------------------------------------
+ m.attr("_T") = nb::type_var("_T", "bound"_a = m.attr("Type"));
+
+ nb::class_<PyValue>(m, "Value", nb::is_generic(),
+ nb::sig("class Value(Generic[_T])"))
+ .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), "value"_a,
+ "Creates a Value reference from another `Value`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
+ "Gets a capsule wrapping the `MlirValue`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
+ "Creates a `Value` from a capsule wrapping `MlirValue`.")
+ .def_prop_ro(
+ "context",
+ [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getParentOperation()->getContext().getObject();
+ },
+ "Context in which the value lives.")
+ .def(
+ "dump", [](PyValue &self) { mlirValueDump(self.get()); },
+ kDumpDocstring)
+ .def_prop_ro(
+ "owner",
+ [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
+ MlirValue v = self.get();
+ if (mlirValueIsAOpResult(v)) {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match "
+ "that in "
+ "the IR");
+ return self.getParentOperation()->createOpView();
+ }
+
+ if (mlirValueIsABlockArgument(v)) {
+ MlirBlock block = mlirBlockArgumentGetOwner(self.get());
+ return nb::cast(PyBlock(self.getParentOperation(), block));
+ }
+
+ assert(false && "Value must be a block argument or an op result");
+ return nb::none();
+ },
+ "Returns the owner of the value (`Operation` for results, `Block` "
+ "for "
+ "arguments).")
+ .def_prop_ro(
+ "uses",
+ [](PyValue &self) {
+ return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
+ },
+ "Returns an iterator over uses of this value.")
+ .def(
+ "__eq__",
+ [](PyValue &self, PyValue &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two values for pointer equality.")
+ .def(
+ "__eq__", [](PyValue &self, nb::object other) { return false; },
+ "Compares value with non-value object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyValue &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the value.")
+ .def(
+ "__str__",
+ [](PyValue &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Value(");
+ mlirValuePrint(self.get(), printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ R"(
+ Returns the string form of the value.
+
+ If the value is a block argument, this is the assembly form of its type and the
+ position in the argument list. If the value is an operation result, this is
+ equivalent to printing the operation that produced it.
+ )")
+ .def(
+ "get_name",
+ [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
+ PyPrintAccumulator printAccum;
+ MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ if (useNameLocAsPrefix)
+ mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
+ MlirAsmState valueState =
+ mlirAsmStateCreateForValue(self.get(), flags);
+ mlirValuePrintAsOperand(self.get(), valueState,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ mlirOpPrintingFlagsDestroy(flags);
+ mlirAsmStateDestroy(valueState);
+ return printAccum.join();
+ },
+ "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false,
+ R"(
+ Returns the string form of value as an operand.
+
+ Args:
+ use_local_scope: Whether to use local scope for naming.
+ use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
+
+ Returns:
+ The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
+ .def(
+ "get_name",
+ [](PyValue &self, PyAsmState &state) {
+ PyPrintAccumulator printAccum;
+ MlirAsmState valueState = state.get();
+ mlirValuePrintAsOperand(self.get(), valueState,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "state"_a,
+ "Returns the string form of value as an operand (i.e., the ValueID).")
+ .def_prop_ro(
+ "type",
+ [](PyValue &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getParentOperation()->getContext(),
+ mlirValueGetType(self.get()))
+ .maybeDownCast();
+ },
+ "Returns the type of the value.")
+ .def(
+ "set_type",
+ [](PyValue &self, const PyType &type) {
+ mlirValueSetType(self.get(), type);
+ },
+ "type"_a, "Sets the type of the value.",
+ nb::sig("def set_type(self, type: _T)"))
+ .def(
+ "replace_all_uses_with",
+ [](PyValue &self, PyValue &with) {
+ mlirValueReplaceAllUsesOfWith(self.get(), with.get());
+ },
+ "Replace all uses of value with the new value, updating anything in "
+ "the IR that uses `self` to use the other value instead.")
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, const nb::list &exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (nb::handle exception : exceptions) {
+ exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
+ }
+
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with,
+ std::vector<PyOperation> &exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (PyOperation &exception : exceptions)
+ exceptionOps.push_back(exception);
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the `Value` to a more specific kind if possible.")
+ .def_prop_ro(
+ "location",
+ [](MlirValue self) {
+ return PyLocation(
+ PyMlirContext::forContext(mlirValueGetContext(self)),
+ mlirValueGetLocation(self));
+ },
+ "Returns the source location of the value.");
+
+ PyBlockArgument::bind(m);
+ PyOpResult::bind(m);
+ PyOpOperand::bind(m);
+
+ nb::class_<PyAsmState>(m, "AsmState")
+ .def(nb::init<PyValue &, bool>(), "value"_a, "use_local_scope"_a = false,
+ R"(
+ Creates an `AsmState` for consistent SSA value naming.
+
+ Args:
+ value: The value to create state for.
+ use_local_scope: Whether to use local scope for naming.)")
+ .def(nb::init<PyOperationBase &, bool>(), "op"_a,
+ "use_local_scope"_a = false,
+ R"(
+ Creates an AsmState for consistent SSA value naming.
+
+ Args:
+ op: The operation to create state for.
+ use_local_scope: Whether to use local scope for naming.)");
+
+ //----------------------------------------------------------------------------
+ // Mapping of SymbolTable.
+ //----------------------------------------------------------------------------
+ nb::class_<PySymbolTable>(m, "SymbolTable")
+ .def(nb::init<PyOperationBase &>(),
+ R"(
+ Creates a symbol table for an operation.
+
+ Args:
+ operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
+
+ Raises:
+ TypeError: If the operation is not a symbol table.)")
+ .def(
+ "__getitem__",
+ [](PySymbolTable &self,
+ const std::string &name) -> nb::typed<nb::object, PyOpView> {
+ return self.dunderGetItem(name);
+ },
+ R"(
+ Looks up a symbol by name in the symbol table.
+
+ Args:
+ name: The name of the symbol to look up.
+
+ Returns:
+ The operation defining the symbol.
+
+ Raises:
+ KeyError: If the symbol is not found.)")
+ .def("insert", &PySymbolTable::insert, "operation"_a,
+ R"(
+ Inserts a symbol operation into the symbol table.
+
+ Args:
+ operation: An operation with a symbol name to insert.
+
+ Returns:
+ The symbol name attribute of the inserted operation.
+
+ Raises:
+ ValueError: If the operation does not have a symbol name.)")
+ .def("erase", &PySymbolTable::erase, "operation"_a,
+ R"(
+ Erases a symbol operation from the symbol table.
+
+ Args:
+ operation: The symbol operation to erase.
+
+ Note:
+ The operation is also erased from the IR and invalidated.)")
+ .def("__delitem__", &PySymbolTable::dunderDel,
+ "Deletes a symbol by name from the symbol table.")
+ .def(
+ "__contains__",
+ [](PySymbolTable &table, const std::string &name) {
+ return !mlirOperationIsNull(mlirSymbolTableLookup(
+ table, mlirStringRefCreate(name.data(), name.length())));
+ },
+ "Checks if a symbol with the given name exists in the table.")
+ // Static helpers.
+ .def_static("set_symbol_name", &PySymbolTable::setSymbolName, "symbol"_a,
+ "name"_a, "Sets the symbol name for a symbol operation.")
+ .def_static("get_symbol_name", &PySymbolTable::getSymbolName, "symbol"_a,
+ "Gets the symbol name from a symbol operation.")
+ .def_static("get_visibility", &PySymbolTable::getVisibility, "symbol"_a,
+ "Gets the visibility attribute of a symbol operation.")
+ .def_static("set_visibility", &PySymbolTable::setVisibility, "symbol"_a,
+ "visibility"_a,
+ "Sets the visibility attribute of a symbol operation.")
+ .def_static("replace_all_symbol_uses",
+ &PySymbolTable::replaceAllSymbolUses, "old_symbol"_a,
+ "new_symbol"_a, "from_op"_a,
+ "Replaces all uses of a symbol with a new symbol name within "
+ "the given operation.")
+ .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
+ "from_op"_a, "all_sym_uses_visible"_a, "callback"_a,
+ "Walks symbol tables starting from an operation with a "
+ "callback function.");
+
+ // Container bindings.
+ PyBlockArgumentList::bind(m);
+ PyBlockIterator::bind(m);
+ PyBlockList::bind(m);
+ PyBlockSuccessors::bind(m);
+ PyBlockPredecessors::bind(m);
+ PyOperationIterator::bind(m);
+ PyOperationList::bind(m);
+ PyOpAttributeMap::bind(m);
+ PyOpOperandIterator::bind(m);
+ PyOpOperandList::bind(m);
+ PyOpResultList::bind(m);
+ PyOpSuccessors::bind(m);
+ PyRegionIterator::bind(m);
+ PyRegionList::bind(m);
+
+ // Debug bindings.
+ PyGlobalDebugFlag::bind(m);
+
+ // Attribute builder getter.
+ PyAttrBuilderMap::bind(m);
+}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index b49a9f1e3af24..b2c9380bc1d73 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -11,2679 +11,32 @@
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindUtils.h"
namespace nb = nanobind;
-using namespace mlir;
-using namespace nb::literals;
using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
-static const char kModuleParseDocstring[] =
- R"(Parses a module's assembly format from a string.
-
-Returns a new MlirModule or raises an MLIRError if the parsing fails.
-
-See also: https://mlir.llvm.org/docs/LangRef/
-)";
-
-static const char kDumpDocstring[] =
- "Dumps a debug representation of the object to stderr.";
-
-static const char kValueReplaceAllUsesExceptDocstring[] =
- R"(Replace all uses of this value with the `with` value, except for those
-in `exceptions`. `exceptions` can be either a single operation or a list of
-operations.
-)";
-
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-/// Wrapper for the global LLVM debugging flag.
-struct PyGlobalDebugFlag {
- static void set(nanobind::object &o, bool enable) {
- nanobind::ft_lock_guard lock(mutex);
- mlirEnableGlobalDebug(enable);
- }
-
- static bool get(const nanobind::object &) {
- nanobind::ft_lock_guard lock(mutex);
- return mlirIsGlobalDebugEnabled();
- }
-
- static void bind(nanobind::module_ &m) {
- // Debug flags.
- nanobind::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
- .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
- .def_static(
- "set_types",
- [](const std::string &type) {
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugType(type.c_str());
- },
- nanobind::arg("types"),
- "Sets specific debug types to be produced by LLVM.")
- .def_static(
- "set_types",
- [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nanobind::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- },
- nanobind::arg("types"),
- "Sets multiple specific debug types to be produced by LLVM.");
- }
-
-private:
- static nanobind::ft_mutex mutex;
-};
-nanobind::ft_mutex PyGlobalDebugFlag::mutex;
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
-
-namespace {
-// see
-// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
-
-#ifndef _Py_CAST
-#define _Py_CAST(type, expr) ((type)(expr))
-#endif
-
-// Static inline functions should use _Py_NULL rather than using directly NULL
-// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
-// _Py_NULL is defined as nullptr.
-#ifndef _Py_NULL
-#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
- (defined(__cplusplus) && __cplusplus >= 201103)
-#define _Py_NULL nullptr
-#else
-#define _Py_NULL NULL
-#endif
-#endif
-
-// Python 3.10.0a3
-#if PY_VERSION_HEX < 0x030A00A3
-
-// bpo-42262 added Py_XNewRef()
-#if !defined(Py_XNewRef)
-[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
- Py_XINCREF(obj);
- return obj;
-}
-#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
-#endif
-
-// bpo-42262 added Py_NewRef()
-#if !defined(Py_NewRef)
-[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
- Py_INCREF(obj);
- return obj;
-}
-#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
-#endif
-
-#endif // Python 3.10.0a3
-
-// Python 3.9.0b1
-#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
-
-// bpo-40429 added PyThreadState_GetFrame()
-PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
- assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
- return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
-}
-
-// bpo-40421 added PyFrame_GetBack()
-PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
- assert(frame != _Py_NULL && "expected frame != _Py_NULL");
- return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
-}
-
-// bpo-40421 added PyFrame_GetCode()
-PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
- assert(frame != _Py_NULL && "expected frame != _Py_NULL");
- assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
- return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
-}
-
-#endif // Python 3.9.0b1
-
-MlirLocation tracebackToLocation(MlirContext ctx) {
- size_t framesLimit =
- PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
- // Use a thread_local here to avoid requiring a large amount of space.
- thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
- frames;
- size_t count = 0;
-
- nb::gil_scoped_acquire acquire;
- PyThreadState *tstate = PyThreadState_GET();
- PyFrameObject *next;
- PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
- // In the increment expression:
- // 1. get the next prev frame;
- // 2. decrement the ref count on the current frame (in order that it can get
- // gc'd, along with any objects in its closure and etc);
- // 3. set current = next.
- for (; pyFrame != nullptr && count < framesLimit;
- next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
- PyCodeObject *code = PyFrame_GetCode(pyFrame);
- auto fileNameStr =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
- llvm::StringRef fileName(fileNameStr);
- if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
- continue;
-
- // co_qualname and PyCode_Addr2Location added in py3.11
-#if PY_VERSION_HEX < 0x030B00F0
- std::string name =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
- llvm::StringRef funcName(name);
- int startLine = PyFrame_GetLineNumber(pyFrame);
- MlirLocation loc =
- mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
-#else
- std::string name =
- nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
- llvm::StringRef funcName(name);
- int startLine, startCol, endLine, endCol;
- int lasti = PyFrame_GetLasti(pyFrame);
- if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
- &endCol)) {
- throw nb::python_error();
- }
- MlirLocation loc = mlirLocationFileLineColRangeGet(
- ctx, wrap(fileName), startLine, startCol, endLine, endCol);
-#endif
-
- frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
- ++count;
- }
- // When the loop breaks (after the last iter), current frame (if non-null)
- // is leaked without this.
- Py_XDECREF(pyFrame);
-
- if (count == 0)
- return mlirLocationUnknownGet(ctx);
-
- MlirLocation callee = frames[0];
- assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
- if (count == 1)
- return callee;
-
- MlirLocation caller = frames[count - 1];
- assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
- for (int i = count - 2; i >= 1; i--)
- caller = mlirLocationCallSiteGet(frames[i], caller);
-
- return mlirLocationCallSiteGet(callee, caller);
-}
-
-PyLocation
-maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
- if (location.has_value())
- return location.value();
- if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
- return DefaultingPyLocation::resolve();
-
- PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
- MlirLocation mlirLoc = tracebackToLocation(ctx.get());
- PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
- return {ref, mlirLoc};
-}
-
-/// Helper for creating an @classmethod.
-template <class Func, typename... Args>
-nanobind::object classmethod(Func f, Args... args) {
- nanobind::object cf = nanobind::cpp_function(f, args...);
- return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
-}
-
-nanobind::object
-createCustomDialectWrapper(const std::string &dialectNamespace,
- nanobind::object dialectDescriptor) {
- auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
- if (!dialectClass) {
- // Use the base class.
- return nanobind::cast(PyDialect(std::move(dialectDescriptor)));
- }
-
- // Create the custom implementation.
- return (*dialectClass)(std::move(dialectDescriptor));
-}
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// NB: all bind and bindDerived methods need to reside in the same
-// binary/extension as the NB_MODULE macro/call. This is because
-// nb_internals *internals within the non-unique nanobind::detail (i.e., the
-// same namespace for all bindings packages).
-//===----------------------------------------------------------------------===//
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-void PyRegionList::bindDerived(ClassTy &c) {
- c.def("__iter__", &PyRegionList::dunderIter,
- "Returns an iterator over the regions in the sequence.");
-}
-
-void PyOpResult::bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "owner",
- [](PyOpResult &self) -> nanobind::typed<nanobind::object, PyOpView> {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in "
- "the IR");
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation that produces this result.");
- c.def_prop_ro(
- "result_number",
- [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); },
- "Returns the position of this result in the operation's result list.");
-}
-
-void PyOpResultList::bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all results in this result list.");
- c.def_prop_ro(
- "owner",
- [](PyOpResultList &self) -> nanobind::typed<nanobind::object, PyOpView> {
- return self.operation->createOpView();
- },
- "Returns the operation that owns this result list.");
-}
-
-void PyOpOperandList::bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpOperandList::dunderSetItem, nanobind::arg("index"),
- nanobind::arg("value"),
- "Sets the operand at the specified index to a new value.");
-}
-
-void PyOpSuccessors::bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nanobind::arg("index"),
- nanobind::arg("block"),
- "Sets the successor block at the specified index.");
-}
-
-void PyBlockArgumentList::bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "types",
- [](PyBlockArgumentList &self) {
- return getValueTypes(self, self.operation->getContext());
- },
- "Returns a list of types for all arguments in this argument list.");
-}
-
-void PyAttrBuilderMap::bind(nanobind::module_ &m) {
- nanobind::class_<PyAttrBuilderMap>(m, "AttrBuilder")
- .def_static("contains", &PyAttrBuilderMap::dunderContains,
- nanobind::arg("attribute_kind"),
- "Checks whether an attribute builder is registered for the "
- "given attribute kind.")
- .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
- nanobind::arg("attribute_kind"),
- "Gets the registered attribute builder for the given "
- "attribute kind.")
- .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
- nanobind::arg("attribute_kind"),
- nanobind::arg("attr_builder"),
- nanobind::arg("replace") = false,
- "Register an attribute builder for building MLIR "
- "attributes from Python values.");
-}
-
-void PyRegionIterator::bind(nanobind::module_ &m) {
- nanobind::class_<PyRegionIterator>(m, "RegionIterator")
- .def("__iter__", &PyRegionIterator::dunderIter,
- "Returns an iterator over the regions in the operation.")
- .def("__next__", &PyRegionIterator::dunderNext,
- "Returns the next region in the iteration.");
-}
-
-void PyBlockIterator::bind(nanobind::module_ &m) {
- nanobind::class_<PyBlockIterator>(m, "BlockIterator")
- .def("__iter__", &PyBlockIterator::dunderIter,
- "Returns an iterator over the blocks in the operation's region.")
- .def("__next__", &PyBlockIterator::dunderNext,
- "Returns the next block in the iteration.");
-}
-
-void PyBlockList::bind(nanobind::module_ &m) {
- nanobind::class_<PyBlockList>(m, "BlockList")
- .def("__getitem__", &PyBlockList::dunderGetItem,
- "Returns the block at the specified index.")
- .def("__iter__", &PyBlockList::dunderIter,
- "Returns an iterator over blocks in the operation's region.")
- .def("__len__", &PyBlockList::dunderLen,
- "Returns the number of blocks in the operation's region.")
- .def("append", &PyBlockList::appendBlock,
- R"(
- Appends a new block, with argument types as positional args.
-
- Returns:
- The created block.
- )",
- nanobind::arg("args"), nanobind::kw_only(),
- nanobind::arg("arg_locs") = std::nullopt);
-}
-
-void PyOperationIterator::bind(nanobind::module_ &m) {
- nanobind::class_<PyOperationIterator>(m, "OperationIterator")
- .def("__iter__", &PyOperationIterator::dunderIter,
- "Returns an iterator over the operations in an operation's block.")
- .def("__next__", &PyOperationIterator::dunderNext,
- "Returns the next operation in the iteration.");
-}
-
-void PyOperationList::bind(nanobind::module_ &m) {
- nanobind::class_<PyOperationList>(m, "OperationList")
- .def("__getitem__", &PyOperationList::dunderGetItem,
- "Returns the operation at the specified index.")
- .def("__iter__", &PyOperationList::dunderIter,
- "Returns an iterator over operations in the list.")
- .def("__len__", &PyOperationList::dunderLen,
- "Returns the number of operations in the list.");
-}
-
-void PyOpOperand::bind(nanobind::module_ &m) {
- nanobind::class_<PyOpOperand>(m, "OpOperand")
- .def_prop_ro("owner", &PyOpOperand::getOwner,
- "Returns the operation that owns this operand.")
- .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
- "Returns the operand number in the owning operation.");
-}
-
-void PyOpOperandIterator::bind(nanobind::module_ &m) {
- nanobind::class_<PyOpOperandIterator>(m, "OpOperandIterator")
- .def("__iter__", &PyOpOperandIterator::dunderIter,
- "Returns an iterator over operands.")
- .def("__next__", &PyOpOperandIterator::dunderNext,
- "Returns the next operand in the iteration.");
-}
-
-void PyOpAttributeMap::bind(nanobind::module_ &m) {
- nanobind::class_<PyOpAttributeMap>(m, "OpAttributeMap")
- .def("__contains__", &PyOpAttributeMap::dunderContains,
- nanobind::arg("name"),
- "Checks if an attribute with the given name exists in the map.")
- .def("__len__", &PyOpAttributeMap::dunderLen,
- "Returns the number of attributes in the map.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
- nanobind::arg("name"), "Gets an attribute by name.")
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
- nanobind::arg("index"), "Gets a named attribute by index.")
- .def("__setitem__", &PyOpAttributeMap::dunderSetItem,
- nanobind::arg("name"), nanobind::arg("attr"),
- "Sets an attribute with the given name.")
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem,
- nanobind::arg("name"), "Deletes an attribute with the given name.")
- .def(
- "__iter__",
- [](PyOpAttributeMap &self) {
- nanobind::list keys;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
- keys.append(nanobind::str(name.data, name.length));
- });
- return nanobind::iter(keys);
- },
- "Iterates over attribute names.")
- .def(
- "keys",
- [](PyOpAttributeMap &self) {
- nanobind::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
- out.append(nanobind::str(name.data, name.length));
- });
- return out;
- },
- "Returns a list of attribute names.")
- .def(
- "values",
- [](PyOpAttributeMap &self) {
- nanobind::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
- out.append(PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast());
- });
- return out;
- },
- "Returns a list of attribute values.")
- .def(
- "items",
- [](PyOpAttributeMap &self) {
- nanobind::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute attr) {
- out.append(nanobind::make_tuple(
- nanobind::str(name.data, name.length),
- PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast()));
- });
- return out;
- },
- "Returns a list of `(name, attribute)` tuples.");
-}
-
void populateIRAffine(nb::module_ &m);
void populateIRAttributes(nb::module_ &m);
void populateIRInterfaces(nb::module_ &m);
void populateIRTypes(nb::module_ &m);
+void populateIRCore(nb::module_ &m);
+void populateRoot(nb::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
-//------------------------------------------------------------------------------
-// Populates the core exports of the 'ir' submodule.
-//------------------------------------------------------------------------------
-
-static void populateIRCore(nb::module_ &m) {
- // disable leak warnings which tend to be false positives.
- nb::set_leak_warnings(false);
- //----------------------------------------------------------------------------
- // Enums.
- //----------------------------------------------------------------------------
- nb::enum_<PyMlirDiagnosticSeverity>(m, "DiagnosticSeverity")
- .value("ERROR", PyMlirDiagnosticSeverity::MlirDiagnosticError)
- .value("WARNING", PyMlirDiagnosticSeverity::MlirDiagnosticWarning)
- .value("NOTE", PyMlirDiagnosticSeverity::MlirDiagnosticNote)
- .value("REMARK", PyMlirDiagnosticSeverity::MlirDiagnosticRemark);
-
- nb::enum_<PyMlirWalkOrder>(m, "WalkOrder")
- .value("PRE_ORDER", PyMlirWalkOrder::MlirWalkPreOrder)
- .value("POST_ORDER", PyMlirWalkOrder::MlirWalkPostOrder);
-
- nb::enum_<PyMlirWalkResult>(m, "WalkResult")
- .value("ADVANCE", PyMlirWalkResult::MlirWalkResultAdvance)
- .value("INTERRUPT", PyMlirWalkResult::MlirWalkResultInterrupt)
- .value("SKIP", PyMlirWalkResult::MlirWalkResultSkip);
-
- //----------------------------------------------------------------------------
- // Mapping of Diagnostics.
- //----------------------------------------------------------------------------
- nb::class_<PyDiagnostic>(m, "Diagnostic")
- .def_prop_ro("severity", &PyDiagnostic::getSeverity,
- "Returns the severity of the diagnostic.")
- .def_prop_ro("location", &PyDiagnostic::getLocation,
- "Returns the location associated with the diagnostic.")
- .def_prop_ro("message", &PyDiagnostic::getMessage,
- "Returns the message text of the diagnostic.")
- .def_prop_ro("notes", &PyDiagnostic::getNotes,
- "Returns a tuple of attached note diagnostics.")
- .def(
- "__str__",
- [](PyDiagnostic &self) -> nb::str {
- if (!self.isValid())
- return nb::str("<Invalid Diagnostic>");
- return self.getMessage();
- },
- "Returns the diagnostic message as a string.");
-
- nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
- .def(
- "__init__",
- [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
- new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
- },
- "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
- .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
- "The severity level of the diagnostic.")
- .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
- "The location associated with the diagnostic.")
- .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
- "The message text of the diagnostic.")
- .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
- "List of attached note diagnostics.")
- .def(
- "__str__",
- [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
- "Returns the diagnostic message as a string.");
-
- nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
- .def("detach", &PyDiagnosticHandler::detach,
- "Detaches the diagnostic handler from the context.")
- .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
- "Returns True if the handler is attached to a context.")
- .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
- "Returns True if an error was encountered during diagnostic "
- "handling.")
- .def("__enter__", &PyDiagnosticHandler::contextEnter,
- "Enters the diagnostic handler as a context manager.")
- .def("__exit__", &PyDiagnosticHandler::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
- "Exits the diagnostic handler context manager.");
-
- // Expose DefaultThreadPool to python
- nb::class_<PyThreadPool>(m, "ThreadPool")
- .def(
- "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
- "Creates a new thread pool with default concurrency.")
- .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
- "Returns the maximum number of threads in the pool.")
- .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
- "Returns the raw pointer to the LLVM thread pool as a string.");
-
- nb::class_<PyMlirContext>(m, "Context")
- .def(
- "__init__",
- [](PyMlirContext &self) {
- MlirContext context = mlirContextCreateWithThreading(false);
- new (&self) PyMlirContext(context);
- },
- R"(
- Creates a new MLIR context.
-
- The context is the top-level container for all MLIR objects. It owns the storage
- for types, attributes, locations, and other core IR objects. A context can be
- configured to allow or disallow unregistered dialects and can have dialects
- loaded on-demand.)")
- .def_static("_get_live_count", &PyMlirContext::getLiveCount,
- "Gets the number of live Context objects.")
- .def(
- "_get_context_again",
- [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
- PyMlirContextRef ref = PyMlirContext::forContext(self.get());
- return ref.releaseObject();
- },
- "Gets another reference to the same context.")
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
- "Gets the number of live modules owned by this context.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
- "Gets a capsule wrapping the MlirContext.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyMlirContext::createFromCapsule,
- "Creates a Context from a capsule wrapping MlirContext.")
- .def("__enter__", &PyMlirContext::contextEnter,
- "Enters the context as a context manager.")
- .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
- "Exits the context manager.")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/)
- -> std::optional<nb::typed<nb::object, PyMlirContext>> {
- auto *context = PyThreadContextEntry::getDefaultContext();
- if (!context)
- return {};
- return nb::cast(context);
- },
- nb::sig("def current(/) -> Context | None"),
- "Gets the Context bound to the current thread or returns None if no "
- "context is set.")
- .def_prop_ro(
- "dialects",
- [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Gets a container for accessing dialects by name.")
- .def_prop_ro(
- "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Alias for `dialects`.")
- .def(
- "get_dialect_descriptor",
- [=](PyMlirContext &self, std::string &name) {
- MlirDialect dialect = mlirContextGetOrLoadDialect(
- self.get(), {name.data(), name.size()});
- if (mlirDialectIsNull(dialect)) {
- throw nb::value_error(
- (Twine("Dialect '") + name + "' not found").str().c_str());
- }
- return PyDialectDescriptor(self.getRef(), dialect);
- },
- nb::arg("dialect_name"),
- "Gets or loads a dialect by name, returning its descriptor object.")
- .def_prop_rw(
- "allow_unregistered_dialects",
- [](PyMlirContext &self) -> bool {
- return mlirContextGetAllowUnregisteredDialects(self.get());
- },
- [](PyMlirContext &self, bool value) {
- mlirContextSetAllowUnregisteredDialects(self.get(), value);
- },
- "Controls whether unregistered dialects are allowed in this context.")
- .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
- nb::arg("callback"),
- "Attaches a diagnostic handler that will receive callbacks.")
- .def(
- "enable_multithreading",
- [](PyMlirContext &self, bool enable) {
- mlirContextEnableMultithreading(self.get(), enable);
- },
- nb::arg("enable"),
- R"(
- Enables or disables multi-threading support in the context.
-
- Args:
- enable: Whether to enable (True) or disable (False) multi-threading.
- )")
- .def(
- "set_thread_pool",
- [](PyMlirContext &self, PyThreadPool &pool) {
- // we should disable multi-threading first before setting
- // new thread pool otherwise the assert in
- // MLIRContext::setThreadPool will be raised.
- mlirContextEnableMultithreading(self.get(), false);
- mlirContextSetThreadPool(self.get(), pool.get());
- },
- R"(
- Sets a custom thread pool for the context to use.
-
- Args:
- pool: A ThreadPool object to use for parallel operations.
-
- Note:
- Multi-threading is automatically disabled before setting the thread pool.)")
- .def(
- "get_num_threads",
- [](PyMlirContext &self) {
- return mlirContextGetNumThreads(self.get());
- },
- "Gets the number of threads in the context's thread pool.")
- .def(
- "_mlir_thread_pool_ptr",
- [](PyMlirContext &self) {
- MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
- std::stringstream ss;
- ss << pool.ptr;
- return ss.str();
- },
- "Gets the raw pointer to the LLVM thread pool as a string.")
- .def(
- "is_registered_operation",
- [](PyMlirContext &self, std::string &name) {
- return mlirContextIsRegisteredOperation(
- self.get(), MlirStringRef{name.data(), name.size()});
- },
- nb::arg("operation_name"),
- R"(
- Checks whether an operation with the given name is registered.
-
- Args:
- operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
-
- Returns:
- True if the operation is registered, False otherwise.)")
- .def(
- "append_dialect_registry",
- [](PyMlirContext &self, PyDialectRegistry ®istry) {
- mlirContextAppendDialectRegistry(self.get(), registry);
- },
- nb::arg("registry"),
- R"(
- Appends the contents of a dialect registry to the context.
-
- Args:
- registry: A DialectRegistry containing dialects to append.)")
- .def_prop_rw("emit_error_diagnostics",
- &PyMlirContext::getEmitErrorDiagnostics,
- &PyMlirContext::setEmitErrorDiagnostics,
- R"(
- Controls whether error diagnostics are emitted to diagnostic handlers.
-
- By default, error diagnostics are captured and reported through MLIRError exceptions.)")
- .def(
- "load_all_available_dialects",
- [](PyMlirContext &self) {
- mlirContextLoadAllAvailableDialects(self.get());
- },
- R"(
- Loads all dialects available in the registry into the context.
-
- This eagerly loads all dialects that have been registered, making them
- immediately available for use.)");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectDescriptor
- //----------------------------------------------------------------------------
- nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
- .def_prop_ro(
- "namespace",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- return nb::str(ns.data, ns.length);
- },
- "Returns the namespace of the dialect.")
- .def(
- "__repr__",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- std::string repr("<DialectDescriptor ");
- repr.append(ns.data, ns.length);
- repr.append(">");
- return repr;
- },
- nb::sig("def __repr__(self) -> str"),
- "Returns a string representation of the dialect descriptor.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialects
- //----------------------------------------------------------------------------
- nb::class_<PyDialects>(m, "Dialects")
- .def(
- "__getitem__",
- [=](PyDialects &self, std::string keyName) {
- MlirDialect dialect =
- self.getDialectForKey(keyName, /*attrError=*/false);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(keyName, std::move(descriptor));
- },
- "Gets a dialect by name using subscript notation.")
- .def(
- "__getattr__",
- [=](PyDialects &self, std::string attrName) {
- MlirDialect dialect =
- self.getDialectForKey(attrName, /*attrError=*/true);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(attrName, std::move(descriptor));
- },
- "Gets a dialect by name using attribute notation.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialect
- //----------------------------------------------------------------------------
- nb::class_<PyDialect>(m, "Dialect")
- .def(nb::init<nb::object>(), nb::arg("descriptor"),
- "Creates a Dialect from a DialectDescriptor.")
- .def_prop_ro(
- "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
- "Returns the DialectDescriptor for this dialect.")
- .def(
- "__repr__",
- [](const nb::object &self) {
- auto clazz = self.attr("__class__");
- return nb::str("<Dialect ") +
- self.attr("descriptor").attr("namespace") +
- nb::str(" (class ") + clazz.attr("__module__") +
- nb::str(".") + clazz.attr("__name__") + nb::str(")>");
- },
- nb::sig("def __repr__(self) -> str"),
- "Returns a string representation of the dialect.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectRegistry
- //----------------------------------------------------------------------------
- nb::class_<PyDialectRegistry>(m, "DialectRegistry")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
- "Gets a capsule wrapping the MlirDialectRegistry.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyDialectRegistry::createFromCapsule,
- "Creates a DialectRegistry from a capsule wrapping "
- "`MlirDialectRegistry`.")
- .def(nb::init<>(), "Creates a new empty dialect registry.");
-
- //----------------------------------------------------------------------------
- // Mapping of Location
- //----------------------------------------------------------------------------
- nb::class_<PyLocation>(m, "Location")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
- "Gets a capsule wrapping the MlirLocation.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
- "Creates a Location from a capsule wrapping MlirLocation.")
- .def("__enter__", &PyLocation::contextEnter,
- "Enters the location as a context manager.")
- .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none(),
- "Exits the location context manager.")
- .def(
- "__eq__",
- [](PyLocation &self, PyLocation &other) -> bool {
- return mlirLocationEqual(self, other);
- },
- "Compares two locations for equality.")
- .def(
- "__eq__", [](PyLocation &self, nb::object other) { return false; },
- "Compares location with non-location object (always returns False).")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/) -> std::optional<PyLocation *> {
- auto *loc = PyThreadContextEntry::getDefaultLocation();
- if (!loc)
- return std::nullopt;
- return loc;
- },
- // clang-format off
- nb::sig("def current(/) -> Location | None"),
- // clang-format on
- "Gets the Location bound to the current thread or raises ValueError.")
- .def_static(
- "unknown",
- [](DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationUnknownGet(context->get()));
- },
- nb::arg("context") = nb::none(),
- "Gets a Location representing an unknown location.")
- .def_static(
- "callsite",
- [](PyLocation callee, const std::vector<PyLocation> &frames,
- DefaultingPyMlirContext context) {
- if (frames.empty())
- throw nb::value_error("No caller frames provided.");
- MlirLocation caller = frames.back().get();
- for (const PyLocation &frame :
- llvm::reverse(llvm::ArrayRef(frames).drop_back()))
- caller = mlirLocationCallSiteGet(frame.get(), caller);
- return PyLocation(context->getRef(),
- mlirLocationCallSiteGet(callee.get(), caller));
- },
- nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
- "Gets a Location representing a caller and callsite.")
- .def("is_a_callsite", mlirLocationIsACallSite,
- "Returns True if this location is a CallSiteLoc.")
- .def_prop_ro(
- "callee",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCallee(self));
- },
- "Gets the callee location from a CallSiteLoc.")
- .def_prop_ro(
- "caller",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCaller(self));
- },
- "Gets the caller location from a CallSiteLoc.")
- .def_static(
- "file",
- [](std::string filename, int line, int col,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationFileLineColGet(
- context->get(), toMlirStringRef(filename), line, col));
- },
- nb::arg("filename"), nb::arg("line"), nb::arg("col"),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a file, line and column.")
- .def_static(
- "file",
- [](std::string filename, int startLine, int startCol, int endLine,
- int endCol, DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFileLineColRangeGet(
- context->get(), toMlirStringRef(filename),
- startLine, startCol, endLine, endCol));
- },
- nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
- nb::arg("end_line"), nb::arg("end_col"),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a file, line and column range.")
- .def("is_a_file", mlirLocationIsAFileLineColRange,
- "Returns True if this location is a FileLineColLoc.")
- .def_prop_ro(
- "filename",
- [](MlirLocation loc) {
- return mlirIdentifierStr(
- mlirLocationFileLineColRangeGetFilename(loc));
- },
- "Gets the filename from a FileLineColLoc.")
- .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
- "Gets the start line number from a `FileLineColLoc`.")
- .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
- "Gets the start column number from a `FileLineColLoc`.")
- .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
- "Gets the end line number from a `FileLineColLoc`.")
- .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
- "Gets the end column number from a `FileLineColLoc`.")
- .def_static(
- "fused",
- [](const std::vector<PyLocation> &pyLocations,
- std::optional<PyAttribute> metadata,
- DefaultingPyMlirContext context) {
- llvm::SmallVector<MlirLocation, 4> locations;
- locations.reserve(pyLocations.size());
- for (auto &pyLocation : pyLocations)
- locations.push_back(pyLocation.get());
- MlirLocation location = mlirLocationFusedGet(
- context->get(), locations.size(), locations.data(),
- metadata ? metadata->get() : MlirAttribute{0});
- return PyLocation(context->getRef(), location);
- },
- nb::arg("locations"), nb::arg("metadata") = nb::none(),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a fused location with optional "
- "metadata.")
- .def("is_a_fused", mlirLocationIsAFused,
- "Returns True if this location is a `FusedLoc`.")
- .def_prop_ro(
- "locations",
- [](PyLocation &self) {
- unsigned numLocations = mlirLocationFusedGetNumLocations(self);
- std::vector<MlirLocation> locations(numLocations);
- if (numLocations)
- mlirLocationFusedGetLocations(self, locations.data());
- std::vector<PyLocation> pyLocations{};
- pyLocations.reserve(numLocations);
- for (unsigned i = 0; i < numLocations; ++i)
- pyLocations.emplace_back(self.getContext(), locations[i]);
- return pyLocations;
- },
- "Gets the list of locations from a `FusedLoc`.")
- .def_static(
- "name",
- [](std::string name, std::optional<PyLocation> childLoc,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationNameGet(
- context->get(), toMlirStringRef(name),
- childLoc ? childLoc->get()
- : mlirLocationUnknownGet(context->get())));
- },
- nb::arg("name"), nb::arg("childLoc") = nb::none(),
- nb::arg("context") = nb::none(),
- "Gets a Location representing a named location with optional child "
- "location.")
- .def("is_a_name", mlirLocationIsAName,
- "Returns True if this location is a `NameLoc`.")
- .def_prop_ro(
- "name_str",
- [](MlirLocation loc) {
- return mlirIdentifierStr(mlirLocationNameGetName(loc));
- },
- "Gets the name string from a `NameLoc`.")
- .def_prop_ro(
- "child_loc",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationNameGetChildLoc(self));
- },
- "Gets the child location from a `NameLoc`.")
- .def_static(
- "from_attr",
- [](PyAttribute &attribute, DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationFromAttribute(attribute));
- },
- nb::arg("attribute"), nb::arg("context") = nb::none(),
- "Gets a Location from a `LocationAttr`.")
- .def_prop_ro(
- "context",
- [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Location`.")
- .def_prop_ro(
- "attr",
- [](PyLocation &self) {
- return PyAttribute(self.getContext(),
- mlirLocationGetAttribute(self));
- },
- "Get the underlying `LocationAttr`.")
- .def(
- "emit_error",
- [](PyLocation &self, std::string message) {
- mlirEmitError(self, message.c_str());
- },
- nb::arg("message"),
- R"(
- Emits an error diagnostic at this location.
-
- Args:
- message: The error message to emit.)")
- .def(
- "__repr__",
- [](PyLocation &self) {
- PyPrintAccumulator printAccum;
- mlirLocationPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly representation of the location.");
-
- //----------------------------------------------------------------------------
- // Mapping of Module
- //----------------------------------------------------------------------------
- nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
- "Gets a capsule wrapping the MlirModule.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
- R"(
- Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
-
- This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
- prevent double-frees (of the underlying `mlir::Module`).)")
- .def("_clear_mlir_module", &PyModule::clearMlirModule,
- R"(
- Clears the internal MLIR module reference.
-
- This is used internally to prevent double-free when ownership is transferred
- via the C API capsule mechanism. Not intended for normal use.)")
- .def_static(
- "parse",
- [](const std::string &moduleAsm, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParse(
- context->get(), toMlirStringRef(moduleAsm));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "parse",
- [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParse(
- context->get(), toMlirStringRef(moduleAsm));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "parseFile",
- [](const std::string &path, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyModule> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirModule module = mlirModuleCreateParseFromFile(
- context->get(), toMlirStringRef(path));
- if (mlirModuleIsNull(module))
- throw MLIRError("Unable to parse module assembly", errors.take());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("path"), nb::arg("context") = nb::none(),
- kModuleParseDocstring)
- .def_static(
- "create",
- [](const std::optional<PyLocation> &loc)
- -> nb::typed<nb::object, PyModule> {
- PyLocation pyLoc = maybeGetTracebackLocation(loc);
- MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
- return PyModule::forModule(module).releaseObject();
- },
- nb::arg("loc") = nb::none(), "Creates an empty module.")
- .def_prop_ro(
- "context",
- [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that created the `Module`.")
- .def_prop_ro(
- "operation",
- [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
- return PyOperation::forOperation(self.getContext(),
- mlirModuleGetOperation(self.get()),
- self.getRef().releaseObject())
- .releaseObject();
- },
- "Accesses the module as an operation.")
- .def_prop_ro(
- "body",
- [](PyModule &self) {
- PyOperationRef moduleOp = PyOperation::forOperation(
- self.getContext(), mlirModuleGetOperation(self.get()),
- self.getRef().releaseObject());
- PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
- return returnBlock;
- },
- "Return the block for this module.")
- .def(
- "dump",
- [](PyModule &self) {
- mlirOperationDump(mlirModuleGetOperation(self.get()));
- },
- kDumpDocstring)
- .def(
- "__str__",
- [](const nb::object &self) {
- // Defer to the operation's __str__.
- return self.attr("operation").attr("__str__")();
- },
- nb::sig("def __str__(self) -> str"),
- R"(
- Gets the assembly form of the operation with default options.
-
- If more advanced control over the assembly formatting or I/O options is needed,
- use the dedicated print or get_asm method, which supports keyword arguments to
- customize behavior.
- )")
- .def(
- "__eq__",
- [](PyModule &self, PyModule &other) {
- return mlirModuleEqual(self.get(), other.get());
- },
- "other"_a, "Compares two modules for equality.")
- .def(
- "__hash__",
- [](PyModule &self) { return mlirModuleHashValue(self.get()); },
- "Returns the hash value of the module.");
-
- //----------------------------------------------------------------------------
- // Mapping of Operation.
- //----------------------------------------------------------------------------
- nb::class_<PyOperationBase>(m, "_OperationBase")
- .def_prop_ro(
- MLIR_PYTHON_CAPI_PTR_ATTR,
- [](PyOperationBase &self) {
- return self.getOperation().getCapsule();
- },
- "Gets a capsule wrapping the `MlirOperation`.")
- .def(
- "__eq__",
- [](PyOperationBase &self, PyOperationBase &other) {
- return mlirOperationEqual(self.getOperation().get(),
- other.getOperation().get());
- },
- "Compares two operations for equality.")
- .def(
- "__eq__",
- [](PyOperationBase &self, nb::object other) { return false; },
- "Compares operation with non-operation object (always returns "
- "False).")
- .def(
- "__hash__",
- [](PyOperationBase &self) {
- return mlirOperationHashValue(self.getOperation().get());
- },
- "Returns the hash value of the operation.")
- .def_prop_ro(
- "attributes",
- [](PyOperationBase &self) {
- return PyOpAttributeMap(self.getOperation().getRef());
- },
- "Returns a dictionary-like map of operation attributes.")
- .def_prop_ro(
- "context",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
- PyOperation &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- return concreteOperation.getContext().getObject();
- },
- "Context that owns the operation.")
- .def_prop_ro(
- "name",
- [](PyOperationBase &self) {
- auto &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- MlirOperation operation = concreteOperation.get();
- return mlirIdentifierStr(mlirOperationGetName(operation));
- },
- "Returns the fully qualified name of the operation.")
- .def_prop_ro(
- "operands",
- [](PyOperationBase &self) {
- return PyOpOperandList(self.getOperation().getRef());
- },
- "Returns the list of operation operands.")
- .def_prop_ro(
- "regions",
- [](PyOperationBase &self) {
- return PyRegionList(self.getOperation().getRef());
- },
- "Returns the list of operation regions.")
- .def_prop_ro(
- "results",
- [](PyOperationBase &self) {
- return PyOpResultList(self.getOperation().getRef());
- },
- "Returns the list of Operation results.")
- .def_prop_ro(
- "result",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
- auto &operation = self.getOperation();
- return PyOpResult(operation.getRef(), getUniqueResult(operation))
- .maybeDownCast();
- },
- "Shortcut to get an op result if it has only one (throws an error "
- "otherwise).")
- .def_prop_rw(
- "location",
- [](PyOperationBase &self) {
- PyOperation &operation = self.getOperation();
- return PyLocation(operation.getContext(),
- mlirOperationGetLocation(operation.get()));
- },
- [](PyOperationBase &self, const PyLocation &location) {
- PyOperation &operation = self.getOperation();
- mlirOperationSetLocation(operation.get(), location.get());
- },
- nb::for_getter("Returns the source location the operation was "
- "defined or derived from."),
- nb::for_setter("Sets the source location the operation was defined "
- "or derived from."))
- .def_prop_ro(
- "parent",
- [](PyOperationBase &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto parent = self.getOperation().getParentOperation();
- if (parent)
- return parent->getObject();
- return {};
- },
- "Returns the parent operation, or `None` if at top level.")
- .def(
- "__str__",
- [](PyOperationBase &self) {
- return self.getAsm(/*binary=*/false,
- /*largeElementsLimit=*/std::nullopt,
- /*largeResourceLimit=*/std::nullopt,
- /*enableDebugInfo=*/false,
- /*prettyDebugInfo=*/false,
- /*printGenericOpForm=*/false,
- /*useLocalScope=*/false,
- /*useNameLocAsPrefix=*/false,
- /*assumeVerified=*/false,
- /*skipRegions=*/false);
- },
- nb::sig("def __str__(self) -> str"),
- "Returns the assembly form of the operation.")
- .def("print",
- nb::overload_cast<PyAsmState &, nb::object, bool>(
- &PyOperationBase::print),
- nb::arg("state"), nb::arg("file") = nb::none(),
- nb::arg("binary") = false,
- R"(
- Prints the assembly form of the operation to a file like object.
-
- Args:
- state: `AsmState` capturing the operation numbering and flags.
- file: Optional file like object to write to. Defaults to sys.stdout.
- binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
- .def("print",
- nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
- bool, bool, bool, bool, bool, bool, nb::object,
- bool, bool>(&PyOperationBase::print),
- // Careful: Lots of arguments must match up with print method.
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
- nb::arg("binary") = false, nb::arg("skip_regions") = false,
- R"(
- Prints the assembly form of the operation to a file like object.
-
- Args:
- large_elements_limit: Whether to elide elements attributes above this
- number of elements. Defaults to None (no limit).
- large_resource_limit: Whether to elide resource attributes above this
- number of characters. Defaults to None (no limit). If large_elements_limit
- is set and this is None, the behavior will be to use large_elements_limit
- as large_resource_limit.
- enable_debug_info: Whether to print debug/location information. Defaults
- to False.
- pretty_debug_info: Whether to format debug information for easier reading
- by a human (warning: the result is unparseable). Defaults to False.
- print_generic_op_form: Whether to print the generic assembly forms of all
- ops. Defaults to False.
- use_local_scope: Whether to print in a way that is more optimized for
- multi-threaded access but may not be consistent with how the overall
- module prints.
- use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
- prefixes for the SSA identifiers. Defaults to False.
- assume_verified: By default, if not printing generic form, the verifier
- will be run and if it fails, generic form will be printed with a comment
- about failed verification. While a reasonable default for interactive use,
- for systematic use, it is often better for the caller to verify explicitly
- and report failures in a more robust fashion. Set this to True if doing this
- in order to avoid running a redundant verification. If the IR is actually
- invalid, behavior is undefined.
- file: The file like object to write to. Defaults to sys.stdout.
- binary: Whether to write bytes (True) or str (False). Defaults to False.
- skip_regions: Whether to skip printing regions. Defaults to False.)")
- .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
- nb::arg("desired_version") = nb::none(),
- R"(
- Write the bytecode form of the operation to a file like object.
-
- Args:
- file: The file like object to write to.
- desired_version: Optional version of bytecode to emit.
- Returns:
- The bytecode writer status.)")
- .def("get_asm", &PyOperationBase::getAsm,
- // Careful: Lots of arguments must match up with get_asm method.
- nb::arg("binary") = false,
- nb::arg("large_elements_limit") = nb::none(),
- nb::arg("large_resource_limit") = nb::none(),
- nb::arg("enable_debug_info") = false,
- nb::arg("pretty_debug_info") = false,
- nb::arg("print_generic_op_form") = false,
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
- R"(
- Gets the assembly form of the operation with all options available.
-
- Args:
- binary: Whether to return a bytes (True) or str (False) object. Defaults to
- False.
- ... others ...: See the print() method for common keyword arguments for
- configuring the printout.
- Returns:
- Either a bytes or str object, depending on the setting of the `binary`
- argument.)")
- .def("verify", &PyOperationBase::verify,
- "Verify the operation. Raises MLIRError if verification fails, and "
- "returns true otherwise.")
- .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
- "Puts self immediately after the other operation in its parent "
- "block.")
- .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
- "Puts self immediately before the other operation in its parent "
- "block.")
- .def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
- nb::arg("other"),
- R"(
- Checks if this operation is before another in the same block.
-
- Args:
- other: Another operation in the same parent block.
-
- Returns:
- True if this operation is before `other` in the operation list of the parent block.)")
- .def(
- "clone",
- [](PyOperationBase &self,
- const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
- return self.getOperation().clone(ip);
- },
- nb::arg("ip") = nb::none(),
- R"(
- Creates a deep copy of the operation.
-
- Args:
- ip: Optional insertion point where the cloned operation should be inserted.
- If None, the current insertion point is used. If False, the operation
- remains detached.
-
- Returns:
- A new Operation that is a clone of this operation.)")
- .def(
- "detach_from_parent",
- [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
- PyOperation &operation = self.getOperation();
- operation.checkValid();
- if (!operation.isAttached())
- throw nb::value_error("Detached operation has no parent.");
-
- operation.detachFromParent();
- return operation.createOpView();
- },
- "Detaches the operation from its parent block.")
- .def_prop_ro(
- "attached",
- [](PyOperationBase &self) {
- PyOperation &operation = self.getOperation();
- operation.checkValid();
- return operation.isAttached();
- },
- "Reports if the operation is attached to its parent block.")
- .def(
- "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
- R"(
- Erases the operation and frees its memory.
-
- Note:
- After erasing, any Python references to the operation become invalid.)")
- .def("walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = PyMlirWalkOrder::MlirWalkPostOrder,
- // clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
- // clang-format on
- R"(
- Walks the operation tree with a callback function.
-
- Args:
- callback: A callable that takes an Operation and returns a WalkResult.
- walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
-
- nb::class_<PyOperation, PyOperationBase>(m, "Operation")
- .def_static(
- "create",
- [](std::string_view name,
- std::optional<std::vector<PyType *>> results,
- std::optional<std::vector<PyValue *>> operands,
- std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors, int regions,
- const std::optional<PyLocation> &location,
- const nb::object &maybeIp,
- bool inferType) -> nb::typed<nb::object, PyOperation> {
- // Unpack/validate operands.
- llvm::SmallVector<MlirValue, 4> mlirOperands;
- if (operands) {
- mlirOperands.reserve(operands->size());
- for (PyValue *operand : *operands) {
- if (!operand)
- throw nb::value_error("operand value cannot be None");
- mlirOperands.push_back(operand->get());
- }
- }
-
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- return PyOperation::create(name, results, mlirOperands, attributes,
- successors, regions, pyLoc, maybeIp,
- inferType);
- },
- nb::arg("name"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- nb::arg("infer_type") = false,
- R"(
- Creates a new operation.
-
- Args:
- name: Operation name (e.g. `dialect.operation`).
- results: Optional sequence of Type representing op result types.
- operands: Optional operands of the operation.
- attributes: Optional Dict of {str: Attribute}.
- successors: Optional List of Block for the operation's successors.
- regions: Number of regions to create (default = 0).
- location: Optional Location object (defaults to resolve from context manager).
- ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
- infer_type: Whether to infer result types (default = False).
- Returns:
- A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
- .def_static(
- "parse",
- [](const std::string &sourceStr, const std::string &sourceName,
- DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyOpView> {
- return PyOperation::parse(context->getRef(), sourceStr, sourceName)
- ->createOpView();
- },
- nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
- nb::arg("context") = nb::none(),
- "Parses an operation. Supports both text assembly format and binary "
- "bytecode format.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
- "Gets a capsule wrapping the MlirOperation.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyOperation::createFromCapsule,
- "Creates an Operation from a capsule wrapping MlirOperation.")
- .def_prop_ro(
- "operation",
- [](nb::object self) -> nb::typed<nb::object, PyOperation> {
- return self;
- },
- "Returns self (the operation).")
- .def_prop_ro(
- "opview",
- [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
- return self.createOpView();
- },
- R"(
- Returns an OpView of this operation.
-
- Note:
- If the operation has a registered and loaded dialect then this OpView will
- be concrete wrapper class.)")
- .def_prop_ro("block", &PyOperation::getBlock,
- "Returns the block containing this operation.")
- .def_prop_ro(
- "successors",
- [](PyOperationBase &self) {
- return PyOpSuccessors(self.getOperation().getRef());
- },
- "Returns the list of Operation successors.")
- .def(
- "replace_uses_of_with",
- [](PyOperation &self, PyValue &of, PyValue &with) {
- mlirOperationReplaceUsesOfWith(self.get(), of.get(), with.get());
- },
- "of"_a, "with_"_a,
- "Replaces uses of the 'of' value with the 'with' value inside the "
- "operation.")
- .def("_set_invalid", &PyOperation::setInvalid,
- "Invalidate the operation.");
-
- auto opViewClass =
- nb::class_<PyOpView, PyOperationBase>(m, "OpView")
- .def(nb::init<nb::typed<nb::object, PyOperation>>(),
- nb::arg("operation"))
- .def(
- "__init__",
- [](PyOpView *self, std::string_view name,
- std::tuple<int, bool> opRegionSpec,
- nb::object operandSegmentSpecObj,
- nb::object resultSegmentSpecObj,
- std::optional<nb::list> resultTypeList, nb::list operandList,
- std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions,
- const std::optional<PyLocation> &location,
- const nb::object &maybeIp) {
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- new (self) PyOpView(PyOpView::buildGeneric(
- name, opRegionSpec, operandSegmentSpecObj,
- resultSegmentSpecObj, resultTypeList, operandList,
- attributes, successors, regions, pyLoc, maybeIp));
- },
- nb::arg("name"), nb::arg("opRegionSpec"),
- nb::arg("operandSegmentSpecObj") = nb::none(),
- nb::arg("resultSegmentSpecObj") = nb::none(),
- nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(),
- nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(),
- nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
- nb::arg("ip") = nb::none())
- .def_prop_ro(
- "operation",
- [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
- return self.getOperationObject();
- })
- .def_prop_ro("opview",
- [](nb::object self) -> nb::typed<nb::object, PyOpView> {
- return self;
- })
- .def(
- "__str__",
- [](PyOpView &self) { return nb::str(self.getOperationObject()); })
- .def_prop_ro(
- "successors",
- [](PyOperationBase &self) {
- return PyOpSuccessors(self.getOperation().getRef());
- },
- "Returns the list of Operation successors.")
- .def(
- "_set_invalid",
- [](PyOpView &self) { self.getOperation().setInvalid(); },
- "Invalidate the operation.");
- opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
- opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
- opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
- // It is faster to pass the operation_name, ods_regions, and
- // ods_operand_segments/ods_result_segments as arguments to the constructor,
- // rather than to access them as attributes.
- opViewClass.attr("build_generic") = classmethod(
- [](nb::handle cls, std::optional<nb::list> resultTypeList,
- nb::list operandList, std::optional<nb::dict> attributes,
- std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, std::optional<PyLocation> location,
- const nb::object &maybeIp) {
- std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
- std::tuple<int, bool> opRegionSpec =
- nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
- nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
- nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
- PyLocation pyLoc = maybeGetTracebackLocation(location);
- return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
- resultSegmentSpec, resultTypeList,
- operandList, attributes, successors,
- regions, pyLoc, maybeIp);
- },
- nb::arg("cls"), nb::arg("results") = nb::none(),
- nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
- nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(),
- nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- "Builds a specific, generated OpView based on class level attributes.");
- opViewClass.attr("parse") = classmethod(
- [](const nb::object &cls, const std::string &sourceStr,
- const std::string &sourceName,
- DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
- PyOperationRef parsed =
- PyOperation::parse(context->getRef(), sourceStr, sourceName);
-
- // Check if the expected operation was parsed, and cast to to the
- // appropriate `OpView` subclass if successful.
- // NOTE: This accesses attributes that have been automatically added to
- // `OpView` subclasses, and is not intended to be used on `OpView`
- // directly.
- std::string clsOpName =
- nb::cast<std::string>(cls.attr("OPERATION_NAME"));
- MlirStringRef identifier =
- mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
- std::string_view parsedOpName(identifier.data, identifier.length);
- if (clsOpName != parsedOpName)
- throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
- parsedOpName + "'");
- return PyOpView::constructDerived(cls, parsed.getObject());
- },
- nb::arg("cls"), nb::arg("source"), nb::kw_only(),
- nb::arg("source_name") = "", nb::arg("context") = nb::none(),
- "Parses a specific, generated OpView based on class level attributes.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyRegion.
- //----------------------------------------------------------------------------
- nb::class_<PyRegion>(m, "Region")
- .def_prop_ro(
- "blocks",
- [](PyRegion &self) {
- return PyBlockList(self.getParentOperation(), self.get());
- },
- "Returns a forward-optimized sequence of blocks.")
- .def_prop_ro(
- "owner",
- [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
- return self.getParentOperation()->createOpView();
- },
- "Returns the operation owning this region.")
- .def(
- "__iter__",
- [](PyRegion &self) {
- self.checkValid();
- MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
- return PyBlockIterator(self.getParentOperation(), firstBlock);
- },
- "Iterates over blocks in the region.")
- .def(
- "__eq__",
- [](PyRegion &self, PyRegion &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two regions for pointer equality.")
- .def(
- "__eq__", [](PyRegion &self, nb::object &other) { return false; },
- "Compares region with non-region object (always returns False).");
-
- //----------------------------------------------------------------------------
- // Mapping of PyBlock.
- //----------------------------------------------------------------------------
- nb::class_<PyBlock>(m, "Block")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
- "Gets a capsule wrapping the MlirBlock.")
- .def_prop_ro(
- "owner",
- [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
- return self.getParentOperation()->createOpView();
- },
- "Returns the owning operation of this block.")
- .def_prop_ro(
- "region",
- [](PyBlock &self) {
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- return PyRegion(self.getParentOperation(), region);
- },
- "Returns the owning region of this block.")
- .def_prop_ro(
- "arguments",
- [](PyBlock &self) {
- return PyBlockArgumentList(self.getParentOperation(), self.get());
- },
- "Returns a list of block arguments.")
- .def(
- "add_argument",
- [](PyBlock &self, const PyType &type, const PyLocation &loc) {
- return PyBlockArgument(self.getParentOperation(),
- mlirBlockAddArgument(self.get(), type, loc));
- },
- "type"_a, "loc"_a,
- R"(
- Appends an argument of the specified type to the block.
-
- Args:
- type: The type of the argument to add.
- loc: The source location for the argument.
-
- Returns:
- The newly added block argument.)")
- .def(
- "erase_argument",
- [](PyBlock &self, unsigned index) {
- return mlirBlockEraseArgument(self.get(), index);
- },
- nb::arg("index"),
- R"(
- Erases the argument at the specified index.
-
- Args:
- index: The index of the argument to erase.)")
- .def_prop_ro(
- "operations",
- [](PyBlock &self) {
- return PyOperationList(self.getParentOperation(), self.get());
- },
- "Returns a forward-optimized sequence of operations.")
- .def_static(
- "create_at_start",
- [](PyRegion &parent, const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- parent.checkValid();
- MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
- mlirRegionInsertOwnedBlock(parent, 0, block);
- return PyBlock(parent.getParentOperation(), block);
- },
- nb::arg("parent"), nb::arg("arg_types") = nb::list(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block at the beginning of the given "
- "region (with given argument types and locations).")
- .def(
- "append_to",
- [](PyBlock &self, PyRegion ®ion) {
- MlirBlock b = self.get();
- if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
- mlirBlockDetach(b);
- mlirRegionAppendOwnedBlock(region.get(), b);
- },
- nb::arg("region"),
- R"(
- Appends this block to a region.
-
- Transfers ownership if the block is currently owned by another region.
-
- Args:
- region: The region to append the block to.)")
- .def(
- "create_before",
- [](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- self.checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
- return PyBlock(self.getParentOperation(), block);
- },
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block before this block "
- "(with given argument types and locations).")
- .def(
- "create_after",
- [](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
- self.checkValid();
- MlirBlock block =
- createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
- MlirRegion region = mlirBlockGetParentRegion(self.get());
- mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
- return PyBlock(self.getParentOperation(), block);
- },
- nb::arg("arg_types"), nb::kw_only(),
- nb::arg("arg_locs") = std::nullopt,
- "Creates and returns a new Block after this block "
- "(with given argument types and locations).")
- .def(
- "__iter__",
- [](PyBlock &self) {
- self.checkValid();
- MlirOperation firstOperation =
- mlirBlockGetFirstOperation(self.get());
- return PyOperationIterator(self.getParentOperation(),
- firstOperation);
- },
- "Iterates over operations in the block.")
- .def(
- "__eq__",
- [](PyBlock &self, PyBlock &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two blocks for pointer equality.")
- .def(
- "__eq__", [](PyBlock &self, nb::object &other) { return false; },
- "Compares block with non-block object (always returns False).")
- .def(
- "__hash__",
- [](PyBlock &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the block.")
- .def(
- "__str__",
- [](PyBlock &self) {
- self.checkValid();
- PyPrintAccumulator printAccum;
- mlirBlockPrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the block.")
- .def(
- "append",
- [](PyBlock &self, PyOperationBase &operation) {
- if (operation.getOperation().isAttached())
- operation.getOperation().detachFromParent();
-
- MlirOperation mlirOperation = operation.getOperation().get();
- mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
- operation.getOperation().setAttached(
- self.getParentOperation().getObject());
- },
- nb::arg("operation"),
- R"(
- Appends an operation to this block.
-
- If the operation is currently in another block, it will be moved.
-
- Args:
- operation: The operation to append to the block.)")
- .def_prop_ro(
- "successors",
- [](PyBlock &self) {
- return PyBlockSuccessors(self, self.getParentOperation());
- },
- "Returns the list of Block successors.")
- .def_prop_ro(
- "predecessors",
- [](PyBlock &self) {
- return PyBlockPredecessors(self, self.getParentOperation());
- },
- "Returns the list of Block predecessors.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyInsertionPoint.
- //----------------------------------------------------------------------------
-
- nb::class_<PyInsertionPoint>(m, "InsertionPoint")
- .def(nb::init<PyBlock &>(), nb::arg("block"),
- "Inserts after the last operation but still inside the block.")
- .def("__enter__", &PyInsertionPoint::contextEnter,
- "Enters the insertion point as a context manager.")
- .def("__exit__", &PyInsertionPoint::contextExit,
- nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none(),
- "Exits the insertion point context manager.")
- .def_prop_ro_static(
- "current",
- [](nb::object & /*class*/) {
- auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
- if (!ip)
- throw nb::value_error("No current InsertionPoint");
- return ip;
- },
- nb::sig("def current(/) -> InsertionPoint"),
- "Gets the InsertionPoint bound to the current thread or raises "
- "ValueError if none has been set.")
- .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
- "Inserts before a referenced operation.")
- .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
- nb::arg("block"),
- R"(
- Creates an insertion point at the beginning of a block.
-
- Args:
- block: The block at whose beginning operations should be inserted.
-
- Returns:
- An InsertionPoint at the block's beginning.)")
- .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
- nb::arg("block"),
- R"(
- Creates an insertion point before a block's terminator.
-
- Args:
- block: The block whose terminator to insert before.
-
- Returns:
- An InsertionPoint before the terminator.
-
- Raises:
- ValueError: If the block has no terminator.)")
- .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
- R"(
- Creates an insertion point immediately after an operation.
-
- Args:
- operation: The operation after which to insert.
-
- Returns:
- An InsertionPoint after the operation.)")
- .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
- R"(
- Inserts an operation at this insertion point.
-
- Args:
- operation: The operation to insert.)")
- .def_prop_ro(
- "block", [](PyInsertionPoint &self) { return self.getBlock(); },
- "Returns the block that this `InsertionPoint` points to.")
- .def_prop_ro(
- "ref_operation",
- [](PyInsertionPoint &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto refOperation = self.getRefOperation();
- if (refOperation)
- return refOperation->getObject();
- return {};
- },
- "The reference operation before which new operations are "
- "inserted, or None if the insertion point is at the end of "
- "the block.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyAttribute.
- //----------------------------------------------------------------------------
- nb::class_<PyAttribute>(m, "Attribute")
- // Delegate to the PyAttribute copy constructor, which will also lifetime
- // extend the backing context which owns the MlirAttribute.
- .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
- "Casts the passed attribute to the generic `Attribute`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
- "Gets a capsule wrapping the MlirAttribute.")
- .def_static(
- MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
- "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
- .def_static(
- "parse",
- [](const std::string &attrSpec, DefaultingPyMlirContext context)
- -> nb::typed<nb::object, PyAttribute> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute attr = mlirAttributeParseGet(
- context->get(), toMlirStringRef(attrSpec));
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Unable to parse attribute", errors.take());
- return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- "Parses an attribute from an assembly form. Raises an `MLIRError` on "
- "failure.")
- .def_prop_ro(
- "context",
- [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Attribute`.")
- .def_prop_ro(
- "type",
- [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirAttributeGetType(self))
- .maybeDownCast();
- },
- "Returns the type of the `Attribute`.")
- .def(
- "get_named",
- [](PyAttribute &self, std::string name) {
- return PyNamedAttribute(self, std::move(name));
- },
- nb::keep_alive<0, 1>(),
- R"(
- Binds a name to the attribute, creating a `NamedAttribute`.
-
- Args:
- name: The name to bind to the `Attribute`.
-
- Returns:
- A `NamedAttribute` with the given name and this attribute.)")
- .def(
- "__eq__",
- [](PyAttribute &self, PyAttribute &other) { return self == other; },
- "Compares two attributes for equality.")
- .def(
- "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
- "Compares attribute with non-attribute object (always returns "
- "False).")
- .def(
- "__hash__",
- [](PyAttribute &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the attribute.")
- .def(
- "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
- kDumpDocstring)
- .def(
- "__str__",
- [](PyAttribute &self) {
- PyPrintAccumulator printAccum;
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the Attribute.")
- .def(
- "__repr__",
- [](PyAttribute &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, attribute values are generally considered useful and
- // are printed. This may need to be re-evaluated if debug dumps end
- // up being excessive.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Attribute(");
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the attribute.")
- .def_prop_ro(
- "typeid",
- [](PyAttribute &self) {
- MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
- assert(!mlirTypeIDIsNull(mlirTypeID) &&
- "mlirTypeID was expected to be non-null.");
- return PyTypeID(mlirTypeID);
- },
- "Returns the `TypeID` of the attribute.")
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
- return self.maybeDownCast();
- },
- "Downcasts the attribute to a more specific attribute if possible.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyNamedAttribute
- //----------------------------------------------------------------------------
- nb::class_<PyNamedAttribute>(m, "NamedAttribute")
- .def(
- "__repr__",
- [](PyNamedAttribute &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("NamedAttribute(");
- printAccum.parts.append(
- nb::str(mlirIdentifierStr(self.namedAttr.name).data,
- mlirIdentifierStr(self.namedAttr.name).length));
- printAccum.parts.append("=");
- mlirAttributePrint(self.namedAttr.attribute,
- printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the named attribute.")
- .def_prop_ro(
- "name",
- [](PyNamedAttribute &self) {
- return mlirIdentifierStr(self.namedAttr.name);
- },
- "The name of the `NamedAttribute` binding.")
- .def_prop_ro(
- "attr",
- [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
- nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
- "The underlying generic attribute of the `NamedAttribute` binding.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyType.
- //----------------------------------------------------------------------------
- nb::class_<PyType>(m, "Type")
- // Delegate to the PyType copy constructor, which will also lifetime
- // extend the backing context which owns the MlirType.
- .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
- "Casts the passed type to the generic `Type`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
- "Gets a capsule wrapping the `MlirType`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
- "Creates a Type from a capsule wrapping `MlirType`.")
- .def_static(
- "parse",
- [](std::string typeSpec,
- DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType type =
- mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
- if (mlirTypeIsNull(type))
- throw MLIRError("Unable to parse type", errors.take());
- return PyType(context.get()->getRef(), type).maybeDownCast();
- },
- nb::arg("asm"), nb::arg("context") = nb::none(),
- R"(
- Parses the assembly form of a type.
-
- Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
-
- See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
- .def_prop_ro(
- "context",
- [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getContext().getObject();
- },
- "Context that owns the `Type`.")
- .def(
- "__eq__", [](PyType &self, PyType &other) { return self == other; },
- "Compares two types for equality.")
- .def(
- "__eq__", [](PyType &self, nb::object &other) { return false; },
- nb::arg("other").none(),
- "Compares type with non-type object (always returns False).")
- .def(
- "__hash__",
- [](PyType &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the `Type`.")
- .def(
- "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
- .def(
- "__str__",
- [](PyType &self) {
- PyPrintAccumulator printAccum;
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- "Returns the assembly form of the `Type`.")
- .def(
- "__repr__",
- [](PyType &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, types are an exception as they typically have compact
- // assembly forms and printing them is useful.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Type(");
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- "Returns a string representation of the `Type`.")
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyType &self) -> nb::typed<nb::object, PyType> {
- return self.maybeDownCast();
- },
- "Downcasts the Type to a more specific `Type` if possible.")
- .def_prop_ro(
- "typeid",
- [](PyType &self) {
- MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
- if (!mlirTypeIDIsNull(mlirTypeID))
- return PyTypeID(mlirTypeID);
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
- throw nb::value_error(
- (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
- },
- "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
- "`Type` has no "
- "`TypeID`.");
-
- //----------------------------------------------------------------------------
- // Mapping of PyTypeID.
- //----------------------------------------------------------------------------
- nb::class_<PyTypeID>(m, "TypeID")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
- "Gets a capsule wrapping the `MlirTypeID`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
- "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
- // Note, this tests whether the underlying TypeIDs are the same,
- // not whether the wrapper MlirTypeIDs are the same, nor whether
- // the Python objects are the same (i.e., PyTypeID is a value type).
- .def(
- "__eq__",
- [](PyTypeID &self, PyTypeID &other) { return self == other; },
- "Compares two `TypeID`s for equality.")
- .def(
- "__eq__",
- [](PyTypeID &self, const nb::object &other) { return false; },
- "Compares TypeID with non-TypeID object (always returns False).")
- // Note, this gives the hash value of the underlying TypeID, not the
- // hash value of the Python object, nor the hash value of the
- // MlirTypeID wrapper.
- .def(
- "__hash__",
- [](PyTypeID &self) {
- return static_cast<size_t>(mlirTypeIDHashValue(self));
- },
- "Returns the hash value of the `TypeID`.");
-
- //----------------------------------------------------------------------------
- // Mapping of Value.
- //----------------------------------------------------------------------------
- m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
-
- nb::class_<PyValue>(m, "Value", nb::is_generic(),
- nb::sig("class Value(Generic[_T])"))
- .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
- "Creates a Value reference from another `Value`.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
- "Gets a capsule wrapping the `MlirValue`.")
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
- "Creates a `Value` from a capsule wrapping `MlirValue`.")
- .def_prop_ro(
- "context",
- [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
- return self.getParentOperation()->getContext().getObject();
- },
- "Context in which the value lives.")
- .def(
- "dump", [](PyValue &self) { mlirValueDump(self.get()); },
- kDumpDocstring)
- .def_prop_ro(
- "owner",
- [](PyValue &self) -> nb::typed<nb::object, PyOpView> {
- MlirValue v = self.get();
- if (mlirValueIsAOpResult(v)) {
- assert(mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match "
- "that in "
- "the IR");
- return self.getParentOperation()->createOpView();
- }
-
- if (mlirValueIsABlockArgument(v)) {
- MlirBlock block = mlirBlockArgumentGetOwner(self.get());
- return nb::cast(PyBlock(self.getParentOperation(), block));
- }
-
- assert(false && "Value must be a block argument or an op result");
- return nb::none();
- },
- "Returns the owner of the value (`Operation` for results, `Block` "
- "for "
- "arguments).")
- .def_prop_ro(
- "uses",
- [](PyValue &self) {
- return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
- },
- "Returns an iterator over uses of this value.")
- .def(
- "__eq__",
- [](PyValue &self, PyValue &other) {
- return self.get().ptr == other.get().ptr;
- },
- "Compares two values for pointer equality.")
- .def(
- "__eq__", [](PyValue &self, nb::object other) { return false; },
- "Compares value with non-value object (always returns False).")
- .def(
- "__hash__",
- [](PyValue &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- },
- "Returns the hash value of the value.")
- .def(
- "__str__",
- [](PyValue &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Value(");
- mlirValuePrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- },
- R"(
- Returns the string form of the value.
-
- If the value is a block argument, this is the assembly form of its type and the
- position in the argument list. If the value is an operation result, this is
- equivalent to printing the operation that produced it.
- )")
- .def(
- "get_name",
- [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
- PyPrintAccumulator printAccum;
- MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- if (useNameLocAsPrefix)
- mlirOpPrintingFlagsPrintNameLocAsPrefix(flags);
- MlirAsmState valueState =
- mlirAsmStateCreateForValue(self.get(), flags);
- mlirValuePrintAsOperand(self.get(), valueState,
- printAccum.getCallback(),
- printAccum.getUserData());
- mlirOpPrintingFlagsDestroy(flags);
- mlirAsmStateDestroy(valueState);
- return printAccum.join();
- },
- nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false,
- R"(
- Returns the string form of value as an operand.
-
- Args:
- use_local_scope: Whether to use local scope for naming.
- use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
-
- Returns:
- The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
- .def(
- "get_name",
- [](PyValue &self, PyAsmState &state) {
- PyPrintAccumulator printAccum;
- MlirAsmState valueState = state.get();
- mlirValuePrintAsOperand(self.get(), valueState,
- printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- },
- nb::arg("state"),
- "Returns the string form of value as an operand (i.e., the ValueID).")
- .def_prop_ro(
- "type",
- [](PyValue &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()))
- .maybeDownCast();
- },
- "Returns the type of the value.")
- .def(
- "set_type",
- [](PyValue &self, const PyType &type) {
- mlirValueSetType(self.get(), type);
- },
- nb::arg("type"), "Sets the type of the value.",
- nb::sig("def set_type(self, type: _T)"))
- .def(
- "replace_all_uses_with",
- [](PyValue &self, PyValue &with) {
- mlirValueReplaceAllUsesOfWith(self.get(), with.get());
- },
- "Replace all uses of value with the new value, updating anything in "
- "the IR that uses `self` to use the other value instead.")
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, PyOperation &exception) {
- MlirOperation exceptedUser = exception.get();
- mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, const nb::list &exceptions) {
- // Convert Python list to a SmallVector of MlirOperations
- llvm::SmallVector<MlirOperation> exceptionOps;
- for (nb::handle exception : exceptions) {
- exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
- }
-
- mlirValueReplaceAllUsesExcept(
- self, with, static_cast<intptr_t>(exceptionOps.size()),
- exceptionOps.data());
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with, PyOperation &exception) {
- MlirOperation exceptedUser = exception.get();
- mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- "replace_all_uses_except",
- [](PyValue &self, PyValue &with,
- std::vector<PyOperation> &exceptions) {
- // Convert Python list to a SmallVector of MlirOperations
- llvm::SmallVector<MlirOperation> exceptionOps;
- for (PyOperation &exception : exceptions)
- exceptionOps.push_back(exception);
- mlirValueReplaceAllUsesExcept(
- self, with, static_cast<intptr_t>(exceptionOps.size()),
- exceptionOps.data());
- },
- nb::arg("with_"), nb::arg("exceptions"),
- kValueReplaceAllUsesExceptDocstring)
- .def(
- MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) -> nb::typed<nb::object, PyValue> {
- return self.maybeDownCast();
- },
- "Downcasts the `Value` to a more specific kind if possible.")
- .def_prop_ro(
- "location",
- [](MlirValue self) {
- return PyLocation(
- PyMlirContext::forContext(mlirValueGetContext(self)),
- mlirValueGetLocation(self));
- },
- "Returns the source location of the value.");
-
- PyBlockArgument::bind(m);
- PyOpResult::bind(m);
- PyOpOperand::bind(m);
-
- nb::class_<PyAsmState>(m, "AsmState")
- .def(nb::init<PyValue &, bool>(), nb::arg("value"),
- nb::arg("use_local_scope") = false,
- R"(
- Creates an `AsmState` for consistent SSA value naming.
-
- Args:
- value: The value to create state for.
- use_local_scope: Whether to use local scope for naming.)")
- .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
- nb::arg("use_local_scope") = false,
- R"(
- Creates an AsmState for consistent SSA value naming.
-
- Args:
- op: The operation to create state for.
- use_local_scope: Whether to use local scope for naming.)");
-
- //----------------------------------------------------------------------------
- // Mapping of SymbolTable.
- //----------------------------------------------------------------------------
- nb::class_<PySymbolTable>(m, "SymbolTable")
- .def(nb::init<PyOperationBase &>(),
- R"(
- Creates a symbol table for an operation.
-
- Args:
- operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
-
- Raises:
- TypeError: If the operation is not a symbol table.)")
- .def(
- "__getitem__",
- [](PySymbolTable &self,
- const std::string &name) -> nb::typed<nb::object, PyOpView> {
- return self.dunderGetItem(name);
- },
- R"(
- Looks up a symbol by name in the symbol table.
-
- Args:
- name: The name of the symbol to look up.
-
- Returns:
- The operation defining the symbol.
-
- Raises:
- KeyError: If the symbol is not found.)")
- .def("insert", &PySymbolTable::insert, nb::arg("operation"),
- R"(
- Inserts a symbol operation into the symbol table.
-
- Args:
- operation: An operation with a symbol name to insert.
-
- Returns:
- The symbol name attribute of the inserted operation.
-
- Raises:
- ValueError: If the operation does not have a symbol name.)")
- .def("erase", &PySymbolTable::erase, nb::arg("operation"),
- R"(
- Erases a symbol operation from the symbol table.
-
- Args:
- operation: The symbol operation to erase.
-
- Note:
- The operation is also erased from the IR and invalidated.)")
- .def("__delitem__", &PySymbolTable::dunderDel,
- "Deletes a symbol by name from the symbol table.")
- .def(
- "__contains__",
- [](PySymbolTable &table, const std::string &name) {
- return !mlirOperationIsNull(mlirSymbolTableLookup(
- table, mlirStringRefCreate(name.data(), name.length())));
- },
- "Checks if a symbol with the given name exists in the table.")
- // Static helpers.
- .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
- nb::arg("symbol"), nb::arg("name"),
- "Sets the symbol name for a symbol operation.")
- .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
- nb::arg("symbol"),
- "Gets the symbol name from a symbol operation.")
- .def_static("get_visibility", &PySymbolTable::getVisibility,
- nb::arg("symbol"),
- "Gets the visibility attribute of a symbol operation.")
- .def_static("set_visibility", &PySymbolTable::setVisibility,
- nb::arg("symbol"), nb::arg("visibility"),
- "Sets the visibility attribute of a symbol operation.")
- .def_static("replace_all_symbol_uses",
- &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
- nb::arg("new_symbol"), nb::arg("from_op"),
- "Replaces all uses of a symbol with a new symbol name within "
- "the given operation.")
- .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
- nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
- nb::arg("callback"),
- "Walks symbol tables starting from an operation with a "
- "callback function.");
-
- // Container bindings.
- PyBlockArgumentList::bind(m);
- PyBlockIterator::bind(m);
- PyBlockList::bind(m);
- PyBlockSuccessors::bind(m);
- PyBlockPredecessors::bind(m);
- PyOperationIterator::bind(m);
- PyOperationList::bind(m);
- PyOpAttributeMap::bind(m);
- PyOpOperandIterator::bind(m);
- PyOpOperandList::bind(m);
- PyOpResultList::bind(m);
- PyOpSuccessors::bind(m);
- PyRegionIterator::bind(m);
- PyRegionList::bind(m);
-
- // Debug bindings.
- PyGlobalDebugFlag::bind(m);
-
- // Attribute builder getter.
- PyAttrBuilderMap::bind(m);
-}
-
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
-
NB_MODULE(_mlir, m) {
- m.doc() = "MLIR Python Native Extension";
- m.attr("T") = nb::type_var("T");
- m.attr("U") = nb::type_var("U");
-
- nb::class_<PyGlobals>(m, "_Globals")
- .def_prop_rw("dialect_search_modules",
- &PyGlobals::getDialectSearchPrefixes,
- &PyGlobals::setDialectSearchPrefixes)
- .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
- "module_name"_a)
- .def(
- "_check_dialect_module_loaded",
- [](PyGlobals &self, const std::string &dialectNamespace) {
- return self.loadDialectModule(dialectNamespace);
- },
- "dialect_namespace"_a)
- .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
- "dialect_namespace"_a, "dialect_class"_a,
- "Testing hook for directly registering a dialect")
- .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- "operation_name"_a, "operation_class"_a, nb::kw_only(),
- "replace"_a = false,
- "Testing hook for directly registering an operation")
- .def("loc_tracebacks_enabled",
- [](PyGlobals &self) {
- return self.getTracebackLoc().locTracebacksEnabled();
- })
- .def("set_loc_tracebacks_enabled",
- [](PyGlobals &self, bool enabled) {
- self.getTracebackLoc().setLocTracebacksEnabled(enabled);
- })
- .def("loc_tracebacks_frame_limit",
- [](PyGlobals &self) {
- return self.getTracebackLoc().locTracebackFramesLimit();
- })
- .def("set_loc_tracebacks_frame_limit",
- [](PyGlobals &self, std::optional<int> n) {
- self.getTracebackLoc().setLocTracebackFramesLimit(
- n.value_or(PyGlobals::TracebackLoc::kMaxFrames));
- })
- .def("register_traceback_file_inclusion",
- [](PyGlobals &self, const std::string &filename) {
- self.getTracebackLoc().registerTracebackFileInclusion(filename);
- })
- .def("register_traceback_file_exclusion",
- [](PyGlobals &self, const std::string &filename) {
- self.getTracebackLoc().registerTracebackFileExclusion(filename);
- });
-
- // Aside from making the globals accessible to python, having python manage
- // it is necessary to make sure it is destroyed (and releases its python
- // resources) properly.
- m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
-
- // Registration decorators.
- m.def(
- "register_dialect",
- [](nb::type_object pyClass) {
- std::string dialectNamespace =
- nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
- PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
- return pyClass;
- },
- "dialect_class"_a,
- "Class decorator for registering a custom Dialect wrapper");
- m.def(
- "register_operation",
- [](const nb::type_object &dialectClass, bool replace) -> nb::object {
- return nb::cpp_function(
- [dialectClass,
- replace](nb::type_object opClass) -> nb::type_object {
- std::string operationName =
- nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
- PyGlobals::get().registerOperationImpl(operationName, opClass,
- replace);
- // Dict-stuff the new opClass by name onto the dialect class.
- nb::object opClassName = opClass.attr("__name__");
- dialectClass.attr(opClassName) = opClass;
- return opClass;
- });
- },
- // clang-format off
- nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
- "-> typing.Callable[[type[T]], type[T]]"),
- // clang-format on
- "dialect_class"_a, nb::kw_only(), "replace"_a = false,
- "Produce a class decorator for registering an Operation class as part of "
- "a dialect");
- m.def(
- MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
- return nb::cpp_function([mlirTypeID, replace](
- nb::callable typeCaster) -> nb::object {
- PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
- return typeCaster;
- });
- },
- // clang-format off
- nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
- "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
- // clang-format on
- "typeid"_a, nb::kw_only(), "replace"_a = false,
- "Register a type caster for casting MLIR types to custom user types.");
- m.def(
- MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
- [](MlirTypeID mlirTypeID, bool replace) -> nb::object {
- return nb::cpp_function(
- [mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
- PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
- replace);
- return valueCaster;
- });
- },
- // clang-format off
- nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
- "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
- // clang-format on
- "typeid"_a, nb::kw_only(), "replace"_a = false,
- "Register a value caster for casting MLIR values to custom user values.");
+ // disable leak warnings which tend to be false positives.
+ nb::set_leak_warnings(false);
+ m.doc() = "MLIR Python Native Extension";
+ populateRoot(m);
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
populateIRCore(irModule);
>From 46252a91679eaf40f842ea97f2f8e27953d8ecd3 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 30 Dec 2025 20:48:33 -0800
Subject: [PATCH 36/38] gate standalone
---
mlir/test/Examples/standalone/test.wheel.toy | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy
index 8dedaa07c84f7..46f170579a977 100644
--- a/mlir/test/Examples/standalone/test.wheel.toy
+++ b/mlir/test/Examples/standalone/test.wheel.toy
@@ -1,6 +1,10 @@
# There's no real issue with windows here, it's just that some CMake generated paths for targets end up being longer
# than 255 chars when combined with the fact that pip wants to install into a tmp directory buried under
# C/Users/ContainerAdministrator/AppData/Local/Temp.
+# UNSUPPORTED: target={{.*(windows).*}}
+# REQUIRES: expensive_checks
+# REQUIRES: non-shared-libs-build
+# REQUIRES: bindings-python
# RUN: export CMAKE_BUILD_TYPE=%cmake_build_type
# RUN: export CMAKE_CXX_COMPILER=%host_cxx
@@ -14,8 +18,7 @@
# RUN: export MLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone
# RUN: export MLIR_BINDINGS_PYTHON_NB_DOMAIN=mlir_standalone
-# RUN: %python -m pip install scikit-build-core
-# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v --no-build-isolation | tee %t
+# RUN: %python -m pip wheel "%mlir_src_root/examples/standalone" -w "%mlir_obj_root/wheelhouse" -v | tee %t
# RUN: rm -rf "%mlir_obj_root/standalone-python-bindings-install"
# RUN: %python -m pip install standalone_python_bindings -f "%mlir_obj_root/wheelhouse" --target "%mlir_obj_root/standalone-python-bindings-install" -v | tee -a %t
@@ -34,6 +37,8 @@
# CHECK: %[[V0:.*]] = standalone.foo %[[C2]] : i32
# CHECK: }
+# CHECK: !standalone.custom<"foo">
+
# CHECK: Testing mlir package
# CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered!
>From 92a8875fa284db9eb0e7932c3577b214ca59715f Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 11:05:21 -0800
Subject: [PATCH 37/38] fix after rebase
---
mlir/examples/standalone/test/python/smoketest.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index dbb664d9190b2..fe4e40e6e8a99 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -21,7 +21,7 @@
custom_type = standalone_d.CustomType.get("foo")
# CHECK: !standalone.custom<"foo">
- print(custom_type)
+ print(custom_type, file=sys.stderr)
# CHECK: Testing mlir package
>From 8de0bcbf76ecd1db066f3f3fc34d0238a0285cb5 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 31 Dec 2025 14:20:39 -0800
Subject: [PATCH 38/38] [mlir][Python] move IRTypes and IRAttributes to public
headers
---
.../mlir/Bindings/Python/IRAttributes.h | 593 +++++
mlir/include/mlir/Bindings/Python/IRCore.h | 17 +-
mlir/include/mlir/Bindings/Python/IRTypes.h | 393 ++-
mlir/lib/Bindings/Python/IRAttributes.cpp | 2297 +++++++----------
mlir/lib/Bindings/Python/IRTypes.cpp | 1609 +++++-------
mlir/lib/Bindings/Python/MainModule.cpp | 2 +
mlir/python/CMakeLists.txt | 6 +-
.../python/lib/PythonTestModuleNanobind.cpp | 131 +-
8 files changed, 2632 insertions(+), 2416 deletions(-)
create mode 100644 mlir/include/mlir/Bindings/Python/IRAttributes.h
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
new file mode 100644
index 0000000000000..d64e32037664c
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -0,0 +1,593 @@
+//===- IRAttributes.h - Exports builtin and standard attributes -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+#define MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H
+
+#include <optional>
+#include <string>
+#include <string_view>
+#include <utility>
+
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir/Bindings/Python/NanobindUtils.h"
+
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+struct nb_buffer_info {
+ void *ptr = nullptr;
+ ssize_t itemsize = 0;
+ ssize_t size = 0;
+ const char *format = nullptr;
+ ssize_t ndim = 0;
+ SmallVector<ssize_t, 4> shape;
+ SmallVector<ssize_t, 4> strides;
+ bool readonly = false;
+
+ nb_buffer_info(
+ void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+ SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
+ bool readonly = false,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr));
+
+ explicit nb_buffer_info(Py_buffer *view)
+ : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
+ {view->shape, view->shape + view->ndim},
+ // TODO(phawkins): check for null strides
+ {view->strides, view->strides + view->ndim},
+ view->readonly != 0,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
+ view, PyBuffer_Release)) {}
+
+ nb_buffer_info(const nb_buffer_info &) = delete;
+ nb_buffer_info(nb_buffer_info &&) = default;
+ nb_buffer_info &operator=(const nb_buffer_info &) = delete;
+ nb_buffer_info &operator=(nb_buffer_info &&) = default;
+
+private:
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
+};
+
+class MLIR_PYTHON_API_EXPORTED nb_buffer : public nanobind::object {
+ NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
+
+ nb_buffer_info request() const;
+};
+
+template <typename T>
+struct nb_format_descriptor {};
+
+class MLIR_PYTHON_API_EXPORTED PyAffineMapAttribute
+ : public PyConcreteAttribute<PyAffineMapAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
+ static constexpr const char *pyClassName = "AffineMapAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirAffineMapAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerSetAttribute
+ : public PyConcreteAttribute<PyIntegerSetAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
+ static constexpr const char *pyClassName = "IntegerSetAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerSetAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+template <typename T>
+static T pyTryCast(nanobind::handle object) {
+ try {
+ return nanobind::cast<T>(object);
+ } catch (nanobind::cast_error &err) {
+ std::string msg = std::string("Invalid attribute when attempting to "
+ "create an ArrayAttribute (") +
+ err.what() + ")";
+ throw std::runtime_error(msg.c_str());
+ } catch (std::runtime_error &err) {
+ std::string msg = std::string("Invalid attribute (None?) when attempting "
+ "to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw std::runtime_error(msg.c_str());
+ }
+}
+
+/// A python-wrapped dense array attribute with an element type and a derived
+/// implementation class.
+template <typename EltTy, typename DerivedT>
+class MLIR_PYTHON_API_EXPORTED PyDenseArrayAttribute
+ : public PyConcreteAttribute<DerivedT> {
+public:
+ using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
+
+ /// Iterator over the integer elements of a dense array.
+ class PyDenseArrayIterator {
+ public:
+ PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+ /// Return a copy of the iterator.
+ PyDenseArrayIterator dunderIter() { return *this; }
+
+ /// Return the next element.
+ EltTy dunderNext() {
+ // Throw if the index has reached the end.
+ if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
+ throw nanobind::stop_iteration();
+ return DerivedT::getElement(attr.get(), nextIndex++);
+ }
+
+ /// Bind the iterator class.
+ static void bind(nanobind::module_ &m) {
+ nanobind::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
+ .def("__iter__", &PyDenseArrayIterator::dunderIter)
+ .def("__next__", &PyDenseArrayIterator::dunderNext);
+ }
+
+ private:
+ /// The referenced dense array attribute.
+ PyAttribute attr;
+ /// The next index to read.
+ int nextIndex = 0;
+ };
+
+ /// Get the element at the given index.
+ EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
+
+ /// Bind the attribute class.
+ static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
+ // Bind the constructor.
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ c.def_static(
+ "get",
+ [](const nanobind::sequence &py_values, DefaultingPyMlirContext ctx) {
+ std::vector<bool> values;
+ for (nanobind::handle py_value : py_values) {
+ int is_true = PyObject_IsTrue(py_value.ptr());
+ if (is_true < 0) {
+ throw nanobind::python_error();
+ }
+ values.push_back(is_true);
+ }
+ return getAttribute(values, ctx->getRef());
+ },
+ nanobind::arg("values"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued dense array attribute");
+ } else {
+ c.def_static(
+ "get",
+ [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
+ return getAttribute(values, ctx->getRef());
+ },
+ nanobind::arg("values"), nanobind::arg("context") = nanobind::none(),
+ "Gets a uniqued dense array attribute");
+ }
+ // Bind the array methods.
+ c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
+ if (i >= mlirDenseArrayGetNumElements(arr))
+ throw nanobind::index_error("DenseArray index out of range");
+ return arr.getItem(i);
+ });
+ c.def("__len__", [](const DerivedT &arr) {
+ return mlirDenseArrayGetNumElements(arr);
+ });
+ c.def("__iter__",
+ [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
+ c.def("__add__", [](DerivedT &arr, const nanobind::list &extras) {
+ std::vector<EltTy> values;
+ intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
+ values.reserve(numOldElements + nanobind::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ values.push_back(arr.getItem(i));
+ for (nanobind::handle attr : extras)
+ values.push_back(pyTryCast<EltTy>(attr));
+ return getAttribute(values, arr.getContext());
+ });
+ }
+
+private:
+ static DerivedT getAttribute(const std::vector<EltTy> &values,
+ PyMlirContextRef ctx) {
+ if constexpr (std::is_same_v<EltTy, bool>) {
+ std::vector<int> intValues(values.begin(), values.end());
+ MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
+ intValues.data());
+ return DerivedT(ctx, attr);
+ } else {
+ MlirAttribute attr =
+ DerivedT::getAttribute(ctx->get(), values.size(), values.data());
+ return DerivedT(ctx, attr);
+ }
+ }
+};
+
+/// Instantiate the python dense array classes.
+struct PyDenseBoolArrayAttribute
+ : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
+ static constexpr auto getAttribute = mlirDenseBoolArrayGet;
+ static constexpr auto getElement = mlirDenseBoolArrayGetElement;
+ static constexpr const char *pyClassName = "DenseBoolArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI8ArrayAttribute
+ : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
+ static constexpr auto getAttribute = mlirDenseI8ArrayGet;
+ static constexpr auto getElement = mlirDenseI8ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI8ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI16ArrayAttribute
+ : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
+ static constexpr auto getAttribute = mlirDenseI16ArrayGet;
+ static constexpr auto getElement = mlirDenseI16ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI16ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI32ArrayAttribute
+ : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
+ static constexpr auto getAttribute = mlirDenseI32ArrayGet;
+ static constexpr auto getElement = mlirDenseI32ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI32ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseI64ArrayAttribute
+ : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
+ static constexpr auto getAttribute = mlirDenseI64ArrayGet;
+ static constexpr auto getElement = mlirDenseI64ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseI64ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF32ArrayAttribute
+ : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
+ static constexpr auto getAttribute = mlirDenseF32ArrayGet;
+ static constexpr auto getElement = mlirDenseF32ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseF32ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+struct PyDenseF64ArrayAttribute
+ : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
+ static constexpr auto getAttribute = mlirDenseF64ArrayGet;
+ static constexpr auto getElement = mlirDenseF64ArrayGetElement;
+ static constexpr const char *pyClassName = "DenseF64ArrayAttr";
+ static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
+ using PyDenseArrayAttribute::PyDenseArrayAttribute;
+};
+
+class MLIR_PYTHON_API_EXPORTED PyArrayAttribute
+ : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+ static constexpr const char *pyClassName = "ArrayAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirArrayAttrGetTypeID;
+
+ class PyArrayAttributeIterator {
+ public:
+ PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
+
+ PyArrayAttributeIterator &dunderIter() { return *this; }
+
+ nanobind::typed<nanobind::object, PyAttribute> dunderNext();
+
+ static void bind(nanobind::module_ &m);
+
+ private:
+ PyAttribute attr;
+ int nextIndex = 0;
+ };
+
+ MlirAttribute getItem(intptr_t i) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Float Point Attribute subclass - FloatAttr.
+class MLIR_PYTHON_API_EXPORTED PyFloatAttribute
+ : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+ static constexpr const char *pyClassName = "FloatAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class MLIR_PYTHON_API_EXPORTED PyIntegerAttribute
+ : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static int64_t toPyInt(PyIntegerAttribute &self);
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class MLIR_PYTHON_API_EXPORTED PyBoolAttribute
+ : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+ static constexpr const char *pyClassName = "BoolAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PySymbolRefAttribute
+ : public PyConcreteAttribute<PySymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
+ static constexpr const char *pyClassName = "SymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFlatSymbolRefAttribute
+ : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
+ static constexpr const char *pyClassName = "FlatSymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyOpaqueAttribute
+ : public PyConcreteAttribute<PyOpaqueAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
+ static constexpr const char *pyClassName = "OpaqueAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+// TODO: Support construction of string elements.
+class MLIR_PYTHON_API_EXPORTED PyDenseElementsAttribute
+ : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+ static constexpr const char *pyClassName = "DenseElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseElementsAttribute
+ getFromList(const nanobind::list &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper);
+
+ static PyDenseElementsAttribute
+ getFromBuffer(const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ DefaultingPyMlirContext contextWrapper);
+
+ static PyDenseElementsAttribute getSplat(const PyType &shapedType,
+ PyAttribute &elementAttr);
+
+ intptr_t dunderLen() const;
+
+ std::unique_ptr<nb_buffer_info> accessBuffer();
+
+ static void bindDerived(ClassTy &c);
+
+ static PyType_Slot slots[];
+
+private:
+ static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
+ static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+
+ static bool isUnsignedIntegerFormat(std::string_view format);
+
+ static bool isSignedIntegerFormat(std::string_view format);
+
+ static MlirType
+ getShapedType(std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ Py_buffer &view);
+
+ static MlirAttribute getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context);
+
+ // There is a complication for boolean numpy arrays, as numpy represents
+ // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
+ // booleans per byte.
+ static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context);
+
+ // This does the opposite transformation of
+ // `getBitpackedAttributeFromBooleanBuffer`
+ std::unique_ptr<nb_buffer_info>
+ getBooleanBufferFromBitpackedAttribute() const;
+
+ template <typename Type>
+ std::unique_ptr<nb_buffer_info>
+ bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ // Prepare the data for the buffer_info.
+ // Buffer is configured for read-only access below.
+ Type *data = static_cast<Type *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ // Prepare the shape for the buffer_info.
+ SmallVector<intptr_t, 4> shape;
+ for (intptr_t i = 0; i < rank; ++i)
+ shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+ // Prepare the strides for the buffer_info.
+ SmallVector<intptr_t, 4> strides;
+ if (mlirDenseElementsAttrIsSplat(*this)) {
+ // Splats are special, only the single value is stored.
+ strides.assign(rank, 0);
+ } else {
+ for (intptr_t i = 1; i < rank; ++i) {
+ intptr_t strideFactor = 1;
+ for (intptr_t j = i; j < rank; ++j)
+ strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+ strides.push_back(sizeof(Type) * strideFactor);
+ }
+ strides.push_back(sizeof(Type));
+ }
+ const char *format;
+ if (explicitFormat) {
+ format = explicitFormat;
+ } else {
+ format = nb_format_descriptor<Type>::format();
+ }
+ return std::make_unique<nb_buffer_info>(
+ data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
+ /*readonly=*/true);
+ }
+};
+
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
+class MLIR_PYTHON_API_EXPORTED PyDenseIntElementsAttribute
+ : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+ static constexpr const char *pyClassName = "DenseIntElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ /// Returns the element at the given linear position. Asserts if the index
+ /// is out of range.
+ nanobind::int_ dunderGetItem(intptr_t pos) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyDenseResourceElementsAttribute
+ : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction =
+ mlirAttributeIsADenseResourceElements;
+ static constexpr const char *pyClassName = "DenseResourceElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseResourceElementsAttribute
+ getFromBuffer(const nb_buffer &buffer, const std::string &name,
+ const PyType &type, std::optional<size_t> alignment,
+ bool isMutable, DefaultingPyMlirContext contextWrapper);
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyDictAttribute
+ : public PyConcreteAttribute<PyDictAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+ static constexpr const char *pyClassName = "DictAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirDictionaryAttrGetTypeID;
+
+ intptr_t dunderLen() const;
+
+ bool dunderContains(const std::string &name) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class MLIR_PYTHON_API_EXPORTED PyDenseFPElementsAttribute
+ : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+ static constexpr const char *pyClassName = "DenseFPElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ nanobind::float_ dunderGetItem(intptr_t pos) const;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyTypeAttribute
+ : public PyConcreteAttribute<PyTypeAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
+ static constexpr const char *pyClassName = "TypeAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTypeAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unit Attribute subclass. Unit attributes don't have values.
+class MLIR_PYTHON_API_EXPORTED PyUnitAttribute
+ : public PyConcreteAttribute<PyUnitAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
+ static constexpr const char *pyClassName = "UnitAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnitAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Strided layout attribute subclass.
+class MLIR_PYTHON_API_EXPORTED PyStridedLayoutAttribute
+ : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+ static constexpr const char *pyClassName = "StridedLayoutAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirStridedLayoutAttrGetTypeID;
+
+ static void bindDerived(ClassTy &c);
+};
+
+MLIR_PYTHON_API_EXPORTED void populateIRAttributes(nanobind::module_ &m);
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index af6c8dbbb7fa8..1e435a1d442d4 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -989,7 +989,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
nanobind::cast<nanobind::callable>(nanobind::cpp_function(
- [](PyType pyType) -> DerivedTy { return pyType; })));
+ [](PyType pyType) -> DerivedTy { return pyType; })),
+ /*replace*/ true);
}
DerivedTy::bindDerived(cls);
@@ -1133,7 +1134,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
nanobind::cast<nanobind::callable>(
nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
return pyAttribute;
- })));
+ })),
+ /*replace*/ true);
}
DerivedTy::bindDerived(cls);
@@ -1517,6 +1519,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
// and redefine bindDerived.
using ClassTy = nanobind::class_<DerivedTy, PyValue>;
using IsAFunctionTy = bool (*)(MlirValue);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef, MlirValue value)
@@ -1559,6 +1563,15 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue {
[](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
return self.maybeDownCast();
});
+
+ if (DerivedTy::getTypeIdFunction) {
+ PyGlobals::get().registerValueCaster(
+ DerivedTy::getTypeIdFunction(),
+ nanobind::cast<nanobind::callable>(nanobind::cpp_function(
+ [](PyValue pyValue) -> DerivedTy { return pyValue; })),
+ /*replace*/ true);
+ }
+
DerivedTy::bindDerived(cls);
}
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 87e0e10764bd8..a0901fefec5ce 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,13 +9,284 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "mlir-c/BuiltinTypes.h"
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+
+MLIR_PYTHON_API_EXPORTED int mlirTypeIsAIntegerOrFloat(MlirType type);
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerType
+ : public PyConcreteType<PyIntegerType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIntegerTypeGetTypeID;
+ static constexpr const char *pyClassName = "IntegerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Index Type subclass - IndexType.
+class MLIR_PYTHON_API_EXPORTED PyIndexType
+ : public PyConcreteType<PyIndexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirIndexTypeGetTypeID;
+ static constexpr const char *pyClassName = "IndexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFloatType
+ : public PyConcreteType<PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+ static constexpr const char *pyClassName = "FloatType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float4E2M1FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat4E2M1FNType
+ : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat4E2M1FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float4E2M1FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E2M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E2M3FNType
+ : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E2M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E2M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E3M2FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E3M2FNType
+ : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E3M2FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E3M2FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNType
+ : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2Type
+ : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3Type
+ : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNUZType
+ : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3B11FNUZType
+ : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E4M3B11FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2FNUZType
+ : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E5M2FNUZTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E5M2FNUZType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E3M4Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E3M4Type
+ : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E3M4TypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E3M4Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E8M0FNUType
+ : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat8E8M0FNUTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float8E8M0FNUType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - BF16Type.
+class MLIR_PYTHON_API_EXPORTED PyBF16Type
+ : public PyConcreteType<PyBF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirBFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "BF16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F16Type.
+class MLIR_PYTHON_API_EXPORTED PyF16Type
+ : public PyConcreteType<PyF16Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat16TypeGetTypeID;
+ static constexpr const char *pyClassName = "F16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - TF32Type.
+class MLIR_PYTHON_API_EXPORTED PyTF32Type
+ : public PyConcreteType<PyTF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloatTF32TypeGetTypeID;
+ static constexpr const char *pyClassName = "FloatTF32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F32Type.
+class MLIR_PYTHON_API_EXPORTED PyF32Type
+ : public PyConcreteType<PyF32Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat32TypeGetTypeID;
+ static constexpr const char *pyClassName = "F32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F64Type.
+class MLIR_PYTHON_API_EXPORTED PyF64Type
+ : public PyConcreteType<PyF64Type, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat64TypeGetTypeID;
+ static constexpr const char *pyClassName = "F64Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// None Type subclass - NoneType.
+class MLIR_PYTHON_API_EXPORTED PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirNoneTypeGetTypeID;
+ static constexpr const char *pyClassName = "NoneType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Complex Type subclass - ComplexType.
+class MLIR_PYTHON_API_EXPORTED PyComplexType
+ : public PyConcreteType<PyComplexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirComplexTypeGetTypeID;
+ static constexpr const char *pyClassName = "ComplexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
/// Shaped Type Interface - ShapedType
-class MLIR_PYTHON_API_EXPORTED PyShapedType
+class MLIR_PYTHON_API_EXPORTED MLIR_PYTHON_API_EXPORTED PyShapedType
: public PyConcreteType<PyShapedType> {
public:
static const IsAFunctionTy isaFunction;
@@ -27,6 +298,124 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType
private:
void requireHasRank();
};
+
+/// Vector Type subclass - VectorType.
+class MLIR_PYTHON_API_EXPORTED PyVectorType
+ : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirVectorTypeGetTypeID;
+ static constexpr const char *pyClassName = "VectorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+
+private:
+ static PyVectorType
+ getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc);
+
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nanobind::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context);
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyRankedTensorType
+ : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "RankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedTensorType
+ : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Ranked MemRef Type subclass - MemRefType.
+class MLIR_PYTHON_API_EXPORTED PyMemRefType
+ : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "MemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedMemRefType
+ : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirUnrankedMemRefTypeGetTypeID;
+ static constexpr const char *pyClassName = "UnrankedMemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Tuple Type subclass - TupleType.
+class MLIR_PYTHON_API_EXPORTED PyTupleType
+ : public PyConcreteType<PyTupleType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirTupleTypeGetTypeID;
+ static constexpr const char *pyClassName = "TupleType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Function type.
+class MLIR_PYTHON_API_EXPORTED PyFunctionType
+ : public PyConcreteType<PyFunctionType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFunctionTypeGetTypeID;
+ static constexpr const char *pyClassName = "FunctionType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+/// Opaque Type subclass - OpaqueType.
+class MLIR_PYTHON_API_EXPORTED PyOpaqueType
+ : public PyConcreteType<PyOpaqueType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirOpaqueTypeGetTypeID;
+ static constexpr const char *pyClassName = "OpaqueType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c);
+};
+
+MLIR_PYTHON_API_EXPORTED void populateIRTypes(nanobind::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index f0f0ae9ba741e..3cd3ce5c4c0ee 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -14,6 +14,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
@@ -125,65 +126,29 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-struct nb_buffer_info {
- void *ptr = nullptr;
- ssize_t itemsize = 0;
- ssize_t size = 0;
- const char *format = nullptr;
- ssize_t ndim = 0;
- SmallVector<ssize_t, 4> shape;
- SmallVector<ssize_t, 4> strides;
- bool readonly = false;
-
- nb_buffer_info(
- void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
- SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
- bool readonly = false,
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
- : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
- shape(std::move(shape_in)), strides(std::move(strides_in)),
- readonly(readonly), owned_view(std::move(owned_view_in)) {
- size = 1;
- for (ssize_t i = 0; i < ndim; ++i) {
- size *= shape[i];
- }
+nb_buffer_info::nb_buffer_info(
+ void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
+ SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
+ bool readonly,
+ std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in)
+ : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
+ shape(std::move(shape_in)), strides(std::move(strides_in)),
+ readonly(readonly), owned_view(std::move(owned_view_in)) {
+ size = 1;
+ for (ssize_t i = 0; i < ndim; ++i) {
+ size *= shape[i];
}
+}
- explicit nb_buffer_info(Py_buffer *view)
- : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
- {view->shape, view->shape + view->ndim},
- // TODO(phawkins): check for null strides
- {view->strides, view->strides + view->ndim},
- view->readonly != 0,
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
- view, PyBuffer_Release)) {}
-
- nb_buffer_info(const nb_buffer_info &) = delete;
- nb_buffer_info(nb_buffer_info &&) = default;
- nb_buffer_info &operator=(const nb_buffer_info &) = delete;
- nb_buffer_info &operator=(nb_buffer_info &&) = default;
-
-private:
- std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
-};
-
-class nb_buffer : public nb::object {
- NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
-
- nb_buffer_info request() const {
- int flags = PyBUF_STRIDES | PyBUF_FORMAT;
- auto *view = new Py_buffer();
- if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
- delete view;
- throw nb::python_error();
- }
- return nb_buffer_info(view);
+nb_buffer_info nb_buffer::request() const {
+ int flags = PyBUF_STRIDES | PyBUF_FORMAT;
+ auto *view = new Py_buffer();
+ if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
+ delete view;
+ throw nb::python_error();
}
-};
-
-template <typename T>
-struct nb_format_descriptor {};
+ return nb_buffer_info(view);
+}
template <>
struct nb_format_descriptor<bool> {
@@ -230,1052 +195,719 @@ struct nb_format_descriptor<double> {
static const char *format() { return "d"; }
};
-class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
- static constexpr const char *pyClassName = "AffineMapAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirAffineMapAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyAffineMap &affineMap) {
- MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
- return PyAffineMapAttribute(affineMap.getContext(), attr);
- },
- nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
- c.def_prop_ro(
- "value",
- [](PyAffineMapAttribute &self) {
- return PyAffineMap(self.getContext(),
- mlirAffineMapAttrGetValue(self));
- },
- "Returns the value of the AffineMap attribute");
- }
-};
+void PyAffineMapAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyAffineMap &affineMap) {
+ MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
+ return PyAffineMapAttribute(affineMap.getContext(), attr);
+ },
+ nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+ c.def_prop_ro(
+ "value",
+ [](PyAffineMapAttribute &self) {
+ return PyAffineMap(self.getContext(), mlirAffineMapAttrGetValue(self));
+ },
+ "Returns the value of the AffineMap attribute");
+}
-class PyIntegerSetAttribute
- : public PyConcreteAttribute<PyIntegerSetAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
- static constexpr const char *pyClassName = "IntegerSetAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerSetAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyIntegerSet &integerSet) {
- MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
- return PyIntegerSetAttribute(integerSet.getContext(), attr);
- },
- nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
- }
-};
+void PyIntegerSetAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyIntegerSet &integerSet) {
+ MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
+ return PyIntegerSetAttribute(integerSet.getContext(), attr);
+ },
+ nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
+}
-template <typename T>
-static T pyTryCast(nb::handle object) {
- try {
- return nb::cast<T>(object);
- } catch (nb::cast_error &err) {
- std::string msg = std::string("Invalid attribute when attempting to "
- "create an ArrayAttribute (") +
- err.what() + ")";
- throw std::runtime_error(msg.c_str());
- } catch (std::runtime_error &err) {
- std::string msg = std::string("Invalid attribute (None?) when attempting "
- "to create an ArrayAttribute (") +
- err.what() + ")";
- throw std::runtime_error(msg.c_str());
- }
+nb::typed<nb::object, PyAttribute>
+PyArrayAttribute::PyArrayAttributeIterator::dunderNext() {
+ // TODO: Throw is an inefficient way to stop iteration.
+ if (PyArrayAttribute::PyArrayAttributeIterator::nextIndex >=
+ mlirArrayAttrGetNumElements(
+ PyArrayAttribute::PyArrayAttributeIterator::attr.get()))
+ throw nb::stop_iteration();
+ return PyAttribute(
+ this->PyArrayAttribute::PyArrayAttributeIterator::attr
+ .getContext(),
+ mlirArrayAttrGetElement(
+ PyArrayAttribute::PyArrayAttributeIterator::attr.get(),
+ PyArrayAttribute::PyArrayAttributeIterator::nextIndex++))
+ .maybeDownCast();
}
-/// A python-wrapped dense array attribute with an element type and a derived
-/// implementation class.
-template <typename EltTy, typename DerivedT>
-class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
-public:
- using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
-
- /// Iterator over the integer elements of a dense array.
- class PyDenseArrayIterator {
- public:
- PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
- /// Return a copy of the iterator.
- PyDenseArrayIterator dunderIter() { return *this; }
-
- /// Return the next element.
- EltTy dunderNext() {
- // Throw if the index has reached the end.
- if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
- throw nb::stop_iteration();
- return DerivedT::getElement(attr.get(), nextIndex++);
- }
+void PyArrayAttribute::PyArrayAttributeIterator::bind(nb::module_ &m) {
+ nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+ .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+ .def("__next__", &PyArrayAttributeIterator::dunderNext);
+}
- /// Bind the iterator class.
- static void bind(nb::module_ &m) {
- nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
- .def("__iter__", &PyDenseArrayIterator::dunderIter)
- .def("__next__", &PyDenseArrayIterator::dunderNext);
- }
+MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
+ return mlirArrayAttrGetElement(*this, i);
+}
- private:
- /// The referenced dense array attribute.
- PyAttribute attr;
- /// The next index to read.
- int nextIndex = 0;
- };
+void PyArrayAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const nb::list &attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(nb::len(attributes));
+ for (auto attribute : attributes) {
+ mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
+ }
+ MlirAttribute attr = mlirArrayAttrGet(
+ context->get(), mlirAttributes.size(), mlirAttributes.data());
+ return PyArrayAttribute(context->getRef(), attr);
+ },
+ nb::arg("attributes"), nb::arg("context") = nb::none(),
+ "Gets a uniqued Array attribute");
+ c.def("__getitem__",
+ [](PyArrayAttribute &arr,
+ intptr_t i) -> nb::typed<nb::object, PyAttribute> {
+ if (i >= mlirArrayAttrGetNumElements(arr))
+ throw nb::index_error("ArrayAttribute index out of range");
+ return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
+ })
+ .def("__len__",
+ [](const PyArrayAttribute &arr) {
+ return mlirArrayAttrGetNumElements(arr);
+ })
+ .def("__iter__", [](const PyArrayAttribute &arr) {
+ return PyArrayAttributeIterator(arr);
+ });
+ c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
+ std::vector<MlirAttribute> attributes;
+ intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
+ attributes.reserve(numOldElements + nb::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ attributes.push_back(arr.getItem(i));
+ for (nb::handle attr : extras)
+ attributes.push_back(pyTryCast<PyAttribute>(attr));
+ MlirAttribute arrayAttr = mlirArrayAttrGet(
+ arr.getContext()->get(), attributes.size(), attributes.data());
+ return PyArrayAttribute(arr.getContext(), arrayAttr);
+ });
+}
+void PyFloatAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, double value, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &type, double value, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute attr =
+ mlirFloatAttrDoubleGet(context.get()->get(), type, value);
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_f32",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF32TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a f32 type");
+ c.def_static(
+ "get_f64",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF64TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued float point attribute associated to a f64 type");
+ c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
+ "Returns the value of the float attribute");
+ c.def("__float__", mlirFloatAttrGetValueDouble,
+ "Converts the value of the float attribute to a Python float");
+}
- /// Get the element at the given index.
- EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
-
- /// Bind the attribute class.
- static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
- // Bind the constructor.
- if constexpr (std::is_same_v<EltTy, bool>) {
- c.def_static(
- "get",
- [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
- std::vector<bool> values;
- for (nb::handle py_value : py_values) {
- int is_true = PyObject_IsTrue(py_value.ptr());
- if (is_true < 0) {
- throw nb::python_error();
- }
- values.push_back(is_true);
- }
- return getAttribute(values, ctx->getRef());
- },
- nb::arg("values"), nb::arg("context") = nb::none(),
- "Gets a uniqued dense array attribute");
- } else {
- c.def_static(
- "get",
- [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
- return getAttribute(values, ctx->getRef());
- },
- nb::arg("values"), nb::arg("context") = nb::none(),
- "Gets a uniqued dense array attribute");
- }
- // Bind the array methods.
- c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
- if (i >= mlirDenseArrayGetNumElements(arr))
- throw nb::index_error("DenseArray index out of range");
- return arr.getItem(i);
- });
- c.def("__len__", [](const DerivedT &arr) {
- return mlirDenseArrayGetNumElements(arr);
- });
- c.def("__iter__",
- [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
- c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
- std::vector<EltTy> values;
- intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
- values.reserve(numOldElements + nb::len(extras));
- for (intptr_t i = 0; i < numOldElements; ++i)
- values.push_back(arr.getItem(i));
- for (nb::handle attr : extras)
- values.push_back(pyTryCast<EltTy>(attr));
- return getAttribute(values, arr.getContext());
- });
- }
+void PyIntegerAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, int64_t value) {
+ MlirAttribute attr = mlirIntegerAttrGet(type, value);
+ return PyIntegerAttribute(type.getContext(), attr);
+ },
+ nb::arg("type"), nb::arg("value"),
+ "Gets an uniqued integer attribute associated to a type");
+ c.def_prop_ro("value", toPyInt, "Returns the value of the integer attribute");
+ c.def("__int__", toPyInt,
+ "Converts the value of the integer attribute to a Python int");
+ c.def_prop_ro_static(
+ "static_typeid",
+ [](nb::object & /*class*/) {
+ return PyTypeID(mlirIntegerAttrGetTypeID());
+ },
+ nb::sig("def static_typeid(/) -> TypeID"));
+}
-private:
- static DerivedT getAttribute(const std::vector<EltTy> &values,
- PyMlirContextRef ctx) {
- if constexpr (std::is_same_v<EltTy, bool>) {
- std::vector<int> intValues(values.begin(), values.end());
- MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
- intValues.data());
- return DerivedT(ctx, attr);
- } else {
- MlirAttribute attr =
- DerivedT::getAttribute(ctx->get(), values.size(), values.data());
- return DerivedT(ctx, attr);
- }
- }
-};
+int64_t PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
+ MlirType type = mlirAttributeGetType(self);
+ if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+ return mlirIntegerAttrGetValueInt(self);
+ if (mlirIntegerTypeIsSigned(type))
+ return mlirIntegerAttrGetValueSInt(self);
+ return mlirIntegerAttrGetValueUInt(self);
+}
-/// Instantiate the python dense array classes.
-struct PyDenseBoolArrayAttribute
- : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
- static constexpr auto getAttribute = mlirDenseBoolArrayGet;
- static constexpr auto getElement = mlirDenseBoolArrayGetElement;
- static constexpr const char *pyClassName = "DenseBoolArrayAttr";
- static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI8ArrayAttribute
- : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
- static constexpr auto getAttribute = mlirDenseI8ArrayGet;
- static constexpr auto getElement = mlirDenseI8ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI8ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI16ArrayAttribute
- : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
- static constexpr auto getAttribute = mlirDenseI16ArrayGet;
- static constexpr auto getElement = mlirDenseI16ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI16ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI32ArrayAttribute
- : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
- static constexpr auto getAttribute = mlirDenseI32ArrayGet;
- static constexpr auto getElement = mlirDenseI32ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI32ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseI64ArrayAttribute
- : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
- static constexpr auto getAttribute = mlirDenseI64ArrayGet;
- static constexpr auto getElement = mlirDenseI64ArrayGetElement;
- static constexpr const char *pyClassName = "DenseI64ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF32ArrayAttribute
- : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
- static constexpr auto getAttribute = mlirDenseF32ArrayGet;
- static constexpr auto getElement = mlirDenseF32ArrayGetElement;
- static constexpr const char *pyClassName = "DenseF32ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
-struct PyDenseF64ArrayAttribute
- : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
- static constexpr auto getAttribute = mlirDenseF64ArrayGet;
- static constexpr auto getElement = mlirDenseF64ArrayGetElement;
- static constexpr const char *pyClassName = "DenseF64ArrayAttr";
- static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
- using PyDenseArrayAttribute::PyDenseArrayAttribute;
-};
+void PyBoolAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](bool value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
+ return PyBoolAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets an uniqued bool attribute");
+ c.def_prop_ro("value", mlirBoolAttrGetValue,
+ "Returns the value of the bool attribute");
+ c.def("__bool__", mlirBoolAttrGetValue,
+ "Converts the value of the bool attribute to a Python bool");
+}
-class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
- static constexpr const char *pyClassName = "ArrayAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirArrayAttrGetTypeID;
-
- class PyArrayAttributeIterator {
- public:
- PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
-
- PyArrayAttributeIterator &dunderIter() { return *this; }
-
- nb::typed<nb::object, PyAttribute> dunderNext() {
- // TODO: Throw is an inefficient way to stop iteration.
- if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
- throw nb::stop_iteration();
- return PyAttribute(this->attr.getContext(),
- mlirArrayAttrGetElement(attr.get(), nextIndex++))
- .maybeDownCast();
- }
+PySymbolRefAttribute
+PySymbolRefAttribute::fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context) {
+ if (symbols.empty())
+ throw std::runtime_error("SymbolRefAttr must be composed of at least "
+ "one symbol.");
+ MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
+ SmallVector<MlirAttribute, 3> referenceAttrs;
+ for (size_t i = 1; i < symbols.size(); ++i) {
+ referenceAttrs.push_back(
+ mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
+ }
+ return PySymbolRefAttribute(context.getRef(),
+ mlirSymbolRefAttrGet(context.get(), rootSymbol,
+ referenceAttrs.size(),
+ referenceAttrs.data()));
+}
- static void bind(nb::module_ &m) {
- nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
- .def("__iter__", &PyArrayAttributeIterator::dunderIter)
- .def("__next__", &PyArrayAttributeIterator::dunderNext);
- }
+void PySymbolRefAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::vector<std::string> &symbols,
+ DefaultingPyMlirContext context) {
+ return PySymbolRefAttribute::fromList(symbols, context.resolve());
+ },
+ nb::arg("symbols"), nb::arg("context") = nb::none(),
+ "Gets a uniqued SymbolRef attribute from a list of symbol names");
+ c.def_prop_ro(
+ "value",
+ [](PySymbolRefAttribute &self) {
+ std::vector<std::string> symbols = {
+ unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
+ for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); ++i)
+ symbols.push_back(
+ unwrap(mlirSymbolRefAttrGetRootReference(
+ mlirSymbolRefAttrGetNestedReference(self, i)))
+ .str());
+ return symbols;
+ },
+ "Returns the value of the SymbolRef attribute as a list[str]");
+}
- private:
- PyAttribute attr;
- int nextIndex = 0;
- };
+void PyFlatSymbolRefAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
+ return PyFlatSymbolRefAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets a uniqued FlatSymbolRef attribute");
+ c.def_prop_ro(
+ "value",
+ [](PyFlatSymbolRefAttribute &self) {
+ MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the FlatSymbolRef attribute as a string");
+}
- MlirAttribute getItem(intptr_t i) {
- return mlirArrayAttrGetElement(*this, i);
- }
+void PyOpaqueAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const nb_buffer &buffer,
+ PyType &type, DefaultingPyMlirContext context) {
+ const nb_buffer_info bufferInfo = buffer.request();
+ intptr_t bufferSize = bufferInfo.size;
+ MlirAttribute attr = mlirOpaqueAttrGet(
+ context->get(), toMlirStringRef(dialectNamespace), bufferSize,
+ static_cast<char *>(bufferInfo.ptr), type);
+ return PyOpaqueAttribute(context->getRef(), attr);
+ },
+ nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
+ nb::arg("context") = nb::none(),
+ // clang-format off
+ nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
+ // clang-format on
+ "Gets an Opaque attribute.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueAttribute &self) {
+ MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque attribute as a string");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueAttribute &self) {
+ MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
+ return nb::bytes(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaqued attributes as `bytes`");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const nb::list &attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(nb::len(attributes));
- for (auto attribute : attributes) {
- mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
- }
- MlirAttribute attr = mlirArrayAttrGet(
- context->get(), mlirAttributes.size(), mlirAttributes.data());
- return PyArrayAttribute(context->getRef(), attr);
- },
- nb::arg("attributes"), nb::arg("context") = nb::none(),
- "Gets a uniqued Array attribute");
- c.def(
- "__getitem__",
- [](PyArrayAttribute &arr,
- intptr_t i) -> nb::typed<nb::object, PyAttribute> {
- if (i >= mlirArrayAttrGetNumElements(arr))
- throw nb::index_error("ArrayAttribute index out of range");
- return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
- })
- .def("__len__",
- [](const PyArrayAttribute &arr) {
- return mlirArrayAttrGetNumElements(arr);
- })
- .def("__iter__", [](const PyArrayAttribute &arr) {
- return PyArrayAttributeIterator(arr);
- });
- c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
- std::vector<MlirAttribute> attributes;
- intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
- attributes.reserve(numOldElements + nb::len(extras));
- for (intptr_t i = 0; i < numOldElements; ++i)
- attributes.push_back(arr.getItem(i));
- for (nb::handle attr : extras)
- attributes.push_back(pyTryCast<PyAttribute>(attr));
- MlirAttribute arrayAttr = mlirArrayAttrGet(
- arr.getContext()->get(), attributes.size(), attributes.data());
- return PyArrayAttribute(arr.getContext(), arrayAttr);
- });
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getFromList(const nb::list &attributes,
+ std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
+ const size_t numAttributes = nb::len(attributes);
+ if (numAttributes == 0)
+ throw nb::value_error("Attributes list must be non-empty.");
+
+ MlirType shapedType;
+ if (explicitType) {
+ if ((!mlirTypeIsAShaped(*explicitType) ||
+ !mlirShapedTypeHasStaticShape(*explicitType))) {
+
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Expected a static ShapedType for the shaped_type parameter: "
+ << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
+ throw nb::value_error(message.c_str());
+ }
+ shapedType = *explicitType;
+ } else {
+ SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+ mlirAttributeGetNull());
}
-};
-/// Float Point Attribute subclass - FloatAttr.
-class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
- static constexpr const char *pyClassName = "FloatAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, double value, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Invalid attribute", errors.take());
- return PyFloatAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
- "Gets an uniqued float point attribute associated to a type");
- c.def_static(
- "get_unchecked",
- [](PyType &type, double value, DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute attr =
- mlirFloatAttrDoubleGet(context.get()->get(), type, value);
- if (mlirAttributeIsNull(attr))
- throw MLIRError("Invalid attribute", errors.take());
- return PyFloatAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a type");
- c.def_static(
- "get_f32",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF32TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a f32 type");
- c.def_static(
- "get_f64",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF64TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued float point attribute associated to a f64 type");
- c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
- "Returns the value of the float attribute");
- c.def("__float__", mlirFloatAttrGetValueDouble,
- "Converts the value of the float attribute to a Python float");
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(numAttributes);
+ for (const nb::handle &attribute : attributes) {
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+ MlirType attrType = mlirAttributeGetType(mlirAttribute);
+ mlirAttributes.push_back(mlirAttribute);
+
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "All attributes must be of the same type and match "
+ << "the type parameter: expected="
+ << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
+ << ", but got=" << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
+ throw nb::value_error(message.c_str());
+ }
}
-};
-/// Integer Attribute subclass - IntegerAttr.
-class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
- static constexpr const char *pyClassName = "IntegerAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, int64_t value) {
- MlirAttribute attr = mlirIntegerAttrGet(type, value);
- return PyIntegerAttribute(type.getContext(), attr);
- },
- nb::arg("type"), nb::arg("value"),
- "Gets an uniqued integer attribute associated to a type");
- c.def_prop_ro("value", toPyInt,
- "Returns the value of the integer attribute");
- c.def("__int__", toPyInt,
- "Converts the value of the integer attribute to a Python int");
- c.def_prop_ro_static(
- "static_typeid",
- [](nb::object & /*class*/) {
- return PyTypeID(mlirIntegerAttrGetTypeID());
- },
- nanobind::sig("def static_typeid(/) -> TypeID"));
- }
+ MlirAttribute elements = mlirDenseElementsAttrGet(
+ shapedType, mlirAttributes.size(), mlirAttributes.data());
-private:
- static int64_t toPyInt(PyIntegerAttribute &self) {
- MlirType type = mlirAttributeGetType(self);
- if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
- return mlirIntegerAttrGetValueInt(self);
- if (mlirIntegerTypeIsSigned(type))
- return mlirIntegerAttrGetValueSInt(self);
- return mlirIntegerAttrGetValueUInt(self);
- }
-};
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
-/// Bool Attribute subclass - BoolAttr.
-class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
- static constexpr const char *pyClassName = "BoolAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](bool value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
- return PyBoolAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets an uniqued bool attribute");
- c.def_prop_ro("value", mlirBoolAttrGetValue,
- "Returns the value of the bool attribute");
- c.def("__bool__", mlirBoolAttrGetValue,
- "Converts the value of the bool attribute to a Python bool");
+PyDenseElementsAttribute PyDenseElementsAttribute::getFromBuffer(
+ const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ DefaultingPyMlirContext contextWrapper) {
+ // Request a contiguous view. In exotic cases, this will cause a copy.
+ int flags = PyBUF_ND;
+ if (!explicitType) {
+ flags |= PyBUF_FORMAT;
}
-};
-
-class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
- static constexpr const char *pyClassName = "SymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
- PyMlirContext &context) {
- if (symbols.empty())
- throw std::runtime_error("SymbolRefAttr must be composed of at least "
- "one symbol.");
- MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
- SmallVector<MlirAttribute, 3> referenceAttrs;
- for (size_t i = 1; i < symbols.size(); ++i) {
- referenceAttrs.push_back(
- mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
- }
- return PySymbolRefAttribute(context.getRef(),
- mlirSymbolRefAttrGet(context.get(), rootSymbol,
- referenceAttrs.size(),
- referenceAttrs.data()));
+ Py_buffer view;
+ if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
+ throw nb::python_error();
}
+ auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::vector<std::string> &symbols,
- DefaultingPyMlirContext context) {
- return PySymbolRefAttribute::fromList(symbols, context.resolve());
- },
- nb::arg("symbols"), nb::arg("context") = nb::none(),
- "Gets a uniqued SymbolRef attribute from a list of symbol names");
- c.def_prop_ro(
- "value",
- [](PySymbolRefAttribute &self) {
- std::vector<std::string> symbols = {
- unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
- for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
- ++i)
- symbols.push_back(
- unwrap(mlirSymbolRefAttrGetRootReference(
- mlirSymbolRefAttrGetNestedReference(self, i)))
- .str());
- return symbols;
- },
- "Returns the value of the SymbolRef attribute as a list[str]");
+ MlirContext context = contextWrapper->get();
+ MlirAttribute attr = getAttributeFromBuffer(
+ view, signless, explicitType, std::move(explicitShape), context);
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseElementsAttr could not be constructed from the given buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
}
-};
+ return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+}
-class PyFlatSymbolRefAttribute
- : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
- static constexpr const char *pyClassName = "FlatSymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
- return PyFlatSymbolRefAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets a uniqued FlatSymbolRef attribute");
- c.def_prop_ro(
- "value",
- [](PyFlatSymbolRefAttribute &self) {
- MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the value of the FlatSymbolRef attribute as a string");
+PyDenseElementsAttribute
+PyDenseElementsAttribute::getSplat(const PyType &shapedType,
+ PyAttribute &elementAttr) {
+ auto contextWrapper =
+ PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+ if (!mlirAttributeIsAInteger(elementAttr) &&
+ !mlirAttributeIsAFloat(elementAttr)) {
+ std::string message = "Illegal element type for DenseElementsAttr: ";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+ throw nb::value_error(message.c_str());
}
-};
-
-class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
- static constexpr const char *pyClassName = "OpaqueAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const nb_buffer &buffer,
- PyType &type, DefaultingPyMlirContext context) {
- const nb_buffer_info bufferInfo = buffer.request();
- intptr_t bufferSize = bufferInfo.size;
- MlirAttribute attr = mlirOpaqueAttrGet(
- context->get(), toMlirStringRef(dialectNamespace), bufferSize,
- static_cast<char *>(bufferInfo.ptr), type);
- return PyOpaqueAttribute(context->getRef(), attr);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
- nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
- // clang-format on
- "Gets an Opaque attribute.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueAttribute &self) {
- MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque attribute as a string");
- c.def_prop_ro(
- "data",
- [](PyOpaqueAttribute &self) {
- MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
- return nb::bytes(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaqued attributes as `bytes`");
+ if (!mlirTypeIsAShaped(shapedType) ||
+ !mlirShapedTypeHasStaticShape(shapedType)) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+ throw nb::value_error(message.c_str());
+ }
+ MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+ MlirType attrType = mlirAttributeGetType(elementAttr);
+ if (!mlirTypeEqual(shapedElementType, attrType)) {
+ std::string message =
+ "Shaped element type and attribute type must be equal: shaped=";
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
+ message.append(", element=");
+ message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
+ throw nb::value_error(message.c_str());
}
-};
-// TODO: Support construction of string elements.
-class PyDenseElementsAttribute
- : public PyConcreteAttribute<PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
- static constexpr const char *pyClassName = "DenseElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseElementsAttribute
- getFromList(const nb::list &attributes, std::optional<PyType> explicitType,
- DefaultingPyMlirContext contextWrapper) {
- const size_t numAttributes = nb::len(attributes);
- if (numAttributes == 0)
- throw nb::value_error("Attributes list must be non-empty.");
-
- MlirType shapedType;
- if (explicitType) {
- if ((!mlirTypeIsAShaped(*explicitType) ||
- !mlirShapedTypeHasStaticShape(*explicitType))) {
-
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "Expected a static ShapedType for the shaped_type parameter: "
- << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
- throw nb::value_error(message.c_str());
- }
- shapedType = *explicitType;
- } else {
- SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
- shapedType = mlirRankedTensorTypeGet(
- shape.size(), shape.data(),
- mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
- mlirAttributeGetNull());
- }
+ MlirAttribute elements =
+ mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+}
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(numAttributes);
- for (const nb::handle &attribute : attributes) {
- MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
- MlirType attrType = mlirAttributeGetType(mlirAttribute);
- mlirAttributes.push_back(mlirAttribute);
-
- if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
- std::string message;
- llvm::raw_string_ostream os(message);
- os << "All attributes must be of the same type and match "
- << "the type parameter: expected="
- << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
- << ", but got="
- << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
- throw nb::value_error(message.c_str());
- }
- }
+intptr_t PyDenseElementsAttribute::dunderLen() const {
+ return mlirElementsAttrGetNumElements(*this);
+}
- MlirAttribute elements = mlirDenseElementsAttrGet(
- shapedType, mlirAttributes.size(), mlirAttributes.data());
+std::unique_ptr<nb_buffer_info> PyDenseElementsAttribute::accessBuffer() {
+ MlirType shapedType = mlirAttributeGetType(*this);
+ MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+ std::string format;
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ if (mlirTypeIsAF32(elementType)) {
+ // f32
+ return bufferInfo<float>(shapedType);
}
-
- static PyDenseElementsAttribute
- getFromBuffer(const nb_buffer &array, bool signless,
- const std::optional<PyType> &explicitType,
- std::optional<std::vector<int64_t>> explicitShape,
- DefaultingPyMlirContext contextWrapper) {
- // Request a contiguous view. In exotic cases, this will cause a copy.
- int flags = PyBUF_ND;
- if (!explicitType) {
- flags |= PyBUF_FORMAT;
- }
- Py_buffer view;
- if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
- throw nb::python_error();
- }
- auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
-
- MlirContext context = contextWrapper->get();
- MlirAttribute attr = getAttributeFromBuffer(
- view, signless, explicitType, std::move(explicitShape), context);
- if (mlirAttributeIsNull(attr)) {
- throw std::invalid_argument(
- "DenseElementsAttr could not be constructed from the given buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
- }
- return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+ if (mlirTypeIsAF64(elementType)) {
+ // f64
+ return bufferInfo<double>(shapedType);
}
-
- static PyDenseElementsAttribute getSplat(const PyType &shapedType,
- PyAttribute &elementAttr) {
- auto contextWrapper =
- PyMlirContext::forContext(mlirTypeGetContext(shapedType));
- if (!mlirAttributeIsAInteger(elementAttr) &&
- !mlirAttributeIsAFloat(elementAttr)) {
- std::string message = "Illegal element type for DenseElementsAttr: ";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
- throw nb::value_error(message.c_str());
+ if (mlirTypeIsAF16(elementType)) {
+ // f16
+ return bufferInfo<uint16_t>(shapedType, "e");
+ }
+ if (mlirTypeIsAIndex(elementType)) {
+ // Same as IndexType::kInternalStorageBitWidth
+ return bufferInfo<int64_t>(shapedType);
+ }
+ if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 32) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i32
+ return bufferInfo<int32_t>(shapedType);
}
- if (!mlirTypeIsAShaped(shapedType) ||
- !mlirShapedTypeHasStaticShape(shapedType)) {
- std::string message =
- "Expected a static ShapedType for the shaped_type parameter: ";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
- throw nb::value_error(message.c_str());
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i32
+ return bufferInfo<uint32_t>(shapedType);
}
- MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
- MlirType attrType = mlirAttributeGetType(elementAttr);
- if (!mlirTypeEqual(shapedElementType, attrType)) {
- std::string message =
- "Shaped element type and attribute type must be equal: shaped=";
- message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
- message.append(", element=");
- message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
- throw nb::value_error(message.c_str());
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 64) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i64
+ return bufferInfo<int64_t>(shapedType);
}
-
- MlirAttribute elements =
- mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
- }
-
- intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
-
- std::unique_ptr<nb_buffer_info> accessBuffer() {
- MlirType shapedType = mlirAttributeGetType(*this);
- MlirType elementType = mlirShapedTypeGetElementType(shapedType);
- std::string format;
-
- if (mlirTypeIsAF32(elementType)) {
- // f32
- return bufferInfo<float>(shapedType);
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i64
+ return bufferInfo<uint64_t>(shapedType);
}
- if (mlirTypeIsAF64(elementType)) {
- // f64
- return bufferInfo<double>(shapedType);
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 8) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i8
+ return bufferInfo<int8_t>(shapedType);
}
- if (mlirTypeIsAF16(elementType)) {
- // f16
- return bufferInfo<uint16_t>(shapedType, "e");
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i8
+ return bufferInfo<uint8_t>(shapedType);
}
- if (mlirTypeIsAIndex(elementType)) {
- // Same as IndexType::kInternalStorageBitWidth
- return bufferInfo<int64_t>(shapedType);
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 16) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i16
+ return bufferInfo<int16_t>(shapedType);
}
- if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 32) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i32
- return bufferInfo<int32_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i32
- return bufferInfo<uint32_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 64) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i64
- return bufferInfo<int64_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i64
- return bufferInfo<uint64_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 8) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i8
- return bufferInfo<int8_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i8
- return bufferInfo<uint8_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 16) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i16
- return bufferInfo<int16_t>(shapedType);
- }
- if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i16
- return bufferInfo<uint16_t>(shapedType);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 1) {
- // i1 / bool
- // We can not send the buffer directly back to Python, because the i1
- // values are bitpacked within MLIR. We call numpy's unpackbits function
- // to convert the bytes.
- return getBooleanBufferFromBitpackedAttribute();
+ if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i16
+ return bufferInfo<uint16_t>(shapedType);
}
-
- // TODO: Currently crashes the program.
- // Reported as https://github.com/pybind/pybind11/issues/3336
- throw std::invalid_argument(
- "unsupported data type for conversion to Python buffer");
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 1) {
+ // i1 / bool
+ // We can not send the buffer directly back to Python, because the i1
+ // values are bitpacked within MLIR. We call numpy's unpackbits function
+ // to convert the bytes.
+ return getBooleanBufferFromBitpackedAttribute();
}
- static void bindDerived(ClassTy &c) {
+ // TODO: Currently crashes the program.
+ // Reported as https://github.com/pybind/pybind11/issues/3336
+ throw std::invalid_argument(
+ "unsupported data type for conversion to Python buffer");
+}
+
+void PyDenseElementsAttribute::bindDerived(ClassTy &c) {
#if PY_VERSION_HEX < 0x03090000
- PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
- tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
- tp->tp_as_buffer->bf_releasebuffer =
- PyDenseElementsAttribute::bf_releasebuffer;
+ PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
+ tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
+ tp->tp_as_buffer->bf_releasebuffer =
+ PyDenseElementsAttribute::bf_releasebuffer;
#endif
- c.def("__len__", &PyDenseElementsAttribute::dunderLen)
- .def_static(
- "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
- nb::arg("signless") = true, nb::arg("type") = nb::none(),
- nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
- // clang-format off
+ c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+ .def_static(
+ "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
+ nb::arg("signless") = true, nb::arg("type") = nb::none(),
+ nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
+ // clang-format off
nb::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"),
- // clang-format on
- kDenseElementsAttrGetDocstring)
- .def_static("get", PyDenseElementsAttribute::getFromList,
- nb::arg("attrs"), nb::arg("type") = nb::none(),
- nb::arg("context") = nb::none(),
- kDenseElementsAttrGetFromListDocstring)
- .def_static("get_splat", PyDenseElementsAttribute::getSplat,
- nb::arg("shaped_type"), nb::arg("element_attr"),
- "Gets a DenseElementsAttr where all values are the same")
- .def_prop_ro("is_splat",
- [](PyDenseElementsAttribute &self) -> bool {
- return mlirDenseElementsAttrIsSplat(self);
- })
- .def("get_splat_value",
- [](PyDenseElementsAttribute &self)
- -> nb::typed<nb::object, PyAttribute> {
- if (!mlirDenseElementsAttrIsSplat(self))
- throw nb::value_error(
- "get_splat_value called on a non-splat attribute");
- return PyAttribute(self.getContext(),
- mlirDenseElementsAttrGetSplatValue(self))
- .maybeDownCast();
- });
- }
-
- static PyType_Slot slots[];
+ // clang-format on
+ kDenseElementsAttrGetDocstring)
+ .def_static("get", PyDenseElementsAttribute::getFromList,
+ nb::arg("attrs"), nb::arg("type") = nb::none(),
+ nb::arg("context") = nb::none(),
+ kDenseElementsAttrGetFromListDocstring)
+ .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+ nb::arg("shaped_type"), nb::arg("element_attr"),
+ "Gets a DenseElementsAttr where all values are the same")
+ .def_prop_ro("is_splat",
+ [](PyDenseElementsAttribute &self) -> bool {
+ return mlirDenseElementsAttrIsSplat(self);
+ })
+ .def("get_splat_value",
+ [](PyDenseElementsAttribute &self)
+ -> nb::typed<nb::object, PyAttribute> {
+ if (!mlirDenseElementsAttrIsSplat(self))
+ throw nb::value_error(
+ "get_splat_value called on a non-splat attribute");
+ return PyAttribute(self.getContext(),
+ mlirDenseElementsAttrGetSplatValue(self))
+ .maybeDownCast();
+ });
+}
-private:
- static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
- static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
+bool PyDenseElementsAttribute::isUnsignedIntegerFormat(
+ std::string_view format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+ code == 'Q';
+}
- static bool isUnsignedIntegerFormat(std::string_view format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
- code == 'Q';
- }
+bool PyDenseElementsAttribute::isSignedIntegerFormat(std::string_view format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+ code == 'q';
+}
- static bool isSignedIntegerFormat(std::string_view format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
- code == 'q';
+MlirType PyDenseElementsAttribute::getShapedType(
+ std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape, Py_buffer &view) {
+ SmallVector<int64_t> shape;
+ if (explicitShape) {
+ shape.append(explicitShape->begin(), explicitShape->end());
+ } else {
+ shape.append(view.shape, view.shape + view.ndim);
}
- static MlirType
- getShapedType(std::optional<MlirType> bulkLoadElementType,
- std::optional<std::vector<int64_t>> explicitShape,
- Py_buffer &view) {
- SmallVector<int64_t> shape;
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
- shape.append(explicitShape->begin(), explicitShape->end());
- } else {
- shape.append(view.shape, view.shape + view.ndim);
- }
-
- if (mlirTypeIsAShaped(*bulkLoadElementType)) {
- if (explicitShape) {
- throw std::invalid_argument("Shape can only be specified explicitly "
- "when the type is not a shaped type.");
- }
- return *bulkLoadElementType;
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
}
- MlirAttribute encodingAttr = mlirAttributeGetNull();
- return mlirRankedTensorTypeGet(shape.size(), shape.data(),
- *bulkLoadElementType, encodingAttr);
+ return *bulkLoadElementType;
}
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
+ return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
+}
- static MlirAttribute getAttributeFromBuffer(
- Py_buffer &view, bool signless, std::optional<PyType> explicitType,
- const std::optional<std::vector<int64_t>> &explicitShape,
- MlirContext &context) {
- // Detect format codes that are suitable for bulk loading. This includes
- // all byte aligned integer and floating point types up to 8 bytes.
- // Notably, this excludes exotics types which do not have a direct
- // representation in the buffer protocol (i.e. complex, etc).
- std::optional<MlirType> bulkLoadElementType;
- if (explicitType) {
- bulkLoadElementType = *explicitType;
- } else {
- std::string_view format(view.format);
- if (format == "f") {
- // f32
- assert(view.itemsize == 4 && "mismatched array itemsize");
- bulkLoadElementType = mlirF32TypeGet(context);
- } else if (format == "d") {
- // f64
- assert(view.itemsize == 8 && "mismatched array itemsize");
- bulkLoadElementType = mlirF64TypeGet(context);
- } else if (format == "e") {
- // f16
- assert(view.itemsize == 2 && "mismatched array itemsize");
- bulkLoadElementType = mlirF16TypeGet(context);
- } else if (format == "?") {
- // i1
- // The i1 type needs to be bit-packed, so we will handle it separately
- return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
- context);
- } else if (isSignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeSignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeSignedGet(context, 16);
- }
- } else if (isUnsignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // unsigned i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // unsigned i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeUnsignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeUnsignedGet(context, 16);
- }
+MlirAttribute PyDenseElementsAttribute::getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context) {
+ // Detect format codes that are suitable for bulk loading. This includes
+ // all byte aligned integer and floating point types up to 8 bytes.
+ // Notably, this excludes exotics types which do not have a direct
+ // representation in the buffer protocol (i.e. complex, etc).
+ std::optional<MlirType> bulkLoadElementType;
+ if (explicitType) {
+ bulkLoadElementType = *explicitType;
+ } else {
+ std::string_view format(view.format);
+ if (format == "f") {
+ // f32
+ assert(view.itemsize == 4 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF32TypeGet(context);
+ } else if (format == "d") {
+ // f64
+ assert(view.itemsize == 8 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (format == "e") {
+ // f16
+ assert(view.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
+ } else if (format == "?") {
+ // i1
+ // The i1 type needs to be bit-packed, so we will handle it separately
+ return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+ context);
+ } else if (isSignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // i32
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // i64
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
}
- if (!bulkLoadElementType) {
- throw std::invalid_argument(
- std::string("unimplemented array format conversion from format: ") +
- std::string(format));
+ } else if (isUnsignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // unsigned i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // unsigned i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
}
}
-
- MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
- return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
- }
-
- // There is a complication for boolean numpy arrays, as numpy represents
- // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
- // booleans per byte.
- static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
- Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
- MlirContext &context) {
- if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian
- // systems we will throw
- throw nb::type_error("Constructing a bit-packed MLIR attribute is "
- "unsupported on big-endian systems");
+ if (!bulkLoadElementType) {
+ throw std::invalid_argument(
+ std::string("unimplemented array format conversion from format: ") +
+ std::string(format));
}
- nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
- /*data=*/static_cast<uint8_t *>(view.buf),
- /*shape=*/{static_cast<size_t>(view.len)});
-
- nb::module_ numpy = nb::module_::import_("numpy");
- nb::object packbitsFunc = numpy.attr("packbits");
- nb::object packedBooleans =
- packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
- nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
-
- MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
- std::move(explicitShape), view);
- assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
- // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
- // packedBooleans, hence the MlirAttribute will remain valid even when
- // packedBooleans get reclaimed by the end of the function.
- return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
- pythonBuffer.ptr);
}
- // This does the opposite transformation of
- // `getBitpackedAttributeFromBooleanBuffer`
- std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
- if (llvm::endianness::native != llvm::endianness::little) {
- // Given we have no good way of testing the behavior on big-endian
- // systems we will throw
- throw nb::type_error("Constructing a numpy array from a MLIR attribute "
- "is unsupported on big-endian systems");
- }
+ MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+ return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
+}
- int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
- int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
- uint8_t *bitpackedData = static_cast<uint8_t *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
- /*data=*/bitpackedData,
- /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
-
- nb::module_ numpy = nb::module_::import_("numpy");
- nb::object unpackbitsFunc = numpy.attr("unpackbits");
- nb::object equalFunc = numpy.attr("equal");
- nb::object reshapeFunc = numpy.attr("reshape");
- nb::object unpackedBooleans =
- unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
-
- // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
- // We need to:
- // 1. Slice away the padded bits
- // 2. Make the boolean array have the correct shape
- // 3. Convert the array to a boolean array
- unpackedBooleans = unpackedBooleans[nb::slice(
- nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
- unpackedBooleans = equalFunc(unpackedBooleans, 1);
-
- MlirType shapedType = mlirAttributeGetType(*this);
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- std::vector<intptr_t> shape(rank);
- for (intptr_t i = 0; i < rank; ++i) {
- shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
- }
- unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+MlirAttribute PyDenseElementsAttribute::getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context) {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nb::type_error("Constructing a bit-packed MLIR attribute is "
+ "unsupported on big-endian systems");
+ }
+ nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
+ /*data=*/static_cast<uint8_t *>(view.buf),
+ /*shape=*/{static_cast<size_t>(view.len)});
+
+ nb::module_ numpy = nb::module_::import_("numpy");
+ nb::object packbitsFunc = numpy.attr("packbits");
+ nb::object packedBooleans =
+ packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
+ nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
+
+ MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
+ std::move(explicitShape), view);
+ assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
+ // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
+ // packedBooleans, hence the MlirAttribute will remain valid even when
+ // packedBooleans get reclaimed by the end of the function.
+ return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
+ pythonBuffer.ptr);
+}
- // Make sure the returned nb::buffer_view claims ownership of the data in
- // `pythonBuffer` so it remains valid when Python reads it
- nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
- return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+std::unique_ptr<nb_buffer_info>
+PyDenseElementsAttribute::getBooleanBufferFromBitpackedAttribute() const {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian
+ // systems we will throw
+ throw nb::type_error("Constructing a numpy array from a MLIR attribute "
+ "is unsupported on big-endian systems");
}
- template <typename Type>
- std::unique_ptr<nb_buffer_info>
- bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- // Prepare the data for the buffer_info.
- // Buffer is configured for read-only access below.
- Type *data = static_cast<Type *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- // Prepare the shape for the buffer_info.
- SmallVector<intptr_t, 4> shape;
- for (intptr_t i = 0; i < rank; ++i)
- shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
- // Prepare the strides for the buffer_info.
- SmallVector<intptr_t, 4> strides;
- if (mlirDenseElementsAttrIsSplat(*this)) {
- // Splats are special, only the single value is stored.
- strides.assign(rank, 0);
- } else {
- for (intptr_t i = 1; i < rank; ++i) {
- intptr_t strideFactor = 1;
- for (intptr_t j = i; j < rank; ++j)
- strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
- strides.push_back(sizeof(Type) * strideFactor);
- }
- strides.push_back(sizeof(Type));
- }
- const char *format;
- if (explicitFormat) {
- format = explicitFormat;
- } else {
- format = nb_format_descriptor<Type>::format();
- }
- return std::make_unique<nb_buffer_info>(
- data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
- /*readonly=*/true);
+ int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+ int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
+ uint8_t *bitpackedData = static_cast<uint8_t *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
+ /*data=*/bitpackedData,
+ /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
+
+ nb::module_ numpy = nb::module_::import_("numpy");
+ nb::object unpackbitsFunc = numpy.attr("unpackbits");
+ nb::object equalFunc = numpy.attr("equal");
+ nb::object reshapeFunc = numpy.attr("reshape");
+ nb::object unpackedBooleans =
+ unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
+
+ // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
+ // We need to:
+ // 1. Slice away the padded bits
+ // 2. Make the boolean array have the correct shape
+ // 3. Convert the array to a boolean array
+ unpackedBooleans = unpackedBooleans[nb::slice(
+ nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
+ unpackedBooleans = equalFunc(unpackedBooleans, 1);
+
+ MlirType shapedType = mlirAttributeGetType(*this);
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ std::vector<intptr_t> shape(rank);
+ for (intptr_t i = 0; i < rank; ++i) {
+ shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
}
-}; // namespace
+ unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
+
+ // Make sure the returned nb::buffer_view claims ownership of the data
+ // in `pythonBuffer` so it remains valid when Python reads it
+ nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
+ return std::make_unique<nb_buffer_info>(pythonBuffer.request());
+}
PyType_Slot PyDenseElementsAttribute::slots[] = {
// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
@@ -1333,364 +965,294 @@ PyType_Slot PyDenseElementsAttribute::slots[] = {
delete reinterpret_cast<nb_buffer_info *>(view->internal);
}
-/// Refinement of the PyDenseElementsAttribute for attributes containing
-/// integer (and boolean) values. Supports element access.
-class PyDenseIntElementsAttribute
- : public PyConcreteAttribute<PyDenseIntElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
- static constexpr const char *pyClassName = "DenseIntElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- /// Returns the element at the given linear position. Asserts if the index
- /// is out of range.
- nb::int_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds element");
- }
+nb::int_ PyDenseIntElementsAttribute::dunderGetItem(intptr_t pos) const {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds element");
+ }
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Index type can also appear as a DenseIntElementsAttr and therefore can be
- // casted to integer.
- assert(mlirTypeIsAInteger(type) ||
- mlirTypeIsAIndex(type) && "expected integer/index element type in "
- "dense int elements attribute");
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. nb::int_ is implicitly constructible
- // from any C++ integral type and handles bitwidth correctly.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAIndex(type)) {
- return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Index type can also appear as a DenseIntElementsAttr and therefore can be
+ // casted to integer.
+ assert(mlirTypeIsAInteger(type) ||
+ mlirTypeIsAIndex(type) && "expected integer/index element type in "
+ "dense int elements attribute");
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. nb::int_ is implicitly
+ // constructible from any C++ integral type and handles bitwidth correctly.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAIndex(type)) {
+ return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+ }
+ unsigned width = mlirIntegerTypeGetWidth(type);
+ bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+ if (isUnsigned) {
+ if (width == 1) {
+ return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
}
- unsigned width = mlirIntegerTypeGetWidth(type);
- bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
- if (isUnsigned) {
- if (width == 1) {
- return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
- }
- if (width == 8) {
- return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
- }
- if (width == 16) {
- return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
- }
- if (width == 32) {
- return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
- }
- if (width == 64) {
- return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
- }
- } else {
- if (width == 1) {
- return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
- }
- if (width == 8) {
- return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
- }
- if (width == 16) {
- return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
- }
- if (width == 32) {
- return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
- }
- if (width == 64) {
- return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
- }
+ if (width == 8) {
+ return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
+ }
+ if (width == 16) {
+ return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
+ }
+ if (width == 32) {
+ return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
+ }
+ if (width == 64) {
+ return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
+ }
+ } else {
+ if (width == 1) {
+ return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
+ }
+ if (width == 8) {
+ return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
+ }
+ if (width == 16) {
+ return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
+ }
+ if (width == 32) {
+ return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
+ }
+ if (width == 64) {
+ return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
}
- throw nb::type_error("Unsupported integer type");
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
}
-};
+ throw nb::type_error("Unsupported integer type");
+}
+void PyDenseIntElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+}
// Check if the python version is less than 3.13. Py_IsFinalizing is a part
// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing.
#if PY_VERSION_HEX < 0x030d0000
#define Py_IsFinalizing _Py_IsFinalizing
#endif
-class PyDenseResourceElementsAttribute
- : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction =
- mlirAttributeIsADenseResourceElements;
- static constexpr const char *pyClassName = "DenseResourceElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseResourceElementsAttribute
- getFromBuffer(const nb_buffer &buffer, const std::string &name,
- const PyType &type, std::optional<size_t> alignment,
- bool isMutable, DefaultingPyMlirContext contextWrapper) {
- if (!mlirTypeIsAShaped(type)) {
- throw std::invalid_argument(
- "Constructing a DenseResourceElementsAttr requires a ShapedType.");
- }
-
- // Do not request any conversions as we must ensure to use caller
- // managed memory.
- int flags = PyBUF_STRIDES;
- std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
- if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
- throw nb::python_error();
- }
+PyDenseResourceElementsAttribute
+PyDenseResourceElementsAttribute::getFromBuffer(
+ const nb_buffer &buffer, const std::string &name, const PyType &type,
+ std::optional<size_t> alignment, bool isMutable,
+ DefaultingPyMlirContext contextWrapper) {
+ if (!mlirTypeIsAShaped(type)) {
+ throw std::invalid_argument(
+ "Constructing a DenseResourceElementsAttr requires a ShapedType.");
+ }
- // This scope releaser will only release if we haven't yet transferred
- // ownership.
- auto freeBuffer = llvm::make_scope_exit([&]() {
- if (view)
- PyBuffer_Release(view.get());
- });
+ // Do not request any conversions as we must ensure to use caller
+ // managed memory.
+ int flags = PyBUF_STRIDES;
+ std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
+ if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
+ throw nb::python_error();
+ }
- if (!PyBuffer_IsContiguous(view.get(), 'A')) {
- throw std::invalid_argument("Contiguous buffer is required.");
- }
+ // This scope releaser will only release if we haven't yet transferred
+ // ownership.
+ auto freeBuffer = llvm::make_scope_exit([&]() {
+ if (view)
+ PyBuffer_Release(view.get());
+ });
- // Infer alignment to be the stride of one element if not explicit.
- size_t inferredAlignment;
- if (alignment)
- inferredAlignment = *alignment;
- else
- inferredAlignment = view->strides[view->ndim - 1];
-
- // The userData is a Py_buffer* that the deleter owns.
- auto deleter = [](void *userData, const void *data, size_t size,
- size_t align) {
- if (Py_IsFinalizing())
- return;
- assert(Py_IsInitialized() && "expected interpreter to be initialized");
- Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
- nb::gil_scoped_acquire gil;
- PyBuffer_Release(ownedView);
- delete ownedView;
- };
-
- size_t rawBufferSize = view->len;
- MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
- type, toMlirStringRef(name), view->buf, rawBufferSize,
- inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
- if (mlirAttributeIsNull(attr)) {
- throw std::invalid_argument(
- "DenseResourceElementsAttr could not be constructed from the given "
- "buffer. "
- "This may mean that the Python buffer layout does not match that "
- "MLIR expected layout and is a bug.");
- }
- view.release();
- return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+ if (!PyBuffer_IsContiguous(view.get(), 'A')) {
+ throw std::invalid_argument("Contiguous buffer is required.");
}
- static void bindDerived(ClassTy &c) {
- c.def_static("get_from_buffer",
- PyDenseResourceElementsAttribute::getFromBuffer,
- nb::arg("array"), nb::arg("name"), nb::arg("type"),
- nb::arg("alignment") = nb::none(),
- nb::arg("is_mutable") = false, nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
- // clang-format on
- kDenseResourceElementsAttrGetFromBufferDocstring);
+ // Infer alignment to be the stride of one element if not explicit.
+ size_t inferredAlignment;
+ if (alignment)
+ inferredAlignment = *alignment;
+ else
+ inferredAlignment = view->strides[view->ndim - 1];
+
+ // The userData is a Py_buffer* that the deleter owns.
+ auto deleter = [](void *userData, const void *data, size_t size,
+ size_t align) {
+ if (Py_IsFinalizing())
+ return;
+ assert(Py_IsInitialized() && "expected interpreter to be initialized");
+ Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
+ nb::gil_scoped_acquire gil;
+ PyBuffer_Release(ownedView);
+ delete ownedView;
+ };
+
+ size_t rawBufferSize = view->len;
+ MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
+ type, toMlirStringRef(name), view->buf, rawBufferSize, inferredAlignment,
+ isMutable, deleter, static_cast<void *>(view.get()));
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseResourceElementsAttr could not be constructed from the given "
+ "buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
}
-};
+ view.release();
+ return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
+}
-class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
- static constexpr const char *pyClassName = "DictAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirDictionaryAttrGetTypeID;
+void PyDenseResourceElementsAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
+ nb::arg("array"), nb::arg("name"), nb::arg("type"),
+ nb::arg("alignment") = nb::none(), nb::arg("is_mutable") = false,
+ nb::arg("context") = nb::none(),
+ // clang-format off
+ nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
+ // clang-format on
+ kDenseResourceElementsAttrGetFromBufferDocstring);
+}
- intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+intptr_t PyDictAttribute::dunderLen() const {
+ return mlirDictionaryAttrGetNumElements(*this);
+}
- bool dunderContains(const std::string &name) {
- return !mlirAttributeIsNull(
- mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
- }
+bool PyDictAttribute::dunderContains(const std::string &name) const {
+ return !mlirAttributeIsNull(
+ mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
+}
- static void bindDerived(ClassTy &c) {
- c.def("__contains__", &PyDictAttribute::dunderContains);
- c.def("__len__", &PyDictAttribute::dunderLen);
- c.def_static(
- "get",
- [](const nb::dict &attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirNamedAttribute> mlirNamedAttributes;
- mlirNamedAttributes.reserve(attributes.size());
- for (std::pair<nb::handle, nb::handle> it : attributes) {
- auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
- auto name = nb::cast<std::string>(it.first);
- mlirNamedAttributes.push_back(mlirNamedAttributeGet(
- mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
- toMlirStringRef(name)),
- mlirAttr));
- }
+void PyDictAttribute::bindDerived(ClassTy &c) {
+ c.def("__contains__", &PyDictAttribute::dunderContains);
+ c.def("__len__", &PyDictAttribute::dunderLen);
+ c.def_static(
+ "get",
+ [](const nb::dict &attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+ mlirNamedAttributes.reserve(attributes.size());
+ for (std::pair<nb::handle, nb::handle> it : attributes) {
+ auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
+ auto name = nb::cast<std::string>(it.first);
+ mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+ mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
+ toMlirStringRef(name)),
+ mlirAttr));
+ }
+ MlirAttribute attr =
+ mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+ mlirNamedAttributes.data());
+ return PyDictAttribute(context->getRef(), attr);
+ },
+ nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
+ "Gets an uniqued dict attribute");
+ c.def("__getitem__",
+ [](PyDictAttribute &self,
+ const std::string &name) -> nb::typed<nb::object, PyAttribute> {
MlirAttribute attr =
- mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
- mlirNamedAttributes.data());
- return PyDictAttribute(context->getRef(), attr);
- },
- nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
- "Gets an uniqued dict attribute");
- c.def("__getitem__",
- [](PyDictAttribute &self,
- const std::string &name) -> nb::typed<nb::object, PyAttribute> {
- MlirAttribute attr =
- mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr))
- throw nb::key_error("attempt to access a non-existent attribute");
- return PyAttribute(self.getContext(), attr).maybeDownCast();
- });
- c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
- if (index < 0 || index >= self.dunderLen()) {
- throw nb::index_error("attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data));
- });
- }
-};
-
-/// Refinement of PyDenseElementsAttribute for attributes containing
-/// floating-point values. Supports element access.
-class PyDenseFPElementsAttribute
- : public PyConcreteAttribute<PyDenseFPElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
- static constexpr const char *pyClassName = "DenseFPElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- nb::float_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds element");
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr))
+ throw nb::key_error("attempt to access a non-existent attribute");
+ return PyAttribute(self.getContext(), attr).maybeDownCast();
+ });
+ c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+ if (index < 0 || index >= self.dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds attribute");
}
+ MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data));
+ });
+}
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. nb::float_ is implicitly constructible
- // from float and double.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAF32(type)) {
- return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
- }
- if (mlirTypeIsAF64(type)) {
- return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
- }
- throw nb::type_error("Unsupported floating-point type");
+nb::float_ PyDenseFPElementsAttribute::dunderGetItem(intptr_t pos) const {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw nb::index_error("attempt to access out of bounds element");
}
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. nb::float_ is implicitly
+ // constructible from float and double.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAF32(type)) {
+ return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
}
-};
-
-class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
- static constexpr const char *pyClassName = "TypeAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTypeAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const PyType &value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirTypeAttrGet(value.get());
- return PyTypeAttribute(context->getRef(), attr);
- },
- nb::arg("value"), nb::arg("context") = nb::none(),
- "Gets a uniqued Type attribute");
- c.def_prop_ro(
- "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
- .maybeDownCast();
- });
+ if (mlirTypeIsAF64(type)) {
+ return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
}
-};
+ throw nb::type_error("Unsupported floating-point type");
+}
-/// Unit Attribute subclass. Unit attributes don't have values.
-class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
- static constexpr const char *pyClassName = "UnitAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnitAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- return PyUnitAttribute(context->getRef(),
- mlirUnitAttrGet(context->get()));
- },
- nb::arg("context") = nb::none(), "Create a Unit attribute.");
- }
-};
+void PyDenseFPElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+}
-/// Strided layout attribute subclass.
-class PyStridedLayoutAttribute
- : public PyConcreteAttribute<PyStridedLayoutAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
- static constexpr const char *pyClassName = "StridedLayoutAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirStridedLayoutAttrGetTypeID;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](int64_t offset, const std::vector<int64_t> &strides,
- DefaultingPyMlirContext ctx) {
- MlirAttribute attr = mlirStridedLayoutAttrGet(
- ctx->get(), offset, strides.size(), strides.data());
- return PyStridedLayoutAttribute(ctx->getRef(), attr);
- },
- nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
- "Gets a strided layout attribute.");
- c.def_static(
- "get_fully_dynamic",
- [](int64_t rank, DefaultingPyMlirContext ctx) {
- auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
- std::vector<int64_t> strides(rank);
- llvm::fill(strides, dynamic);
- MlirAttribute attr = mlirStridedLayoutAttrGet(
- ctx->get(), dynamic, strides.size(), strides.data());
- return PyStridedLayoutAttribute(ctx->getRef(), attr);
- },
- nb::arg("rank"), nb::arg("context") = nb::none(),
- "Gets a strided layout attribute with dynamic offset and strides of "
- "a "
- "given rank.");
- c.def_prop_ro(
- "offset",
- [](PyStridedLayoutAttribute &self) {
- return mlirStridedLayoutAttrGetOffset(self);
- },
- "Returns the value of the float point attribute");
- c.def_prop_ro(
- "strides",
- [](PyStridedLayoutAttribute &self) {
- intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
- std::vector<int64_t> strides(size);
- for (intptr_t i = 0; i < size; i++) {
- strides[i] = mlirStridedLayoutAttrGetStride(self, i);
- }
- return strides;
- },
- "Returns the value of the float point attribute");
- }
-};
+void PyTypeAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const PyType &value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirTypeAttrGet(value.get());
+ return PyTypeAttribute(context->getRef(), attr);
+ },
+ nb::arg("value"), nb::arg("context") = nb::none(),
+ "Gets a uniqued Type attribute");
+ c.def_prop_ro(
+ "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+ .maybeDownCast();
+ });
+}
+
+void PyUnitAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return PyUnitAttribute(context->getRef(),
+ mlirUnitAttrGet(context->get()));
+ },
+ nb::arg("context") = nb::none(), "Create a Unit attribute.");
+}
+
+void PyStridedLayoutAttribute::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](int64_t offset, const std::vector<int64_t> &strides,
+ DefaultingPyMlirContext ctx) {
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), offset, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(),
+ "Gets a strided layout attribute.");
+ c.def_static(
+ "get_fully_dynamic",
+ [](int64_t rank, DefaultingPyMlirContext ctx) {
+ auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
+ std::vector<int64_t> strides(rank);
+ llvm::fill(strides, dynamic);
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), dynamic, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ nb::arg("rank"), nb::arg("context") = nb::none(),
+ "Gets a strided layout attribute with dynamic offset and strides of "
+ "a "
+ "given rank.");
+ c.def_prop_ro(
+ "offset",
+ [](PyStridedLayoutAttribute &self) {
+ return mlirStridedLayoutAttrGetOffset(self);
+ },
+ "Returns the value of the float point attribute");
+ c.def_prop_ro(
+ "strides",
+ [](PyStridedLayoutAttribute &self) {
+ intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+ std::vector<int64_t> strides(size);
+ for (intptr_t i = 0; i < size; i++) {
+ strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+ }
+ return strides;
+ },
+ "Returns the value of the float point attribute");
+}
nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
@@ -1747,10 +1309,6 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
-
void PyStringAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
@@ -1795,9 +1353,6 @@ void PyStringAttribute::bindDerived(ClassTy &c) {
"Returns the value of the string attribute as `bytes`");
}
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRAttributes(nb::module_ &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7350046f428c7..ca56fc3248ed8 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -29,490 +29,269 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
+int mlirTypeIsAIntegerOrFloat(MlirType type) {
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIntegerTypeGetTypeID;
- static constexpr const char *pyClassName = "IntegerType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_signless",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signless integer type");
- c.def_static(
- "get_signed",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create a signed integer type");
- c.def_static(
- "get_unsigned",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- nb::arg("width"), nb::arg("context") = nb::none(),
- "Create an unsigned integer type");
- c.def_prop_ro(
- "width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
- "Returns the width of the integer type");
- c.def_prop_ro(
- "is_signless",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self);
- },
- "Returns whether this is a signless integer");
- c.def_prop_ro(
- "is_signed",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self);
- },
- "Returns whether this is a signed integer");
- c.def_prop_ro(
- "is_unsigned",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self);
- },
- "Returns whether this is an unsigned integer");
- }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirIndexTypeGetTypeID;
- static constexpr const char *pyClassName = "IndexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirIndexTypeGet(context->get());
- return PyIndexType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a index type.");
- }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
- static constexpr const char *pyClassName = "FloatType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro(
- "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
- "Returns the width of the floating-point type");
- }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
- : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat4E2M1FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float4E2M1FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
- return PyFloat4E2M1FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
- : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E2M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E2M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
- return PyFloat6E2M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
- : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat6E3M2FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float6E3M2FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
- return PyFloat6E3M2FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
- : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
- return PyFloat8E4M3FNType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2TypeGet(context->get());
- return PyFloat8E5M2Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3TypeGet(context->get());
- return PyFloat8E4M3Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
- : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
- return PyFloat8E4M3FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
- : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E4M3B11FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
- return PyFloat8E4M3B11FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
- }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
- : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E5M2FNUZTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E5M2FNUZType";
- using PyConcreteType::PyConcreteType;
+void PyIntegerType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_signless",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create a signless integer type");
+ c.def_static(
+ "get_signed",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create a signed integer type");
+ c.def_static(
+ "get_unsigned",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("context") = nb::none(),
+ "Create an unsigned integer type");
+ c.def_prop_ro(
+ "width",
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+ "Returns the width of the integer type");
+ c.def_prop_ro(
+ "is_signless",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSignless(self);
+ },
+ "Returns whether this is a signless integer");
+ c.def_prop_ro(
+ "is_signed",
+ [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); },
+ "Returns whether this is a signed integer");
+ c.def_prop_ro(
+ "is_unsigned",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsUnsigned(self);
+ },
+ "Returns whether this is an unsigned integer");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
- return PyFloat8E5M2FNUZType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
- }
-};
+void PyIndexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirIndexTypeGet(context->get());
+ return PyIndexType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a index type.");
+}
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E3M4TypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E3M4Type";
- using PyConcreteType::PyConcreteType;
+void PyFloatType::bindDerived(ClassTy &c) {
+ c.def_prop_ro(
+ "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+ "Returns the width of the floating-point type");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E3M4TypeGet(context->get());
- return PyFloat8E3M4Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
- }
-};
+void PyFloat4E2M1FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
+ return PyFloat4E2M1FNType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
+}
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
- : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat8E8M0FNUTypeGetTypeID;
- static constexpr const char *pyClassName = "Float8E8M0FNUType";
- using PyConcreteType::PyConcreteType;
+void PyFloat6E2M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+ return PyFloat6E2M3FNType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
- return PyFloat8E8M0FNUType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
- }
-};
+void PyFloat6E3M2FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+ return PyFloat6E3M2FNType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
+}
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirBFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "BF16Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3FNType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+ return PyFloat8E4M3FNType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirBF16TypeGet(context->get());
- return PyBF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a bf16 type.");
- }
-};
+void PyFloat8E5M2Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2TypeGet(context->get());
+ return PyFloat8E5M2Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
+}
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat16TypeGetTypeID;
- static constexpr const char *pyClassName = "F16Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3TypeGet(context->get());
+ return PyFloat8E4M3Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF16TypeGet(context->get());
- return PyF16Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f16 type.");
- }
-};
+void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
+ return PyFloat8E4M3FNUZType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
+}
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloatTF32TypeGetTypeID;
- static constexpr const char *pyClassName = "FloatTF32Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+ return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirTF32TypeGet(context->get());
- return PyTF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a tf32 type.");
- }
-};
+void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
+ return PyFloat8E5M2FNUZType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
+}
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat32TypeGetTypeID;
- static constexpr const char *pyClassName = "F32Type";
- using PyConcreteType::PyConcreteType;
+void PyFloat8E3M4Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E3M4TypeGet(context->get());
+ return PyFloat8E3M4Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF32TypeGet(context->get());
- return PyF32Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f32 type.");
- }
-};
+void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+ return PyFloat8E8M0FNUType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
+}
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFloat64TypeGetTypeID;
- static constexpr const char *pyClassName = "F64Type";
- using PyConcreteType::PyConcreteType;
+void PyBF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirBF16TypeGet(context->get());
+ return PyBF16Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a bf16 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF64TypeGet(context->get());
- return PyF64Type(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a f64 type.");
- }
-};
+void PyF16Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF16TypeGet(context->get());
+ return PyF16Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a f16 type.");
+}
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirNoneTypeGetTypeID;
- static constexpr const char *pyClassName = "NoneType";
- using PyConcreteType::PyConcreteType;
+void PyTF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirTF32TypeGet(context->get());
+ return PyTF32Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a tf32 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirNoneTypeGet(context->get());
- return PyNoneType(context->getRef(), t);
- },
- nb::arg("context") = nb::none(), "Create a none type.");
- }
-};
+void PyF32Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF32TypeGet(context->get());
+ return PyF32Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a f32 type.");
+}
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirComplexTypeGetTypeID;
- static constexpr const char *pyClassName = "ComplexType";
- using PyConcreteType::PyConcreteType;
+void PyF64Type::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF64TypeGet(context->get());
+ return PyF64Type(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a f64 type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType) {
- // The element must be a floating point or integer scalar type.
- if (mlirTypeIsAIntegerOrFloat(elementType)) {
- MlirType t = mlirComplexTypeGet(elementType);
- return PyComplexType(elementType.getContext(), t);
- }
- throw nb::value_error(
- (Twine("invalid '") +
- nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
- "' and expected floating point or integer type.")
- .str()
- .c_str());
- },
- "Create a complex type");
- c.def_prop_ro(
- "element_type",
- [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
- .maybeDownCast();
- },
- "Returns element type.");
- }
-};
+void PyNoneType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirNoneTypeGet(context->get());
+ return PyNoneType(context->getRef(), t);
+ },
+ nb::arg("context") = nb::none(), "Create a none type.");
+}
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
+void PyComplexType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType)) {
+ MlirType t = mlirComplexTypeGet(elementType);
+ return PyComplexType(elementType.getContext(), t);
+ }
+ throw nb::value_error(
+ (Twine("invalid '") +
+ nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
+ "' and expected floating point or integer type.")
+ .str()
+ .c_str());
+ },
+ "Create a complex type");
+ c.def_prop_ro(
+ "element_type",
+ [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
+ .maybeDownCast();
+ },
+ "Returns element type.");
+}
// Shaped Type Interface - ShapedType
void PyShapedType::bindDerived(ClassTy &c) {
@@ -629,526 +408,424 @@ void PyShapedType::requireHasRank() {
const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped;
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirVectorTypeGetTypeID;
- static constexpr const char *pyClassName = "VectorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable") = nb::none(),
- nb::arg("scalable_dims") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a vector type")
- .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
- nb::arg("element_type"), nb::kw_only(),
- nb::arg("scalable") = nb::none(),
- nb::arg("scalable_dims") = nb::none(),
- nb::arg("context") = nb::none(), "Create a vector type")
- .def_prop_ro(
- "scalable",
- [](MlirType self) { return mlirVectorTypeIsScalable(self); })
- .def_prop_ro("scalable_dims", [](MlirType self) {
- std::vector<bool> scalableDims;
- size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
- scalableDims.reserve(rank);
- for (size_t i = 0; i < rank; ++i)
- scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
- return scalableDims;
- });
- }
-
-private:
- static PyVectorType
- getChecked(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyLocation loc) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
- }
-
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
-
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
- scalableDimFlags.data(),
- elementType);
- } else {
- type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- }
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
+void PyVectorType::bindDerived(ClassTy &c) {
+ c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable") = nb::none(),
+ nb::arg("scalable_dims") = nb::none(),
+ nb::arg("loc") = nb::none(), "Create a vector type")
+ .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
+ nb::arg("element_type"), nb::kw_only(),
+ nb::arg("scalable") = nb::none(),
+ nb::arg("scalable_dims") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a vector type")
+ .def_prop_ro("scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_prop_ro("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
+}
- static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
- std::optional<nb::list> scalable,
- std::optional<std::vector<int64_t>> scalableDims,
- DefaultingPyMlirContext context) {
- if (scalable && scalableDims) {
- throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
- "are mutually exclusive.");
+PyVectorType
+PyVectorType::getChecked(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalableChecked(
+ loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
}
+ type = mlirVectorTypeGetScalableChecked(
+ loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType);
+ } else {
+ type =
+ mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+}
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType type;
- if (scalable) {
- if (scalable->size() != shape.size())
- throw nb::value_error("Expected len(scalable) == len(shape).");
-
- SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
- *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
- type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
- scalableDimFlags.data(), elementType);
- } else if (scalableDims) {
- SmallVector<bool> scalableDimFlags(shape.size(), false);
- for (int64_t dim : *scalableDims) {
- if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
- throw nb::value_error("Scalable dimension index out of bounds.");
- scalableDimFlags[dim] = true;
- }
- type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
- scalableDimFlags.data(), elementType);
- } else {
- type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+PyVectorType PyVectorType::get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<nb::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyMlirContext context) {
+ if (scalable && scalableDims) {
+ throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw nb::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw nb::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
}
- if (mlirTypeIsNull(type))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), type);
- }
-};
-
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
- : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirRankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "RankedTensorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirRankedTensorTypeGetChecked(
- loc, shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
- "Create a ranked tensor type");
- c.def_static(
- "get_unchecked",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::optional<PyAttribute> &encodingAttr,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType t = mlirRankedTensorTypeGet(
- shape.size(), shape.data(), elementType,
- encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyRankedTensorType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
- "Create a ranked tensor type");
- c.def_prop_ro(
- "encoding",
- [](PyRankedTensorType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
- if (mlirAttributeIsNull(encoding))
- return std::nullopt;
- return PyAttribute(self.getContext(), encoding).maybeDownCast();
- });
- }
-};
-
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
- : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedTensorTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedTensorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("loc") = nb::none(),
- "Create a unranked tensor type");
- c.def_static(
- "get_unchecked",
- [](PyType &elementType, DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirType t = mlirUnrankedTensorTypeGet(elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("context") = nb::none(),
- "Create a unranked tensor type");
- }
-};
-
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "MemRefType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
- shape.data(), layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
- nb::arg("loc") = nb::none(), "Create a memref type")
- .def_static(
- "get_unchecked",
- [](std::vector<int64_t> shape, PyType &elementType,
- PyAttribute *layout, PyAttribute *memorySpace,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute layoutAttr =
- layout ? *layout : mlirAttributeGetNull();
- MlirAttribute memSpaceAttr =
- memorySpace ? *memorySpace : mlirAttributeGetNull();
- MlirType t =
- mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
- layoutAttr, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyMemRefType(elementType.getContext(), t);
- },
- nb::arg("shape"), nb::arg("element_type"),
- nb::arg("layout") = nb::none(),
- nb::arg("memory_space") = nb::none(),
- nb::arg("context") = nb::none(), "Create a memref type")
- .def_prop_ro(
- "layout",
- [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
- return PyAttribute(self.getContext(),
- mlirMemRefTypeGetLayout(self))
- .maybeDownCast();
- },
- "The layout of the MemRef type.")
- .def(
- "get_strides_and_offset",
- [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
- std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
- int64_t offset;
- if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
- self, strides.data(), &offset)))
- throw std::runtime_error(
- "Failed to extract strides and offset from memref.");
- return {strides, offset};
- },
- "The strides and offset of the MemRef type.")
- .def_prop_ro(
- "affine_map",
- [](PyMemRefType &self) -> PyAffineMap {
- MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
- return PyAffineMap(self.getContext(), map);
- },
- "The layout of the MemRef type as an affine map.")
- .def_prop_ro(
- "memory_space",
- [](PyMemRefType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
- },
- "Returns the memory space of the given MemRef type.");
- }
-};
-
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
- : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirUnrankedMemRefTypeGetTypeID;
- static constexpr const char *pyClassName = "UnrankedMemRefType";
- using PyConcreteType::PyConcreteType;
+ type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+ scalableDimFlags.data(), elementType);
+ } else {
+ type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyRankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirRankedTensorTypeGetChecked(
+ loc, shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
+ "Create a ranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<PyAttribute> &encodingAttr,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), elementType,
+ encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
+ "Create a ranked tensor type");
+ c.def_prop_ro(
+ "encoding",
+ [](PyRankedTensorType &self)
+ -> std::optional<nb::typed<nb::object, PyAttribute>> {
+ MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+ if (mlirAttributeIsNull(encoding))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), encoding).maybeDownCast();
+ });
+}
- MlirType t =
- mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("loc") = nb::none(), "Create a unranked memref type")
- .def_static(
- "get_unchecked",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyMlirContext context) {
- PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
+void PyUnrankedTensorType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("loc") = nb::none(),
+ "Create a unranked tensor type");
+ c.def_static(
+ "get_unchecked",
+ [](PyType &elementType, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirType t = mlirUnrankedTensorTypeGet(elementType);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("context") = nb::none(),
+ "Create a unranked tensor type");
+}
- MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- nb::arg("element_type"), nb::arg("memory_space").none(),
- nb::arg("context") = nb::none(), "Create a unranked memref type")
- .def_prop_ro(
- "memory_space",
- [](PyUnrankedMemRefType &self)
- -> std::optional<nb::typed<nb::object, PyAttribute>> {
- MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
- if (mlirAttributeIsNull(a))
- return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
- },
- "Returns the memory space of the given Unranked MemRef type.");
- }
-};
+void PyMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType, PyAttribute *layout,
+ PyAttribute *memorySpace, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+ shape.data(), layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
+ nb::arg("loc") = nb::none(), "Create a memref type")
+ .def_static(
+ "get_unchecked",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ PyAttribute *layout, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute layoutAttr =
+ layout ? *layout : mlirAttributeGetNull();
+ MlirAttribute memSpaceAttr =
+ memorySpace ? *memorySpace : mlirAttributeGetNull();
+ MlirType t =
+ mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+ layoutAttr, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("shape"), nb::arg("element_type"),
+ nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
+ nb::arg("context") = nb::none(), "Create a memref type")
+ .def_prop_ro(
+ "layout",
+ [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
+ return PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
+ .maybeDownCast();
+ },
+ "The layout of the MemRef type.")
+ .def(
+ "get_strides_and_offset",
+ [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+ std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+ int64_t offset;
+ if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
+ self, strides.data(), &offset)))
+ throw std::runtime_error(
+ "Failed to extract strides and offset from memref.");
+ return {strides, offset};
+ },
+ "The strides and offset of the MemRef type.")
+ .def_prop_ro(
+ "affine_map",
+ [](PyMemRefType &self) -> PyAffineMap {
+ MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+ return PyAffineMap(self.getContext(), map);
+ },
+ "The layout of the MemRef type as an affine map.")
+ .def_prop_ro(
+ "memory_space",
+ [](PyMemRefType &self)
+ -> std::optional<nb::typed<nb::object, PyAttribute>> {
+ MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given MemRef type.");
+}
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirTupleTypeGetTypeID;
- static constexpr const char *pyClassName = "TupleType";
- using PyConcreteType::PyConcreteType;
+void PyUnrankedMemRefType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t =
+ mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("loc") = nb::none(), "Create a unranked memref type")
+ .def_static(
+ "get_unchecked",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ nb::arg("element_type"), nb::arg("memory_space").none(),
+ nb::arg("context") = nb::none(), "Create a unranked memref type")
+ .def_prop_ro(
+ "memory_space",
+ [](PyUnrankedMemRefType &self)
+ -> std::optional<nb::typed<nb::object, PyAttribute>> {
+ MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+ if (mlirAttributeIsNull(a))
+ return std::nullopt;
+ return PyAttribute(self.getContext(), a).maybeDownCast();
+ },
+ "Returns the memory space of the given Unranked MemRef type.");
+}
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_tuple",
- [](const std::vector<PyType> &elements,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirElements;
- mlirElements.reserve(elements.size());
- for (const auto &element : elements)
- mlirElements.push_back(element.get());
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- mlirElements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- "Create a tuple type");
- c.def_static(
- "get_tuple",
- [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
- MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
- elements.data());
- return PyTupleType(context->getRef(), t);
- },
- nb::arg("elements"), nb::arg("context") = nb::none(),
- // clang-format off
+void PyTupleType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_tuple",
+ [](const std::vector<PyType> &elements, DefaultingPyMlirContext context) {
+ std::vector<MlirType> mlirElements;
+ mlirElements.reserve(elements.size());
+ for (const auto &element : elements)
+ mlirElements.push_back(element.get());
+ MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+ mlirElements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nb::arg("elements"), nb::arg("context") = nb::none(),
+ "Create a tuple type");
+ c.def_static(
+ "get_tuple",
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nb::arg("elements"), nb::arg("context") = nb::none(),
+ // clang-format off
nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
- // clang-format on
- "Create a tuple type");
- c.def(
- "get_type",
- [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
- .maybeDownCast();
- },
- nb::arg("pos"), "Returns the pos-th type in the tuple type.");
- c.def_prop_ro(
- "num_types",
- [](PyTupleType &self) -> intptr_t {
- return mlirTupleTypeGetNumTypes(self);
- },
- "Returns the number of types contained in a tuple.");
- }
-};
-
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFunctionTypeGetTypeID;
- static constexpr const char *pyClassName = "FunctionType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
- DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirInputs;
- mlirInputs.reserve(inputs.size());
- for (const auto &input : inputs)
- mlirInputs.push_back(input.get());
- std::vector<MlirType> mlirResults;
- mlirResults.reserve(results.size());
- for (const auto &result : results)
- mlirResults.push_back(result.get());
+ // clang-format on
+ "Create a tuple type");
+ c.def(
+ "get_type",
+ [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
+ .maybeDownCast();
+ },
+ nb::arg("pos"), "Returns the pos-th type in the tuple type.");
+ c.def_prop_ro(
+ "num_types",
+ [](PyTupleType &self) -> intptr_t {
+ return mlirTupleTypeGetNumTypes(self);
+ },
+ "Returns the number of types contained in a tuple.");
+}
- MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
- mlirInputs.data(), results.size(),
- mlirResults.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- "Gets a FunctionType from a list of input and result types");
- c.def_static(
- "get",
- [](std::vector<MlirType> inputs, std::vector<MlirType> results,
- DefaultingPyMlirContext context) {
- MlirType t =
- mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
- results.size(), results.data());
- return PyFunctionType(context->getRef(), t);
- },
- nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
- // clang-format off
+void PyFunctionType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<PyType> inputs, std::vector<PyType> results,
+ DefaultingPyMlirContext context) {
+ std::vector<MlirType> mlirInputs;
+ mlirInputs.reserve(inputs.size());
+ for (const auto &input : inputs)
+ mlirInputs.push_back(input.get());
+ std::vector<MlirType> mlirResults;
+ mlirResults.reserve(results.size());
+ for (const auto &result : results)
+ mlirResults.push_back(result.get());
+
+ MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
+ mlirInputs.data(), results.size(),
+ mlirResults.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+ "Gets a FunctionType from a list of input and result types");
+ c.def_static(
+ "get",
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+ DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+ // clang-format off
nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
- // clang-format on
- "Gets a FunctionType from a list of input and result types");
- c.def_prop_ro(
- "inputs",
- [](PyFunctionType &self) {
- MlirType t = self;
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetInput(t, i));
- }
- return types;
- },
- "Returns the list of input types in the FunctionType.");
- c.def_prop_ro(
- "results",
- [](PyFunctionType &self) {
- nb::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
- ++i) {
- types.append(mlirFunctionTypeGetResult(self, i));
- }
- return types;
- },
- "Returns the list of result types in the FunctionType.");
- }
-};
-
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirOpaqueTypeGetTypeID;
- static constexpr const char *pyClassName = "OpaqueType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](const std::string &dialectNamespace, const std::string &typeData,
- DefaultingPyMlirContext context) {
- MlirType type = mlirOpaqueTypeGet(context->get(),
- toMlirStringRef(dialectNamespace),
- toMlirStringRef(typeData));
- return PyOpaqueType(context->getRef(), type);
- },
- nb::arg("dialect_namespace"), nb::arg("buffer"),
- nb::arg("context") = nb::none(),
- "Create an unregistered (opaque) dialect type.");
- c.def_prop_ro(
- "dialect_namespace",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the dialect namespace for the Opaque type as a string.");
- c.def_prop_ro(
- "data",
- [](PyOpaqueType &self) {
- MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
- return nb::str(stringRef.data, stringRef.length);
- },
- "Returns the data for the Opaque type as a string.");
- }
-};
+ // clang-format on
+ "Gets a FunctionType from a list of input and result types");
+ c.def_prop_ro(
+ "inputs",
+ [](PyFunctionType &self) {
+ MlirType t = self;
+ nb::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetInput(t, i));
+ }
+ return types;
+ },
+ "Returns the list of input types in the FunctionType.");
+ c.def_prop_ro(
+ "results",
+ [](PyFunctionType &self) {
+ nb::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+ ++i) {
+ types.append(mlirFunctionTypeGetResult(self, i));
+ }
+ return types;
+ },
+ "Returns the list of result types in the FunctionType.");
+}
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
+void PyOpaqueType::bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::string &dialectNamespace, const std::string &typeData,
+ DefaultingPyMlirContext context) {
+ MlirType type =
+ mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace),
+ toMlirStringRef(typeData));
+ return PyOpaqueType(context->getRef(), type);
+ },
+ nb::arg("dialect_namespace"), nb::arg("buffer"),
+ nb::arg("context") = nb::none(),
+ "Create an unregistered (opaque) dialect type.");
+ c.def_prop_ro(
+ "dialect_namespace",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the dialect namespace for the Opaque type as a string.");
+ c.def_prop_ro(
+ "data",
+ [](PyOpaqueType &self) {
+ MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+ return nb::str(stringRef.data, stringRef.length);
+ },
+ "Returns the data for the Opaque type as a string.");
+}
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index b2c9380bc1d73..88f58d45cdd75 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -9,7 +9,9 @@
#include "Pass.h"
#include "Rewrite.h"
#include "mlir/Bindings/Python/Globals.h"
+#include "mlir/Bindings/Python/IRAttributes.h"
#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 4a9fb127ee08c..003a06b16daac 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -533,9 +533,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
SOURCES
MainModule.cpp
IRAffine.cpp
- IRAttributes.cpp
IRInterfaces.cpp
- IRTypes.cpp
Pass.cpp
Rewrite.cpp
@@ -846,8 +844,10 @@ declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
- IRCore.cpp
Globals.cpp
+ IRAttributes.cpp
+ IRCore.cpp
+ IRTypes.cpp
)
################################################################################
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 43573cbc305fa..b229c02ccf5e6 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -15,6 +15,7 @@
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
@@ -47,6 +48,49 @@ struct PyTestType
}
};
+struct PyTestIntegerRankedTensorType
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
+ PyTestIntegerRankedTensorType,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestIntegerRankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, unsigned width,
+ mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return PyTestIntegerRankedTensorType(
+ ctx->getRef(),
+ mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirIntegerTypeGet(ctx.get()->get(), width), encoding));
+ },
+ nb::arg("shape"), nb::arg("width"),
+ nb::arg("context").none() = nb::none());
+ }
+};
+
+struct PyTestTensorValue
+ : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
+ PyTestTensorValue> {
+ static constexpr IsAFunctionTy isaFunction =
+ mlirTypeIsAPythonTestTestTensorValue;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirRankedTensorTypeGetTypeID;
+ static constexpr const char *pyClassName = "TestTensorValue";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+ }
+};
+
class PyTestAttr
: public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
PyTestAttr> {
@@ -73,18 +117,18 @@ class PyTestAttr
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
- [](MlirContext context, bool load) {
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ context,
+ bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
- mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+ mlirDialectHandleRegisterDialect(pythonTestDialect,
+ context.get()->get());
if (load) {
- mlirDialectHandleLoadDialect(pythonTestDialect, context);
+ mlirDialectHandleLoadDialect(pythonTestDialect, context.get()->get());
}
},
- nb::arg("context"), nb::arg("load") = true,
- // clang-format off
- nb::sig("def register_python_test_dialect(context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", load: bool = True) -> None"));
- // clang-format on
+ nb::arg("context").none() = nb::none(), nb::arg("load") = true);
m.def(
"register_dialect",
@@ -100,73 +144,16 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"test_diagnostics_with_errors_and_notes",
- [](MlirContext ctx) {
- mlir::python::CollectDiagnosticsToStringScope handler(ctx);
- mlirPythonTestEmitDiagnosticWithNote(ctx);
+ [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+ ctx) {
+ mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
+ mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
throw nb::value_error(handler.takeMessage().c_str());
},
- // clang-format off
- nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
- // clang-format on
+ nb::arg("context").none() = nb::none());
PyTestAttr::bind(m);
PyTestType::bind(m);
-
- auto typeCls =
- mlir_type_subclass(m, "TestIntegerRankedTensorType",
- mlirTypeIsARankedIntegerTensor,
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("RankedTensorType"))
- .def_classmethod(
- "get",
- [](const nb::object &cls, std::vector<int64_t> shape,
- unsigned width, MlirContext ctx) {
- MlirAttribute encoding = mlirAttributeGetNull();
- return cls(mlirRankedTensorTypeGet(
- shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
- encoding));
- },
- // clang-format off
- nb::sig("def get(cls: object, shape: collections.abc.Sequence[int], width: int, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
- // clang-format on
- nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
- nb::arg("context").none() = nb::none());
-
- assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
- "TestIntegerRankedTensorType has no static_typeid");
-
- MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
-
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
- mlirRankedTensorTypeID, nb::arg("replace") = true)(
- nanobind::cpp_function([typeCls](const nb::object &mlirType) {
- return typeCls.get_class()(mlirType);
- }));
-
- auto valueCls = mlir_value_subclass(m, "TestTensorValue",
- mlirTypeIsAPythonTestTestTensorValue)
- .def("is_null", [](MlirValue &self) {
- return mlirValueIsNull(self);
- });
-
- nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
- mlirRankedTensorTypeID)(
- nanobind::cpp_function([valueCls](const nb::object &valueObj) {
- std::optional<nb::object> capsule =
- mlirApiObjectToCapsule(valueObj);
- assert(capsule.has_value() && "capsule is not null");
- MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr());
- MlirType t = mlirValueGetType(v);
- // This is hyper-specific in order to exercise/test registering a
- // value caster from cpp (but only for a single test case; see
- // testTensorValue python_test.py).
- if (mlirShapedTypeHasStaticShape(t) &&
- mlirShapedTypeGetDimSize(t, 0) == 1 &&
- mlirShapedTypeGetDimSize(t, 1) == 2 &&
- mlirShapedTypeGetDimSize(t, 2) == 3)
- return valueCls.get_class()(valueObj);
- return valueObj;
- }));
-}
+ PyTestIntegerRankedTensorType::bind(m);
+ PyTestTensorValue::bind(m);
+}
\ No newline at end of file
More information about the llvm-commits
mailing list